feat: Smargopolo Smargon device v1 #11

Merged
perl_d merged 2 commits from feat/smargopolo_smargon into main 2026-03-06 16:27:06 +01:00
5 changed files with 339 additions and 0 deletions
@@ -42,4 +42,17 @@ dcm_froll:
deviceTags:
- dcm
readOnly: false
softwareTrigger: false
smargon:
description: REST-based device which connects to Smargopolo
deviceClass: pxii_bec.devices.smargopolo_smargon.Smargon
deviceConfig: {prefix: 'http://x10sa-smargopolo.psi.ch:3000'}
onFailure: buffer
enabled: True
readoutPriority: baseline
deviceTags:
- smargon
- motors
readOnly: false
softwareTrigger: false
+165
View File
@@ -0,0 +1,165 @@
import time
from threading import Event, RLock, Thread
from typing import Any
from ophyd import Component as Cpt
from ophyd import OphydObject
from ophyd_devices import PSIDeviceBase
from ophyd_devices.utils.socket import SocketSignal
from requests import Response, get, put
_TIMESTAMP_ID = "__timestamp"
_POLL_INTERVAL_SLOW = 0.1
class HttpRestError(Exception):
"""Error for rest calls from a HttpRestSignal."""
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
method, url = resp.request.method, resp.request.url
data = f"{str(value)} to " if value is not None else ""
super().__init__(
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
*args,
)
class SmargonSignal(SocketSignal):
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController"""
def __init__(self, *args, axis_identifier: str, **kwargs):
super().__init__(*args, **kwargs)
controller: SmargonController | None = getattr(self.root, "controller", None)
if controller is None:
raise TypeError("SmargonSignal must be used in a device with a SmargonController")
self._controller = controller
self._axis_id = axis_identifier
self._controller.register(self._axis_id)
def _socket_get(self): # type: ignore
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
self._axis_id
) or (0.0, 0.0)
return self._readback
def _socket_set(self, val: float):
self._controller.put(self._axis_id, val)
def get(self, **kwargs):
if self._controller.monitor_stopped():
self._controller.start_monitor()
return super().get(**kwargs)
class SmargonController(OphydObject):
"""Controller to consolidate polling loops and other REST calls for the smargon"""
def __init__(self, *, prefix, **kwargs):
self._prefix = prefix
self._readback_endpoint = "/readbackSCS"
self._target_endpoint = "/targetSCS"
self._targets = {}
self._signal_registry: set[str] = set()
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
super().__init__(**kwargs)
self._setup_readback()
def _setup_readback(self):
self._readbacks: dict[str, float] = {}
self._stop_monitor_readback_event = Event()
self._readback_lock = RLock()
self._monitor_readback_thread = Thread(
target=self._monitor,
args=[
self._readback_endpoint,
self._stop_monitor_readback_event,
self._readback_lock,
self._readbacks,
],
)
def manual_update(self):
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
data = self._rest_get(endpoint)
timestamp = time.monotonic()
with lock:
buffer.update(data)
buffer["__timestamp"] = timestamp
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
while not event.is_set():
self._update_reading(endpoint, lock, buffer)
time.sleep(self._readback_poll_interval)
def _clean_monitor(self):
if self._monitor_readback_thread.is_alive():
self._stop_monitor_readback_event.set()
self._monitor_readback_thread.join(timeout=2)
if self._monitor_readback_thread.is_alive():
raise RuntimeError("Failed to clean up Smargon monitor thread.")
def register(self, axis_id: str):
self._signal_registry.add(axis_id)
def _rest_get(self, endpoint):
resp = get(self._prefix + endpoint)
if not resp.ok:
raise HttpRestError(resp)
return resp.json()
def _rest_put(self, val: dict[str, float]):
resp = put(self._prefix + self._target_endpoint, params=val)
if not resp.ok:
raise HttpRestError(resp, value=val)
def start_monitor(self):
"""Start or restart the automonitor thread."""
self._clean_monitor()
self._setup_readback()
self._monitor_readback_thread.start()
def monitor_stopped(self):
return not self._monitor_readback_thread.is_alive()
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
with self._readback_lock:
if axis_id not in self._readbacks or _TIMESTAMP_ID not in self._readbacks:
return None
return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore
def put(self, axis: str, val: float):
self._rest_put({axis: val})
def stop(self):
# There doesn't appear to be a stop endpoint on the server
# Best effort: set the target to the current position
self._rest_put(self._readbacks)
class Smargon(PSIDeviceBase):
x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01)
y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01)
z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01)
phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01)
chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01)
def __init__(
self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs
):
self.controller = SmargonController(prefix=prefix)
super().__init__(
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
)
def wait_for_connection(self, **kwargs): # type: ignore
self.controller.start_monitor()
self.controller.manual_update()
return super().wait_for_connection(**kwargs)
def stop(self, *, success: bool = False) -> None:
self.controller.stop()
return super().stop(success=success)
+1
View File
@@ -25,6 +25,7 @@ dev = [
"pytest-random-order",
"ophyd_devices",
"bec_server",
"requests-mock",
]
[project.entry-points."bec"]
+110
View File
@@ -0,0 +1,110 @@
"""A mock smargopolo REST interface with mock motoers, for testing devices against"""
import asyncio
import random
import time
from contextlib import asynccontextmanager
from typing import Iterable
import uvicorn
from fastapi import FastAPI, HTTPException, Query, Request
from pydantic import BaseModel
AXES = ["SHX", "SHY", "SHZ", "PHI", "CHI"]
class Motor:
def __init__(self, velocity: float = 1.0):
self.position = 0.0
self.target = 0.0
self.velocity = velocity
self.moving = False
self._last_update = time.monotonic()
def update(self):
now = time.monotonic()
dt = now - self._last_update
self._last_update = now
if not self.moving:
return
jitter_factor = random.random() * 0.05 - 0.025 # +- 2.5% jitter in step
distance = self.target - self.position
direction = 1 if distance > 0 else -1
step = direction * self.velocity * dt
if abs(step) >= abs(distance):
self.position = self.target + (step * jitter_factor)
self.moving = False
else:
self.position += step * (1 + jitter_factor)
motors: dict[str, Motor] = {
"SHX": Motor(velocity=3),
"SHY": Motor(velocity=2.5),
"SHZ": Motor(velocity=2),
"PHI": Motor(velocity=1.0),
"CHI": Motor(velocity=0.7),
}
class MoveRequest(BaseModel):
target: float
@asynccontextmanager
async def lifespan(app: FastAPI):
async def updater():
while True:
for motor in motors.values():
motor.update()
await asyncio.sleep(0.02) # 50 Hz update loop
task = asyncio.create_task(updater())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
app = FastAPI(lifespan=lifespan)
def validate_axes(axes: Iterable[str] | None) -> list[str]:
if axes is None:
return AXES
for a in axes:
if a not in AXES:
raise HTTPException(status_code=404, detail=f"Unknown axis: {a}")
return list(axes)
@app.get("/readbackSCS")
async def readback_scs(axis: list[str] | None = Query(None)):
selected_axes = validate_axes(axis)
return {ax: motors[ax].position for ax in selected_axes}
@app.put("/targetSCS")
async def target_scs(req: Request):
targets = {ax: float(t) for ax, t in req.query_params.items()}
if targets is None:
return {}
selected_axes = validate_axes(targets.keys())
for a in selected_axes:
motor = motors[a]
motor.update()
motor.target = targets[a]
motor.moving = True
return {"targets": targets, "message": "Move started"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, reload=False)
+50
View File
@@ -0,0 +1,50 @@
from copy import copy
from threading import RLock
from unittest.mock import ANY, MagicMock, patch
import pytest
class MockServer:
def __init__(self) -> None:
self.lock = RLock()
self.mock_data = {"SHX": 1.0, "SHY": 1.0, "SHZ": 1.0, "PHI": 1.0, "CHI": 1.0}
def get(self, endpoint):
with self.lock:
return copy(self.mock_data)
def put(self, val: dict[str, float]):
with self.lock:
self.mock_data.update(val)
@pytest.fixture
def smargon():
mock_server = MockServer()
from pxii_bec.devices.smargopolo_smargon import Smargon
s = Smargon(name="smargon", prefix="http://test-smargopolo.psi.ch")
s.controller._rest_get = mock_server.get
s.controller._rest_put = mock_server.put
yield s
s.controller._stop_monitor_readback_event.set()
class TestSmargon:
def test_smargon_read(self, smargon):
smargon.wait_for_connection()
reading = smargon.read()
assert dict(reading) == {
"smargon_x": {"value": 1.0, "timestamp": ANY},
"smargon_y": {"value": 1.0, "timestamp": ANY},
"smargon_z": {"value": 1.0, "timestamp": ANY},
"smargon_phi": {"value": 1.0, "timestamp": ANY},
"smargon_chi": {"value": 1.0, "timestamp": ANY},
}
def test_smargon_set_with_status(self, smargon):
smargon.wait_for_connection()
st = smargon.x.set(5.0)
st.wait(timeout=1)
assert smargon.x.get() == 5.0