feat: Smargopolo Smargon device v1 #11
@@ -42,4 +42,17 @@ dcm_froll:
|
||||
deviceTags:
|
||||
- dcm
|
||||
readOnly: false
|
||||
softwareTrigger: false
|
||||
|
||||
smargon:
|
||||
description: REST-based device which connects to Smargopolo
|
||||
deviceClass: pxii_bec.devices.smargopolo_smargon.Smargon
|
||||
deviceConfig: {prefix: 'http://x10sa-smargopolo.psi.ch:3000'}
|
||||
onFailure: buffer
|
||||
enabled: True
|
||||
readoutPriority: baseline
|
||||
deviceTags:
|
||||
- smargon
|
||||
- motors
|
||||
readOnly: false
|
||||
softwareTrigger: false
|
||||
@@ -0,0 +1,165 @@
|
||||
import time
|
||||
from threading import Event, RLock, Thread
|
||||
from typing import Any
|
||||
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd import OphydObject
|
||||
from ophyd_devices import PSIDeviceBase
|
||||
from ophyd_devices.utils.socket import SocketSignal
|
||||
from requests import Response, get, put
|
||||
|
||||
_TIMESTAMP_ID = "__timestamp"
|
||||
_POLL_INTERVAL_SLOW = 0.1
|
||||
|
||||
|
||||
class HttpRestError(Exception):
|
||||
"""Error for rest calls from a HttpRestSignal."""
|
||||
|
||||
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
|
||||
method, url = resp.request.method, resp.request.url
|
||||
data = f"{str(value)} to " if value is not None else ""
|
||||
super().__init__(
|
||||
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
class SmargonSignal(SocketSignal):
|
||||
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController"""
|
||||
|
||||
def __init__(self, *args, axis_identifier: str, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
controller: SmargonController | None = getattr(self.root, "controller", None)
|
||||
if controller is None:
|
||||
raise TypeError("SmargonSignal must be used in a device with a SmargonController")
|
||||
self._controller = controller
|
||||
self._axis_id = axis_identifier
|
||||
self._controller.register(self._axis_id)
|
||||
|
||||
def _socket_get(self): # type: ignore
|
||||
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
|
||||
self._axis_id
|
||||
) or (0.0, 0.0)
|
||||
return self._readback
|
||||
|
||||
def _socket_set(self, val: float):
|
||||
self._controller.put(self._axis_id, val)
|
||||
|
||||
def get(self, **kwargs):
|
||||
if self._controller.monitor_stopped():
|
||||
self._controller.start_monitor()
|
||||
return super().get(**kwargs)
|
||||
|
||||
|
||||
class SmargonController(OphydObject):
|
||||
"""Controller to consolidate polling loops and other REST calls for the smargon"""
|
||||
|
||||
def __init__(self, *, prefix, **kwargs):
|
||||
self._prefix = prefix
|
||||
self._readback_endpoint = "/readbackSCS"
|
||||
self._target_endpoint = "/targetSCS"
|
||||
self._targets = {}
|
||||
self._signal_registry: set[str] = set()
|
||||
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._setup_readback()
|
||||
|
||||
def _setup_readback(self):
|
||||
self._readbacks: dict[str, float] = {}
|
||||
self._stop_monitor_readback_event = Event()
|
||||
self._readback_lock = RLock()
|
||||
self._monitor_readback_thread = Thread(
|
||||
target=self._monitor,
|
||||
args=[
|
||||
self._readback_endpoint,
|
||||
self._stop_monitor_readback_event,
|
||||
self._readback_lock,
|
||||
self._readbacks,
|
||||
],
|
||||
)
|
||||
|
||||
def manual_update(self):
|
||||
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
|
||||
|
||||
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
|
||||
data = self._rest_get(endpoint)
|
||||
timestamp = time.monotonic()
|
||||
with lock:
|
||||
buffer.update(data)
|
||||
buffer["__timestamp"] = timestamp
|
||||
|
||||
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
|
||||
while not event.is_set():
|
||||
self._update_reading(endpoint, lock, buffer)
|
||||
time.sleep(self._readback_poll_interval)
|
||||
|
||||
def _clean_monitor(self):
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
self._stop_monitor_readback_event.set()
|
||||
self._monitor_readback_thread.join(timeout=2)
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
raise RuntimeError("Failed to clean up Smargon monitor thread.")
|
||||
|
||||
def register(self, axis_id: str):
|
||||
self._signal_registry.add(axis_id)
|
||||
|
||||
def _rest_get(self, endpoint):
|
||||
resp = get(self._prefix + endpoint)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp)
|
||||
return resp.json()
|
||||
|
||||
def _rest_put(self, val: dict[str, float]):
|
||||
resp = put(self._prefix + self._target_endpoint, params=val)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp, value=val)
|
||||
|
||||
def start_monitor(self):
|
||||
"""Start or restart the automonitor thread."""
|
||||
self._clean_monitor()
|
||||
self._setup_readback()
|
||||
self._monitor_readback_thread.start()
|
||||
|
||||
def monitor_stopped(self):
|
||||
return not self._monitor_readback_thread.is_alive()
|
||||
|
||||
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
|
||||
with self._readback_lock:
|
||||
if axis_id not in self._readbacks or _TIMESTAMP_ID not in self._readbacks:
|
||||
return None
|
||||
return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore
|
||||
|
||||
def put(self, axis: str, val: float):
|
||||
self._rest_put({axis: val})
|
||||
|
||||
def stop(self):
|
||||
# There doesn't appear to be a stop endpoint on the server
|
||||
# Best effort: set the target to the current position
|
||||
self._rest_put(self._readbacks)
|
||||
|
||||
|
||||
class Smargon(PSIDeviceBase):
|
||||
|
||||
x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01)
|
||||
y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01)
|
||||
z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01)
|
||||
phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01)
|
||||
chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01)
|
||||
|
||||
def __init__(
|
||||
self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs
|
||||
):
|
||||
self.controller = SmargonController(prefix=prefix)
|
||||
super().__init__(
|
||||
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_connection(self, **kwargs): # type: ignore
|
||||
self.controller.start_monitor()
|
||||
self.controller.manual_update()
|
||||
return super().wait_for_connection(**kwargs)
|
||||
|
||||
def stop(self, *, success: bool = False) -> None:
|
||||
self.controller.stop()
|
||||
return super().stop(success=success)
|
||||
@@ -25,6 +25,7 @@ dev = [
|
||||
"pytest-random-order",
|
||||
"ophyd_devices",
|
||||
"bec_server",
|
||||
"requests-mock",
|
||||
]
|
||||
|
||||
[project.entry-points."bec"]
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
"""A mock smargopolo REST interface with mock motoers, for testing devices against"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Iterable
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Query, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
AXES = ["SHX", "SHY", "SHZ", "PHI", "CHI"]
|
||||
|
||||
|
||||
class Motor:
|
||||
def __init__(self, velocity: float = 1.0):
|
||||
self.position = 0.0
|
||||
self.target = 0.0
|
||||
self.velocity = velocity
|
||||
self.moving = False
|
||||
self._last_update = time.monotonic()
|
||||
|
||||
def update(self):
|
||||
now = time.monotonic()
|
||||
dt = now - self._last_update
|
||||
self._last_update = now
|
||||
|
||||
if not self.moving:
|
||||
return
|
||||
|
||||
jitter_factor = random.random() * 0.05 - 0.025 # +- 2.5% jitter in step
|
||||
distance = self.target - self.position
|
||||
direction = 1 if distance > 0 else -1
|
||||
step = direction * self.velocity * dt
|
||||
|
||||
if abs(step) >= abs(distance):
|
||||
self.position = self.target + (step * jitter_factor)
|
||||
self.moving = False
|
||||
else:
|
||||
self.position += step * (1 + jitter_factor)
|
||||
|
||||
|
||||
motors: dict[str, Motor] = {
|
||||
"SHX": Motor(velocity=3),
|
||||
"SHY": Motor(velocity=2.5),
|
||||
"SHZ": Motor(velocity=2),
|
||||
"PHI": Motor(velocity=1.0),
|
||||
"CHI": Motor(velocity=0.7),
|
||||
}
|
||||
|
||||
|
||||
class MoveRequest(BaseModel):
|
||||
target: float
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
async def updater():
|
||||
while True:
|
||||
for motor in motors.values():
|
||||
motor.update()
|
||||
await asyncio.sleep(0.02) # 50 Hz update loop
|
||||
|
||||
task = asyncio.create_task(updater())
|
||||
yield
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
def validate_axes(axes: Iterable[str] | None) -> list[str]:
|
||||
if axes is None:
|
||||
return AXES
|
||||
for a in axes:
|
||||
if a not in AXES:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown axis: {a}")
|
||||
return list(axes)
|
||||
|
||||
|
||||
@app.get("/readbackSCS")
|
||||
async def readback_scs(axis: list[str] | None = Query(None)):
|
||||
selected_axes = validate_axes(axis)
|
||||
return {ax: motors[ax].position for ax in selected_axes}
|
||||
|
||||
|
||||
@app.put("/targetSCS")
|
||||
async def target_scs(req: Request):
|
||||
targets = {ax: float(t) for ax, t in req.query_params.items()}
|
||||
if targets is None:
|
||||
return {}
|
||||
|
||||
selected_axes = validate_axes(targets.keys())
|
||||
|
||||
for a in selected_axes:
|
||||
motor = motors[a]
|
||||
motor.update()
|
||||
motor.target = targets[a]
|
||||
motor.moving = True
|
||||
|
||||
return {"targets": targets, "message": "Move started"}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, reload=False)
|
||||
@@ -0,0 +1,50 @@
|
||||
from copy import copy
|
||||
from threading import RLock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class MockServer:
|
||||
def __init__(self) -> None:
|
||||
self.lock = RLock()
|
||||
self.mock_data = {"SHX": 1.0, "SHY": 1.0, "SHZ": 1.0, "PHI": 1.0, "CHI": 1.0}
|
||||
|
||||
def get(self, endpoint):
|
||||
with self.lock:
|
||||
return copy(self.mock_data)
|
||||
|
||||
def put(self, val: dict[str, float]):
|
||||
with self.lock:
|
||||
self.mock_data.update(val)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def smargon():
|
||||
mock_server = MockServer()
|
||||
from pxii_bec.devices.smargopolo_smargon import Smargon
|
||||
|
||||
s = Smargon(name="smargon", prefix="http://test-smargopolo.psi.ch")
|
||||
s.controller._rest_get = mock_server.get
|
||||
s.controller._rest_put = mock_server.put
|
||||
yield s
|
||||
s.controller._stop_monitor_readback_event.set()
|
||||
|
||||
|
||||
class TestSmargon:
|
||||
def test_smargon_read(self, smargon):
|
||||
smargon.wait_for_connection()
|
||||
reading = smargon.read()
|
||||
assert dict(reading) == {
|
||||
"smargon_x": {"value": 1.0, "timestamp": ANY},
|
||||
"smargon_y": {"value": 1.0, "timestamp": ANY},
|
||||
"smargon_z": {"value": 1.0, "timestamp": ANY},
|
||||
"smargon_phi": {"value": 1.0, "timestamp": ANY},
|
||||
"smargon_chi": {"value": 1.0, "timestamp": ANY},
|
||||
}
|
||||
|
||||
def test_smargon_set_with_status(self, smargon):
|
||||
smargon.wait_for_connection()
|
||||
st = smargon.x.set(5.0)
|
||||
st.wait(timeout=1)
|
||||
assert smargon.x.get() == 5.0
|
||||
Reference in New Issue
Block a user