From c58b0c1d20d564aca5bbbd39f7597d9b1ea217d0 Mon Sep 17 00:00:00 2001 From: perl_d Date: Wed, 4 Mar 2026 17:38:32 +0100 Subject: [PATCH 1/2] feat: Smargopolo Smargon device v1 --- .../device_configs/x10sa_device_config.yaml | 13 ++ pxii_bec/devices/smargopolo_smargon.py | 158 ++++++++++++++++++ tests/tests_devices/mock_smargopolo.py | 110 ++++++++++++ 3 files changed, 281 insertions(+) create mode 100644 pxii_bec/devices/smargopolo_smargon.py create mode 100644 tests/tests_devices/mock_smargopolo.py diff --git a/pxii_bec/device_configs/x10sa_device_config.yaml b/pxii_bec/device_configs/x10sa_device_config.yaml index db42938..e41ad8e 100644 --- a/pxii_bec/device_configs/x10sa_device_config.yaml +++ b/pxii_bec/device_configs/x10sa_device_config.yaml @@ -42,4 +42,17 @@ dcm_froll: deviceTags: - dcm readOnly: false + softwareTrigger: false + +smargon: + description: REST-based device which connects to Smargopolo + deviceClass: pxii_bec.devices.smargopolo_smargon.Smargon + deviceConfig: {prefix: 'http://x10sa-smargopolo.psi.ch:3000'} + onFailure: buffer + enabled: True + readoutPriority: baseline + deviceTags: + - smargon + - motors + readOnly: false softwareTrigger: false \ No newline at end of file diff --git a/pxii_bec/devices/smargopolo_smargon.py b/pxii_bec/devices/smargopolo_smargon.py new file mode 100644 index 0000000..c0eda96 --- /dev/null +++ b/pxii_bec/devices/smargopolo_smargon.py @@ -0,0 +1,158 @@ +import time +from threading import Event, RLock, Thread +from typing import Any + +from ophyd import Component as Cpt +from ophyd import OphydObject +from ophyd_devices import PSIDeviceBase +from ophyd_devices.utils.socket import SocketSignal +from requests import Response, get, put + +_TIMESTAMP_ID = "__timestamp" +_POLL_INTERVAL_SLOW = 0.1 + + +class HttpRestError(Exception): + """Error for rest calls from a HttpRestSignal.""" + + def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None: + method, url = resp.request.method, resp.request.url + data = f"{str(value)} to " if value is not None else "" + super().__init__( + f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.", + *args, + ) + + +class SmargonSignal(SocketSignal): + """Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController""" + + def __init__(self, *args, axis_identifier: str, **kwargs): + super().__init__(*args, **kwargs) + controller: SmargonController | None = getattr(self.root, "controller", None) + if controller is None: + raise TypeError("SmargonSignal must be used in a device with a SmargonController") + self._controller = controller + self._axis_id = axis_identifier + self._controller.register(self._axis_id) + + def _socket_get(self): # type: ignore + self._readback, self.metadata["timestamp"] = self._controller.get_readback( + self._axis_id + ) or (0.0, 0.0) + return self._readback + + def _socket_set(self, val: float): + self._controller.put(self._axis_id, val) + + def get(self, **kwargs): + if self._controller.monitor_stopped(): + self._controller.start_monitor() + return super().get(**kwargs) + + +class SmargonController(OphydObject): + """Controller to consolidate polling loops and other REST calls for the smargon""" + + def __init__(self, *, prefix, **kwargs): + self._prefix = prefix + self._readback_endpoint = "/readbackSCS" + self._target_endpoint = "/targetSCS" + self._targets = {} + self._signal_registry: set[str] = set() + self._readback_poll_interval: float = _POLL_INTERVAL_SLOW + + super().__init__(**kwargs) + self._setup_readback() + + def _setup_readback(self): + self._readbacks: dict[str, float] = {} + self._stop_monitor_readback_event = Event() + self._readback_lock = RLock() + self._monitor_readback_thread = Thread( + target=self._monitor, + args=[ + self._readback_endpoint, + self._stop_monitor_readback_event, + self._readback_lock, + self._readbacks, + ], + ) + + def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict): + while not event.is_set(): + data = self._rest_get(endpoint) + timestamp = time.monotonic() + with lock: + buffer.update(data) + buffer["__timestamp"] = timestamp + time.sleep(self._readback_poll_interval) + + def _clean_monitor(self): + if self._monitor_readback_thread.is_alive(): + self._stop_monitor_readback_event.set() + self._monitor_readback_thread.join(timeout=2) + if self._monitor_readback_thread.is_alive(): + raise RuntimeError("Failed to clean up Smargon monitor thread.") + + def register(self, axis_id: str): + self._signal_registry.add(axis_id) + + def _rest_get(self, endpoint): + resp = get(self._prefix + endpoint) + if not resp.ok: + raise HttpRestError(resp) + return resp.json() + + def _rest_put(self, val: dict[str, float]): + resp = put(self._prefix + self._target_endpoint, params=val) + if not resp.ok: + raise HttpRestError(resp, value=val) + + def start_monitor(self): + """Start or restart the automonitor thread.""" + self._clean_monitor() + self._setup_readback() + self._monitor_readback_thread.start() + + def monitor_stopped(self): + return not self._monitor_readback_thread.is_alive() + + def get_readback(self, axis_id: str) -> tuple[float, float] | None: + with self._readback_lock: + if axis_id not in self._readbacks or _TIMESTAMP_ID not in self._readbacks: + return None + return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore + + def put(self, axis: str, val: float): + self._rest_put({axis: val}) + + def stop(self): + # There doesn't appear to be a stop endpoint on the server + # Best effort: set the target to the current position + self._rest_put(self._readbacks) + + +class Smargon(PSIDeviceBase): + + x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01) + y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01) + z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01) + phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01) + chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01) + + def __init__( + self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs + ): + self.controller = SmargonController(prefix=prefix) + super().__init__( + name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs + ) + + def wait_for_connection(self, **kwargs): # type: ignore + self.controller.start_monitor() + return super().wait_for_connection(**kwargs) + + def stop(self, *, success: bool = False) -> None: + self.controller.stop() + return super().stop(success=success) diff --git a/tests/tests_devices/mock_smargopolo.py b/tests/tests_devices/mock_smargopolo.py new file mode 100644 index 0000000..466063e --- /dev/null +++ b/tests/tests_devices/mock_smargopolo.py @@ -0,0 +1,110 @@ +"""A mock smargopolo REST interface with mock motoers, for testing devices against""" + +import asyncio +import random +import time +from contextlib import asynccontextmanager +from typing import Iterable + +import uvicorn +from fastapi import FastAPI, HTTPException, Query, Request +from pydantic import BaseModel + +AXES = ["SHX", "SHY", "SHZ", "PHI", "CHI"] + + +class Motor: + def __init__(self, velocity: float = 1.0): + self.position = 0.0 + self.target = 0.0 + self.velocity = velocity + self.moving = False + self._last_update = time.monotonic() + + def update(self): + now = time.monotonic() + dt = now - self._last_update + self._last_update = now + + if not self.moving: + return + + jitter_factor = random.random() * 0.05 - 0.025 # +- 2.5% jitter in step + distance = self.target - self.position + direction = 1 if distance > 0 else -1 + step = direction * self.velocity * dt + + if abs(step) >= abs(distance): + self.position = self.target + (step * jitter_factor) + self.moving = False + else: + self.position += step * (1 + jitter_factor) + + +motors: dict[str, Motor] = { + "SHX": Motor(velocity=3), + "SHY": Motor(velocity=2.5), + "SHZ": Motor(velocity=2), + "PHI": Motor(velocity=1.0), + "CHI": Motor(velocity=0.7), +} + + +class MoveRequest(BaseModel): + target: float + + +@asynccontextmanager +async def lifespan(app: FastAPI): + async def updater(): + while True: + for motor in motors.values(): + motor.update() + await asyncio.sleep(0.02) # 50 Hz update loop + + task = asyncio.create_task(updater()) + yield + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + +app = FastAPI(lifespan=lifespan) + + +def validate_axes(axes: Iterable[str] | None) -> list[str]: + if axes is None: + return AXES + for a in axes: + if a not in AXES: + raise HTTPException(status_code=404, detail=f"Unknown axis: {a}") + return list(axes) + + +@app.get("/readbackSCS") +async def readback_scs(axis: list[str] | None = Query(None)): + selected_axes = validate_axes(axis) + return {ax: motors[ax].position for ax in selected_axes} + + +@app.put("/targetSCS") +async def target_scs(req: Request): + targets = {ax: float(t) for ax, t in req.query_params.items()} + if targets is None: + return {} + + selected_axes = validate_axes(targets.keys()) + + for a in selected_axes: + motor = motors[a] + motor.update() + motor.target = targets[a] + motor.moving = True + + return {"targets": targets, "message": "Move started"} + + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000, reload=False) -- 2.52.0 From a11fe919985aa3a470ff28049d9185a492861974 Mon Sep 17 00:00:00 2001 From: David Perl Date: Thu, 5 Mar 2026 15:32:20 +0100 Subject: [PATCH 2/2] tests: add test for smargon --- pxii_bec/devices/smargopolo_smargon.py | 17 ++++++--- pyproject.toml | 1 + tests/tests_devices/test_smargon.py | 50 ++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 5 deletions(-) create mode 100644 tests/tests_devices/test_smargon.py diff --git a/pxii_bec/devices/smargopolo_smargon.py b/pxii_bec/devices/smargopolo_smargon.py index c0eda96..9af7bdf 100644 --- a/pxii_bec/devices/smargopolo_smargon.py +++ b/pxii_bec/devices/smargopolo_smargon.py @@ -79,13 +79,19 @@ class SmargonController(OphydObject): ], ) + def manual_update(self): + self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks) + + def _update_reading(self, endpoint: str, lock: RLock, buffer: dict): + data = self._rest_get(endpoint) + timestamp = time.monotonic() + with lock: + buffer.update(data) + buffer["__timestamp"] = timestamp + def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict): while not event.is_set(): - data = self._rest_get(endpoint) - timestamp = time.monotonic() - with lock: - buffer.update(data) - buffer["__timestamp"] = timestamp + self._update_reading(endpoint, lock, buffer) time.sleep(self._readback_poll_interval) def _clean_monitor(self): @@ -151,6 +157,7 @@ class Smargon(PSIDeviceBase): def wait_for_connection(self, **kwargs): # type: ignore self.controller.start_monitor() + self.controller.manual_update() return super().wait_for_connection(**kwargs) def stop(self, *, success: bool = False) -> None: diff --git a/pyproject.toml b/pyproject.toml index d259c6e..133b014 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dev = [ "pytest-random-order", "ophyd_devices", "bec_server", + "requests-mock", ] [project.entry-points."bec"] diff --git a/tests/tests_devices/test_smargon.py b/tests/tests_devices/test_smargon.py new file mode 100644 index 0000000..bbb215c --- /dev/null +++ b/tests/tests_devices/test_smargon.py @@ -0,0 +1,50 @@ +from copy import copy +from threading import RLock +from unittest.mock import ANY, MagicMock, patch + +import pytest + + +class MockServer: + def __init__(self) -> None: + self.lock = RLock() + self.mock_data = {"SHX": 1.0, "SHY": 1.0, "SHZ": 1.0, "PHI": 1.0, "CHI": 1.0} + + def get(self, endpoint): + with self.lock: + return copy(self.mock_data) + + def put(self, val: dict[str, float]): + with self.lock: + self.mock_data.update(val) + + +@pytest.fixture +def smargon(): + mock_server = MockServer() + from pxii_bec.devices.smargopolo_smargon import Smargon + + s = Smargon(name="smargon", prefix="http://test-smargopolo.psi.ch") + s.controller._rest_get = mock_server.get + s.controller._rest_put = mock_server.put + yield s + s.controller._stop_monitor_readback_event.set() + + +class TestSmargon: + def test_smargon_read(self, smargon): + smargon.wait_for_connection() + reading = smargon.read() + assert dict(reading) == { + "smargon_x": {"value": 1.0, "timestamp": ANY}, + "smargon_y": {"value": 1.0, "timestamp": ANY}, + "smargon_z": {"value": 1.0, "timestamp": ANY}, + "smargon_phi": {"value": 1.0, "timestamp": ANY}, + "smargon_chi": {"value": 1.0, "timestamp": ANY}, + } + + def test_smargon_set_with_status(self, smargon): + smargon.wait_for_connection() + st = smargon.x.set(5.0) + st.wait(timeout=1) + assert smargon.x.get() == 5.0 -- 2.52.0