Add Aerotech ophyd device and config #14

Merged
perl_d merged 3 commits from feat/aerotech into main 2026-05-12 09:14:21 +02:00
7 changed files with 316 additions and 143 deletions
@@ -55,4 +55,17 @@ smargon:
- smargon
- motors
readOnly: false
softwareTrigger: false
softwareTrigger: false
aerotech:
description: REST-based device which connects to AareScan
deviceClass: pxii_bec.devices.aerotech
deviceConfig: { prefix: "http://mx-x10sa-queue-01:5234/" }
onFailure: buffer
enabled: True
readoutPriority: baseline
deviceTags:
- aerotech
- motors
readOnly: false
softwareTrigger: false
+43
View File
@@ -0,0 +1,43 @@
from ophyd import Component as Cpt
from .http import TIMESTAMP_ID, HttpDeviceController, HttpDeviceSignal, HttpOphydDevice
class AerotechController(HttpDeviceController):
_readback_endpoint = "status"
_target_endpoint = "position"
def __init__(self, *, prefix, **kwargs):
self._readbacks: dict[str, dict[str, float | bool]] = {}
super().__init__(prefix=prefix, **kwargs)
def put(self, axis: str, val: float):
self._rest_post(body={axis: val})
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
with self._readback_lock:
if axis_id not in self._readbacks or TIMESTAMP_ID not in self._readbacks:
return None
return self._readbacks.get(axis_id)["pos"], self._readbacks.get(TIMESTAMP_ID) # type: ignore
class Aerotech(HttpOphydDevice):
controller_class = AerotechController
x = Cpt(HttpDeviceSignal, axis_identifier="x", tolerance=0.01)
y = Cpt(HttpDeviceSignal, axis_identifier="y", tolerance=0.01)
z = Cpt(HttpDeviceSignal, axis_identifier="z", tolerance=0.01)
u = Cpt(HttpDeviceSignal, axis_identifier="u", tolerance=0.01)
vel_u_deg_s = Cpt(HttpDeviceSignal, axis_identifier="vel_u_deg_s", tolerance=0.01)
def _test():
a = Aerotech(name="aerotech", prefix="http://mx-x10sa-queue-01:5234")
a.wait_for_connection()
return a
if __name__ == "__main__":
aerotech = _test()
print(aerotech.read())
aerotech.stop()
+178
View File
@@ -0,0 +1,178 @@
import time
from abc import ABC, abstractmethod
from threading import Event, RLock, Thread
from typing import Any
from ophyd import OphydObject
from ophyd_devices import PSIDeviceBase
from ophyd_devices.utils.socket import SocketSignal
from requests import Response, Session
TIMESTAMP_ID = "__timestamp"
_POLL_INTERVAL_SLOW = 0.1
class HttpRestError(Exception):
"""Error for rest calls from a HttpRestSignal."""
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
method, url = resp.request.method, resp.request.url
data = f"{str(value)} to " if value is not None else ""
super().__init__(
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
*args,
)
class HttpDeviceController(OphydObject, ABC):
"""Controller to consolidate polling loops and other REST calls for devices which communicate
with HTTP REST interfaces"""
_readback_endpoint: str
_target_endpoint: str
def __init__(self, *, prefix, **kwargs):
self._readbacks: dict
self._session = Session()
self._prefix = prefix
self._targets = {}
self._signal_registry: set[str] = set()
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
super().__init__(**kwargs)
self._setup_readback()
def _setup_readback(self):
self._stop_monitor_readback_event = Event()
self._readback_lock = RLock()
self._monitor_readback_thread = Thread(
target=self._monitor,
args=[
self._readback_endpoint,
self._stop_monitor_readback_event,
self._readback_lock,
self._readbacks,
],
)
def manual_update(self):
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
data = self._rest_get(endpoint)
timestamp = time.monotonic()
with lock:
buffer.update(data)
buffer["__timestamp"] = timestamp
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
while not event.is_set():
self._update_reading(endpoint, lock, buffer)
time.sleep(self._readback_poll_interval)
def _clean_monitor(self):
if self._monitor_readback_thread.is_alive():
self._stop_monitor_readback_event.set()
self._monitor_readback_thread.join(timeout=2)
if self._monitor_readback_thread.is_alive():
raise RuntimeError("Failed to clean up Aerotech monitor thread.")
def register(self, axis_id: str):
self._signal_registry.add(axis_id)
def _rest_get(self, endpoint):
resp = self._session.get(self._prefix + endpoint)
if not resp.ok:
raise HttpRestError(resp)
return resp.json()
def _rest_put(self, params: dict | None = None, body: dict | None = None):
resp = self._session.put(self._prefix + self._target_endpoint, params=params, json=body)
if not resp.ok:
raise HttpRestError(resp, value=params)
def _rest_post(self, params: dict | None = None, body: dict | None = None):
resp = self._session.post(self._prefix + self._target_endpoint, params=params, json=body)
if not resp.ok:
raise HttpRestError(resp, value=params)
def start_monitor(self):
"""Start or restart the automonitor thread."""
self._clean_monitor()
self._setup_readback()
self._monitor_readback_thread.start()
def monitor_stopped(self):
return not self._monitor_readback_thread.is_alive()
def put(self, axis: str, val: float):
self._rest_put({axis: val})
@abstractmethod
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
"""Return a tuple (reading, timestamp) if the axis_id exists"""
def stop(self):
# There doesn't appear to be a stop endpoint on the server
# Best effort: set the target to the current position
pass
# TODO: self._rest_put(self._readbacks)
class HttpDeviceSignal(SocketSignal):
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the Aerotech
Controller"""
def __init__(self, *args, axis_identifier: str, **kwargs):
super().__init__(*args, **kwargs)
controller: HttpDeviceController | None = getattr(self.root, "controller", None)
if controller is None:
raise TypeError("HttpDeviceSignal must be used in a device with a HttpDeviceController")
self._controller = controller
self._axis_id = axis_identifier
self._controller.register(self._axis_id)
def _socket_get(self): # type: ignore
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
self._axis_id
) or (0.0, 0.0)
return self._readback
def _socket_set(self, val: float):
self._controller.put(self._axis_id, val)
def get(self, **kwargs):
if self._controller.monitor_stopped():
self._controller.start_monitor()
return super().get(**kwargs)
class HttpOphydDevice(PSIDeviceBase):
controller_class: type[HttpDeviceController]
def __init__(
self,
*,
name: str,
prefix: str = "",
scan_info=None,
device_manager=None,
**kwargs,
):
self.controller = self.controller_class(prefix=prefix)
super().__init__(
name=name,
prefix=prefix,
scan_info=scan_info,
device_manager=device_manager,
**kwargs,
)
def wait_for_connection(self, **kwargs): # type: ignore
self.controller.start_monitor()
self.controller.manual_update()
return super().wait_for_connection(**kwargs)
def stop(self, *, success: bool = False) -> None:
self.controller.stop()
return super().stop(success=success)
+16 -139
View File
@@ -1,129 +1,21 @@
import time
from threading import Event, RLock, Thread
from typing import Any
from ophyd import Component as Cpt
from ophyd import OphydObject
from ophyd_devices import PSIDeviceBase
from ophyd_devices.utils.socket import SocketSignal
from requests import Response, Session
from .http import HttpDeviceController, HttpDeviceSignal, HttpOphydDevice
_TIMESTAMP_ID = "__timestamp"
_POLL_INTERVAL_SLOW = 0.1
class HttpRestError(Exception):
"""Error for rest calls from a HttpRestSignal."""
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
method, url = resp.request.method, resp.request.url
data = f"{str(value)} to " if value is not None else ""
super().__init__(
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
*args,
)
class SmargonSignal(SocketSignal):
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController"""
def __init__(self, *args, axis_identifier: str, **kwargs):
super().__init__(*args, **kwargs)
controller: SmargonController | None = getattr(self.root, "controller", None)
if controller is None:
raise TypeError("SmargonSignal must be used in a device with a SmargonController")
self._controller = controller
self._axis_id = axis_identifier
self._controller.register(self._axis_id)
def _socket_get(self): # type: ignore
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
self._axis_id
) or (0.0, 0.0)
return self._readback
def _socket_set(self, val: float):
self._controller.put(self._axis_id, val)
def get(self, **kwargs):
if self._controller.monitor_stopped():
self._controller.start_monitor()
return super().get(**kwargs)
class SmargonController(OphydObject):
class SmargonController(HttpDeviceController):
"""Controller to consolidate polling loops and other REST calls for the smargon"""
_readback_endpoint = "/readbackSCS"
_target_endpoint = "/targetSCS"
def __init__(self, *, prefix, **kwargs):
self._session = Session()
self._prefix = prefix
self._readback_endpoint = "/readbackSCS"
self._target_endpoint = "/targetSCS"
self._targets = {}
self._signal_registry: set[str] = set()
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
super().__init__(**kwargs)
self._setup_readback()
def _setup_readback(self):
self._readbacks: dict[str, float] = {}
self._stop_monitor_readback_event = Event()
self._readback_lock = RLock()
self._monitor_readback_thread = Thread(
target=self._monitor,
args=[
self._readback_endpoint,
self._stop_monitor_readback_event,
self._readback_lock,
self._readbacks,
],
)
def manual_update(self):
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
data = self._rest_get(endpoint)
timestamp = time.monotonic()
with lock:
buffer.update(data)
buffer["__timestamp"] = timestamp
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
while not event.is_set():
self._update_reading(endpoint, lock, buffer)
time.sleep(self._readback_poll_interval)
def _clean_monitor(self):
if self._monitor_readback_thread.is_alive():
self._stop_monitor_readback_event.set()
self._monitor_readback_thread.join(timeout=2)
if self._monitor_readback_thread.is_alive():
raise RuntimeError("Failed to clean up Smargon monitor thread.")
def register(self, axis_id: str):
self._signal_registry.add(axis_id)
def _rest_get(self, endpoint):
resp = self._session.get(self._prefix + endpoint)
if not resp.ok:
raise HttpRestError(resp)
return resp.json()
def _rest_put(self, val: dict[str, float]):
resp = self._session.put(self._prefix + self._target_endpoint, params=val)
if not resp.ok:
raise HttpRestError(resp, value=val)
def start_monitor(self):
"""Start or restart the automonitor thread."""
self._clean_monitor()
self._setup_readback()
self._monitor_readback_thread.start()
def monitor_stopped(self):
return not self._monitor_readback_thread.is_alive()
super().__init__(prefix=prefix, **kwargs)
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
with self._readback_lock:
@@ -132,34 +24,19 @@ class SmargonController(OphydObject):
return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore
def put(self, axis: str, val: float):
self._rest_put({axis: val})
self._rest_put(params={axis: val})
def stop(self):
# There doesn't appear to be a stop endpoint on the server
# Best effort: set the target to the current position
self._rest_put(self._readbacks)
self._rest_put(params=self._readbacks)
class Smargon(PSIDeviceBase):
x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01)
y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01)
z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01)
phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01)
chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01)
class Smargon(HttpOphydDevice):
controller_class = SmargonController
def __init__(
self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs
):
self.controller = SmargonController(prefix=prefix)
super().__init__(
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
)
def wait_for_connection(self, **kwargs): # type: ignore
self.controller.start_monitor()
self.controller.manual_update()
return super().wait_for_connection(**kwargs)
def stop(self, *, success: bool = False) -> None:
self.controller.stop()
return super().stop(success=success)
x = Cpt(HttpDeviceSignal, axis_identifier="SHX", tolerance=0.01)
y = Cpt(HttpDeviceSignal, axis_identifier="SHY", tolerance=0.01)
z = Cpt(HttpDeviceSignal, axis_identifier="SHZ", tolerance=0.01)
phi = Cpt(HttpDeviceSignal, axis_identifier="PHI", tolerance=0.01)
chi = Cpt(HttpDeviceSignal, axis_identifier="CHI", tolerance=0.01)
+3
View File
@@ -77,3 +77,6 @@ good-names-rgxs = [
".*_2D.*",
".*_1D.*",
]
[tool.ruff]
line-length = 100
+58
View File
@@ -0,0 +1,58 @@
from copy import copy
from threading import RLock
from unittest.mock import ANY
import pytest
class MockServer:
def __init__(self) -> None:
self.lock = RLock()
self.mock_data = {
"x": {"pos": 1.0},
"y": {"pos": 1.0},
"z": {"pos": 1.0},
"u": {"pos": 1.0},
"vel_u_deg_s": {"pos": 1.0},
}
def get(self, endpoint):
with self.lock:
return copy(self.mock_data)
def put(self, params: dict | None = None, body: dict | None = None):
with self.lock:
assert body is not None
for k, v in body.items():
self.mock_data[k]["pos"] = v
@pytest.fixture
def aerotech():
mock_server = MockServer()
from pxii_bec.devices.aerotech import Aerotech
s = Aerotech(name="aerotech", prefix="http://test-aerotech.psi.ch")
s.controller._rest_get = mock_server.get
s.controller._rest_post = mock_server.put
yield s
s.controller._stop_monitor_readback_event.set()
class TestAerotech:
def test_aerotech_read(self, aerotech):
aerotech.wait_for_connection()
reading = aerotech.read()
assert dict(reading) == {
"aerotech_x": {"value": 1.0, "timestamp": ANY},
"aerotech_y": {"value": 1.0, "timestamp": ANY},
"aerotech_z": {"value": 1.0, "timestamp": ANY},
"aerotech_u": {"value": 1.0, "timestamp": ANY},
"aerotech_vel_u_deg_s": {"value": 1.0, "timestamp": ANY},
}
def test_aerotech_set_with_status(self, aerotech):
aerotech.wait_for_connection()
st = aerotech.x.set(5.0)
st.wait(timeout=1)
assert aerotech.x.get() == 5.0
+4 -3
View File
@@ -1,6 +1,6 @@
from copy import copy
from threading import RLock
from unittest.mock import ANY, MagicMock, patch
from unittest.mock import ANY
import pytest
@@ -14,9 +14,10 @@ class MockServer:
with self.lock:
return copy(self.mock_data)
def put(self, val: dict[str, float]):
def put(self, params: dict | None = None, body: dict | None = None):
with self.lock:
self.mock_data.update(val)
assert params is not None
self.mock_data.update(params)
@pytest.fixture