From b65ed70f3244fe5c53867d14e90b91e5832b2426 Mon Sep 17 00:00:00 2001 From: perl_d Date: Wed, 6 May 2026 18:07:03 +0200 Subject: [PATCH 1/3] refactor: extract core HTTP device logic --- pxii_bec/devices/http.py | 178 +++++++++++++++++++++++++ pxii_bec/devices/smargopolo_smargon.py | 155 +++------------------ pyproject.toml | 3 + tests/tests_devices/test_smargon.py | 7 +- 4 files changed, 201 insertions(+), 142 deletions(-) create mode 100644 pxii_bec/devices/http.py diff --git a/pxii_bec/devices/http.py b/pxii_bec/devices/http.py new file mode 100644 index 0000000..38b612a --- /dev/null +++ b/pxii_bec/devices/http.py @@ -0,0 +1,178 @@ +import time +from abc import ABC, abstractmethod +from threading import Event, RLock, Thread +from typing import Any + +from ophyd import OphydObject +from ophyd_devices import PSIDeviceBase +from ophyd_devices.utils.socket import SocketSignal +from requests import Response, Session + +TIMESTAMP_ID = "__timestamp" +_POLL_INTERVAL_SLOW = 0.1 + + +class HttpRestError(Exception): + """Error for rest calls from a HttpRestSignal.""" + + def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None: + method, url = resp.request.method, resp.request.url + data = f"{str(value)} to " if value is not None else "" + super().__init__( + f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.", + *args, + ) + + +class HttpDeviceController(OphydObject, ABC): + """Controller to consolidate polling loops and other REST calls for devices which communicate + with HTTP REST interfaces""" + + _readback_endpoint: str + _target_endpoint: str + + def __init__(self, *, prefix, **kwargs): + self._readbacks: dict + self._session = Session() + self._prefix = prefix + self._targets = {} + self._signal_registry: set[str] = set() + self._readback_poll_interval: float = _POLL_INTERVAL_SLOW + + super().__init__(**kwargs) + self._setup_readback() + + def _setup_readback(self): + self._stop_monitor_readback_event = Event() + self._readback_lock = RLock() + self._monitor_readback_thread = Thread( + target=self._monitor, + args=[ + self._readback_endpoint, + self._stop_monitor_readback_event, + self._readback_lock, + self._readbacks, + ], + ) + + def manual_update(self): + self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks) + + def _update_reading(self, endpoint: str, lock: RLock, buffer: dict): + data = self._rest_get(endpoint) + timestamp = time.monotonic() + with lock: + buffer.update(data) + buffer["__timestamp"] = timestamp + + def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict): + while not event.is_set(): + self._update_reading(endpoint, lock, buffer) + time.sleep(self._readback_poll_interval) + + def _clean_monitor(self): + if self._monitor_readback_thread.is_alive(): + self._stop_monitor_readback_event.set() + self._monitor_readback_thread.join(timeout=2) + if self._monitor_readback_thread.is_alive(): + raise RuntimeError("Failed to clean up Aerotech monitor thread.") + + def register(self, axis_id: str): + self._signal_registry.add(axis_id) + + def _rest_get(self, endpoint): + resp = self._session.get(self._prefix + endpoint) + if not resp.ok: + raise HttpRestError(resp) + return resp.json() + + def _rest_put(self, params: dict | None = None, body: dict | None = None): + resp = self._session.put(self._prefix + self._target_endpoint, params=params, json=body) + if not resp.ok: + raise HttpRestError(resp, value=params) + + def _rest_post(self, params: dict | None = None, body: dict | None = None): + resp = self._session.post(self._prefix + self._target_endpoint, params=params, json=body) + if not resp.ok: + raise HttpRestError(resp, value=params) + + def start_monitor(self): + """Start or restart the automonitor thread.""" + self._clean_monitor() + self._setup_readback() + self._monitor_readback_thread.start() + + def monitor_stopped(self): + return not self._monitor_readback_thread.is_alive() + + def put(self, axis: str, val: float): + self._rest_put({axis: val}) + + @abstractmethod + def get_readback(self, axis_id: str) -> tuple[float, float] | None: + """Return a tuple (reading, timestamp) if the axis_id exists""" + + def stop(self): + # There doesn't appear to be a stop endpoint on the server + # Best effort: set the target to the current position + pass + # TODO: self._rest_put(self._readbacks) + + +class HttpDeviceSignal(SocketSignal): + """Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the Aerotech + Controller""" + + def __init__(self, *args, axis_identifier: str, **kwargs): + super().__init__(*args, **kwargs) + controller: HttpDeviceController | None = getattr(self.root, "controller", None) + if controller is None: + raise TypeError("HttpDeviceSignal must be used in a device with a HttpDeviceController") + self._controller = controller + self._axis_id = axis_identifier + self._controller.register(self._axis_id) + + def _socket_get(self): # type: ignore + self._readback, self.metadata["timestamp"] = self._controller.get_readback( + self._axis_id + ) or (0.0, 0.0) + return self._readback + + def _socket_set(self, val: float): + self._controller.put(self._axis_id, val) + + def get(self, **kwargs): + if self._controller.monitor_stopped(): + self._controller.start_monitor() + return super().get(**kwargs) + + +class HttpOphydDevice(PSIDeviceBase): + controller_class: type[HttpDeviceController] + + def __init__( + self, + *, + name: str, + prefix: str = "", + scan_info=None, + device_manager=None, + **kwargs, + ): + self.controller = self.controller_class(prefix=prefix) + super().__init__( + name=name, + prefix=prefix, + scan_info=scan_info, + device_manager=device_manager, + **kwargs, + ) + + def wait_for_connection(self, **kwargs): # type: ignore + self.controller.start_monitor() + self.controller.manual_update() + return super().wait_for_connection(**kwargs) + + def stop(self, *, success: bool = False) -> None: + self.controller.stop() + return super().stop(success=success) diff --git a/pxii_bec/devices/smargopolo_smargon.py b/pxii_bec/devices/smargopolo_smargon.py index 899c8a1..a7c4618 100644 --- a/pxii_bec/devices/smargopolo_smargon.py +++ b/pxii_bec/devices/smargopolo_smargon.py @@ -1,129 +1,21 @@ -import time -from threading import Event, RLock, Thread -from typing import Any - from ophyd import Component as Cpt -from ophyd import OphydObject from ophyd_devices import PSIDeviceBase -from ophyd_devices.utils.socket import SocketSignal -from requests import Response, Session + +from .http import HttpDeviceController, HttpDeviceSignal, HttpOphydDevice _TIMESTAMP_ID = "__timestamp" _POLL_INTERVAL_SLOW = 0.1 -class HttpRestError(Exception): - """Error for rest calls from a HttpRestSignal.""" - - def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None: - method, url = resp.request.method, resp.request.url - data = f"{str(value)} to " if value is not None else "" - super().__init__( - f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.", - *args, - ) - - -class SmargonSignal(SocketSignal): - """Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController""" - - def __init__(self, *args, axis_identifier: str, **kwargs): - super().__init__(*args, **kwargs) - controller: SmargonController | None = getattr(self.root, "controller", None) - if controller is None: - raise TypeError("SmargonSignal must be used in a device with a SmargonController") - self._controller = controller - self._axis_id = axis_identifier - self._controller.register(self._axis_id) - - def _socket_get(self): # type: ignore - self._readback, self.metadata["timestamp"] = self._controller.get_readback( - self._axis_id - ) or (0.0, 0.0) - return self._readback - - def _socket_set(self, val: float): - self._controller.put(self._axis_id, val) - - def get(self, **kwargs): - if self._controller.monitor_stopped(): - self._controller.start_monitor() - return super().get(**kwargs) - - -class SmargonController(OphydObject): +class SmargonController(HttpDeviceController): """Controller to consolidate polling loops and other REST calls for the smargon""" + _readback_endpoint = "/readbackSCS" + _target_endpoint = "/targetSCS" + def __init__(self, *, prefix, **kwargs): - self._session = Session() - self._prefix = prefix - self._readback_endpoint = "/readbackSCS" - self._target_endpoint = "/targetSCS" - self._targets = {} - self._signal_registry: set[str] = set() - self._readback_poll_interval: float = _POLL_INTERVAL_SLOW - - super().__init__(**kwargs) - self._setup_readback() - - def _setup_readback(self): self._readbacks: dict[str, float] = {} - self._stop_monitor_readback_event = Event() - self._readback_lock = RLock() - self._monitor_readback_thread = Thread( - target=self._monitor, - args=[ - self._readback_endpoint, - self._stop_monitor_readback_event, - self._readback_lock, - self._readbacks, - ], - ) - - def manual_update(self): - self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks) - - def _update_reading(self, endpoint: str, lock: RLock, buffer: dict): - data = self._rest_get(endpoint) - timestamp = time.monotonic() - with lock: - buffer.update(data) - buffer["__timestamp"] = timestamp - - def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict): - while not event.is_set(): - self._update_reading(endpoint, lock, buffer) - time.sleep(self._readback_poll_interval) - - def _clean_monitor(self): - if self._monitor_readback_thread.is_alive(): - self._stop_monitor_readback_event.set() - self._monitor_readback_thread.join(timeout=2) - if self._monitor_readback_thread.is_alive(): - raise RuntimeError("Failed to clean up Smargon monitor thread.") - - def register(self, axis_id: str): - self._signal_registry.add(axis_id) - - def _rest_get(self, endpoint): - resp = self._session.get(self._prefix + endpoint) - if not resp.ok: - raise HttpRestError(resp) - return resp.json() - - def _rest_put(self, val: dict[str, float]): - resp = self._session.put(self._prefix + self._target_endpoint, params=val) - if not resp.ok: - raise HttpRestError(resp, value=val) - - def start_monitor(self): - """Start or restart the automonitor thread.""" - self._clean_monitor() - self._setup_readback() - self._monitor_readback_thread.start() - - def monitor_stopped(self): - return not self._monitor_readback_thread.is_alive() + super().__init__(prefix=prefix, **kwargs) def get_readback(self, axis_id: str) -> tuple[float, float] | None: with self._readback_lock: @@ -132,34 +24,19 @@ class SmargonController(OphydObject): return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore def put(self, axis: str, val: float): - self._rest_put({axis: val}) + self._rest_put(params={axis: val}) def stop(self): # There doesn't appear to be a stop endpoint on the server # Best effort: set the target to the current position - self._rest_put(self._readbacks) + self._rest_put(params=self._readbacks) -class Smargon(PSIDeviceBase): - x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01) - y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01) - z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01) - phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01) - chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01) +class Smargon(HttpOphydDevice): + controller_class = SmargonController - def __init__( - self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs - ): - self.controller = SmargonController(prefix=prefix) - super().__init__( - name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs - ) - - def wait_for_connection(self, **kwargs): # type: ignore - self.controller.start_monitor() - self.controller.manual_update() - return super().wait_for_connection(**kwargs) - - def stop(self, *, success: bool = False) -> None: - self.controller.stop() - return super().stop(success=success) + x = Cpt(HttpDeviceSignal, axis_identifier="SHX", tolerance=0.01) + y = Cpt(HttpDeviceSignal, axis_identifier="SHY", tolerance=0.01) + z = Cpt(HttpDeviceSignal, axis_identifier="SHZ", tolerance=0.01) + phi = Cpt(HttpDeviceSignal, axis_identifier="PHI", tolerance=0.01) + chi = Cpt(HttpDeviceSignal, axis_identifier="CHI", tolerance=0.01) diff --git a/pyproject.toml b/pyproject.toml index 133b014..8dbf95a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,3 +77,6 @@ good-names-rgxs = [ ".*_2D.*", ".*_1D.*", ] + +[tool.ruff] +line-length = 100 diff --git a/tests/tests_devices/test_smargon.py b/tests/tests_devices/test_smargon.py index bbb215c..fc14015 100644 --- a/tests/tests_devices/test_smargon.py +++ b/tests/tests_devices/test_smargon.py @@ -1,6 +1,6 @@ from copy import copy from threading import RLock -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import ANY import pytest @@ -14,9 +14,10 @@ class MockServer: with self.lock: return copy(self.mock_data) - def put(self, val: dict[str, float]): + def put(self, params: dict | None = None, body: dict | None = None): with self.lock: - self.mock_data.update(val) + assert params is not None + self.mock_data.update(params) @pytest.fixture -- 2.52.0 From 6c8351238cddf5b578f4ac88ce0554a0c718f73c Mon Sep 17 00:00:00 2001 From: perl_d Date: Wed, 6 May 2026 18:07:28 +0200 Subject: [PATCH 2/3] feat: add aerotech device --- pxii_bec/devices/aerotech.py | 43 +++++++++++++++++++++ tests/tests_devices/test_aerotech.py | 58 ++++++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 pxii_bec/devices/aerotech.py create mode 100644 tests/tests_devices/test_aerotech.py diff --git a/pxii_bec/devices/aerotech.py b/pxii_bec/devices/aerotech.py new file mode 100644 index 0000000..d465e59 --- /dev/null +++ b/pxii_bec/devices/aerotech.py @@ -0,0 +1,43 @@ +from ophyd import Component as Cpt + +from .http import TIMESTAMP_ID, HttpDeviceController, HttpDeviceSignal, HttpOphydDevice + + +class AerotechController(HttpDeviceController): + _readback_endpoint = "status" + _target_endpoint = "position" + + def __init__(self, *, prefix, **kwargs): + self._readbacks: dict[str, dict[str, float | bool]] = {} + super().__init__(prefix=prefix, **kwargs) + + def put(self, axis: str, val: float): + self._rest_post(body={axis: val}) + + def get_readback(self, axis_id: str) -> tuple[float, float] | None: + with self._readback_lock: + if axis_id not in self._readbacks or TIMESTAMP_ID not in self._readbacks: + return None + return self._readbacks.get(axis_id)["pos"], self._readbacks.get(TIMESTAMP_ID) # type: ignore + + +class Aerotech(HttpOphydDevice): + controller_class = AerotechController + + x = Cpt(HttpDeviceSignal, axis_identifier="x", tolerance=0.01) + y = Cpt(HttpDeviceSignal, axis_identifier="y", tolerance=0.01) + z = Cpt(HttpDeviceSignal, axis_identifier="z", tolerance=0.01) + u = Cpt(HttpDeviceSignal, axis_identifier="u", tolerance=0.01) + vel_u_deg_s = Cpt(HttpDeviceSignal, axis_identifier="vel_u_deg_s", tolerance=0.01) + + +def _test(): + a = Aerotech(name="aerotech", prefix="http://mx-x10sa-queue-01:5234") + a.wait_for_connection() + return a + + +if __name__ == "__main__": + aerotech = _test() + print(aerotech.read()) + aerotech.stop() diff --git a/tests/tests_devices/test_aerotech.py b/tests/tests_devices/test_aerotech.py new file mode 100644 index 0000000..49e2fc2 --- /dev/null +++ b/tests/tests_devices/test_aerotech.py @@ -0,0 +1,58 @@ +from copy import copy +from threading import RLock +from unittest.mock import ANY + +import pytest + + +class MockServer: + def __init__(self) -> None: + self.lock = RLock() + self.mock_data = { + "x": {"pos": 1.0}, + "y": {"pos": 1.0}, + "z": {"pos": 1.0}, + "u": {"pos": 1.0}, + "vel_u_deg_s": {"pos": 1.0}, + } + + def get(self, endpoint): + with self.lock: + return copy(self.mock_data) + + def put(self, params: dict | None = None, body: dict | None = None): + with self.lock: + assert body is not None + for k, v in body.items(): + self.mock_data[k]["pos"] = v + + +@pytest.fixture +def aerotech(): + mock_server = MockServer() + from pxii_bec.devices.aerotech import Aerotech + + s = Aerotech(name="aerotech", prefix="http://test-aerotech.psi.ch") + s.controller._rest_get = mock_server.get + s.controller._rest_post = mock_server.put + yield s + s.controller._stop_monitor_readback_event.set() + + +class TestAerotech: + def test_aerotech_read(self, aerotech): + aerotech.wait_for_connection() + reading = aerotech.read() + assert dict(reading) == { + "aerotech_x": {"value": 1.0, "timestamp": ANY}, + "aerotech_y": {"value": 1.0, "timestamp": ANY}, + "aerotech_z": {"value": 1.0, "timestamp": ANY}, + "aerotech_u": {"value": 1.0, "timestamp": ANY}, + "aerotech_vel_u_deg_s": {"value": 1.0, "timestamp": ANY}, + } + + def test_aerotech_set_with_status(self, aerotech): + aerotech.wait_for_connection() + st = aerotech.x.set(5.0) + st.wait(timeout=1) + assert aerotech.x.get() == 5.0 -- 2.52.0 From f35d5964f9f502a8c7e17ecd003c8261f68f792d Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 7 May 2026 16:31:53 +0200 Subject: [PATCH 3/3] config: add aerotech device to pxii config --- pxii_bec/device_configs/x10sa_device_config.yaml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pxii_bec/device_configs/x10sa_device_config.yaml b/pxii_bec/device_configs/x10sa_device_config.yaml index e41ad8e..32e940a 100644 --- a/pxii_bec/device_configs/x10sa_device_config.yaml +++ b/pxii_bec/device_configs/x10sa_device_config.yaml @@ -55,4 +55,17 @@ smargon: - smargon - motors readOnly: false - softwareTrigger: false \ No newline at end of file + softwareTrigger: false + +aerotech: + description: REST-based device which connects to AareScan + deviceClass: pxii_bec.devices.aerotech + deviceConfig: { prefix: "http://mx-x10sa-queue-01:5234/" } + onFailure: buffer + enabled: True + readoutPriority: baseline + deviceTags: + - aerotech + - motors + readOnly: false + softwareTrigger: false -- 2.52.0