Add Aerotech ophyd device and config #14
@@ -55,4 +55,17 @@ smargon:
|
||||
- smargon
|
||||
- motors
|
||||
readOnly: false
|
||||
softwareTrigger: false
|
||||
softwareTrigger: false
|
||||
|
||||
aerotech:
|
||||
description: REST-based device which connects to AareScan
|
||||
deviceClass: pxii_bec.devices.aerotech
|
||||
deviceConfig: { prefix: "http://mx-x10sa-queue-01:5234/" }
|
||||
onFailure: buffer
|
||||
enabled: True
|
||||
readoutPriority: baseline
|
||||
deviceTags:
|
||||
- aerotech
|
||||
- motors
|
||||
readOnly: false
|
||||
softwareTrigger: false
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
from ophyd import Component as Cpt
|
||||
|
||||
from .http import TIMESTAMP_ID, HttpDeviceController, HttpDeviceSignal, HttpOphydDevice
|
||||
|
||||
|
||||
class AerotechController(HttpDeviceController):
|
||||
_readback_endpoint = "status"
|
||||
_target_endpoint = "position"
|
||||
|
||||
def __init__(self, *, prefix, **kwargs):
|
||||
self._readbacks: dict[str, dict[str, float | bool]] = {}
|
||||
super().__init__(prefix=prefix, **kwargs)
|
||||
|
||||
def put(self, axis: str, val: float):
|
||||
self._rest_post(body={axis: val})
|
||||
|
||||
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
|
||||
with self._readback_lock:
|
||||
if axis_id not in self._readbacks or TIMESTAMP_ID not in self._readbacks:
|
||||
return None
|
||||
return self._readbacks.get(axis_id)["pos"], self._readbacks.get(TIMESTAMP_ID) # type: ignore
|
||||
|
||||
|
||||
class Aerotech(HttpOphydDevice):
|
||||
controller_class = AerotechController
|
||||
|
||||
x = Cpt(HttpDeviceSignal, axis_identifier="x", tolerance=0.01)
|
||||
y = Cpt(HttpDeviceSignal, axis_identifier="y", tolerance=0.01)
|
||||
z = Cpt(HttpDeviceSignal, axis_identifier="z", tolerance=0.01)
|
||||
u = Cpt(HttpDeviceSignal, axis_identifier="u", tolerance=0.01)
|
||||
vel_u_deg_s = Cpt(HttpDeviceSignal, axis_identifier="vel_u_deg_s", tolerance=0.01)
|
||||
|
||||
|
||||
def _test():
|
||||
a = Aerotech(name="aerotech", prefix="http://mx-x10sa-queue-01:5234")
|
||||
a.wait_for_connection()
|
||||
return a
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
aerotech = _test()
|
||||
print(aerotech.read())
|
||||
aerotech.stop()
|
||||
@@ -0,0 +1,178 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from threading import Event, RLock, Thread
|
||||
from typing import Any
|
||||
|
||||
from ophyd import OphydObject
|
||||
from ophyd_devices import PSIDeviceBase
|
||||
from ophyd_devices.utils.socket import SocketSignal
|
||||
from requests import Response, Session
|
||||
|
||||
TIMESTAMP_ID = "__timestamp"
|
||||
_POLL_INTERVAL_SLOW = 0.1
|
||||
|
||||
|
||||
class HttpRestError(Exception):
|
||||
"""Error for rest calls from a HttpRestSignal."""
|
||||
|
||||
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
|
||||
method, url = resp.request.method, resp.request.url
|
||||
data = f"{str(value)} to " if value is not None else ""
|
||||
super().__init__(
|
||||
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
class HttpDeviceController(OphydObject, ABC):
|
||||
"""Controller to consolidate polling loops and other REST calls for devices which communicate
|
||||
with HTTP REST interfaces"""
|
||||
|
||||
_readback_endpoint: str
|
||||
_target_endpoint: str
|
||||
|
||||
def __init__(self, *, prefix, **kwargs):
|
||||
self._readbacks: dict
|
||||
self._session = Session()
|
||||
self._prefix = prefix
|
||||
self._targets = {}
|
||||
self._signal_registry: set[str] = set()
|
||||
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._setup_readback()
|
||||
|
||||
def _setup_readback(self):
|
||||
self._stop_monitor_readback_event = Event()
|
||||
self._readback_lock = RLock()
|
||||
self._monitor_readback_thread = Thread(
|
||||
target=self._monitor,
|
||||
args=[
|
||||
self._readback_endpoint,
|
||||
self._stop_monitor_readback_event,
|
||||
self._readback_lock,
|
||||
self._readbacks,
|
||||
],
|
||||
)
|
||||
|
||||
def manual_update(self):
|
||||
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
|
||||
|
||||
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
|
||||
data = self._rest_get(endpoint)
|
||||
timestamp = time.monotonic()
|
||||
with lock:
|
||||
buffer.update(data)
|
||||
buffer["__timestamp"] = timestamp
|
||||
|
||||
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
|
||||
while not event.is_set():
|
||||
self._update_reading(endpoint, lock, buffer)
|
||||
time.sleep(self._readback_poll_interval)
|
||||
|
||||
def _clean_monitor(self):
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
self._stop_monitor_readback_event.set()
|
||||
self._monitor_readback_thread.join(timeout=2)
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
raise RuntimeError("Failed to clean up Aerotech monitor thread.")
|
||||
|
||||
def register(self, axis_id: str):
|
||||
self._signal_registry.add(axis_id)
|
||||
|
||||
def _rest_get(self, endpoint):
|
||||
resp = self._session.get(self._prefix + endpoint)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp)
|
||||
return resp.json()
|
||||
|
||||
def _rest_put(self, params: dict | None = None, body: dict | None = None):
|
||||
resp = self._session.put(self._prefix + self._target_endpoint, params=params, json=body)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp, value=params)
|
||||
|
||||
def _rest_post(self, params: dict | None = None, body: dict | None = None):
|
||||
resp = self._session.post(self._prefix + self._target_endpoint, params=params, json=body)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp, value=params)
|
||||
|
||||
def start_monitor(self):
|
||||
"""Start or restart the automonitor thread."""
|
||||
self._clean_monitor()
|
||||
self._setup_readback()
|
||||
self._monitor_readback_thread.start()
|
||||
|
||||
def monitor_stopped(self):
|
||||
return not self._monitor_readback_thread.is_alive()
|
||||
|
||||
def put(self, axis: str, val: float):
|
||||
self._rest_put({axis: val})
|
||||
|
||||
@abstractmethod
|
||||
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
|
||||
"""Return a tuple (reading, timestamp) if the axis_id exists"""
|
||||
|
||||
def stop(self):
|
||||
# There doesn't appear to be a stop endpoint on the server
|
||||
# Best effort: set the target to the current position
|
||||
pass
|
||||
# TODO: self._rest_put(self._readbacks)
|
||||
|
||||
|
||||
class HttpDeviceSignal(SocketSignal):
|
||||
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the Aerotech
|
||||
Controller"""
|
||||
|
||||
def __init__(self, *args, axis_identifier: str, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
controller: HttpDeviceController | None = getattr(self.root, "controller", None)
|
||||
if controller is None:
|
||||
raise TypeError("HttpDeviceSignal must be used in a device with a HttpDeviceController")
|
||||
self._controller = controller
|
||||
self._axis_id = axis_identifier
|
||||
self._controller.register(self._axis_id)
|
||||
|
||||
def _socket_get(self): # type: ignore
|
||||
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
|
||||
self._axis_id
|
||||
) or (0.0, 0.0)
|
||||
return self._readback
|
||||
|
||||
def _socket_set(self, val: float):
|
||||
self._controller.put(self._axis_id, val)
|
||||
|
||||
def get(self, **kwargs):
|
||||
if self._controller.monitor_stopped():
|
||||
self._controller.start_monitor()
|
||||
return super().get(**kwargs)
|
||||
|
||||
|
||||
class HttpOphydDevice(PSIDeviceBase):
|
||||
controller_class: type[HttpDeviceController]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str,
|
||||
prefix: str = "",
|
||||
scan_info=None,
|
||||
device_manager=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.controller = self.controller_class(prefix=prefix)
|
||||
super().__init__(
|
||||
name=name,
|
||||
prefix=prefix,
|
||||
scan_info=scan_info,
|
||||
device_manager=device_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def wait_for_connection(self, **kwargs): # type: ignore
|
||||
self.controller.start_monitor()
|
||||
self.controller.manual_update()
|
||||
return super().wait_for_connection(**kwargs)
|
||||
|
||||
def stop(self, *, success: bool = False) -> None:
|
||||
self.controller.stop()
|
||||
return super().stop(success=success)
|
||||
@@ -1,129 +1,21 @@
|
||||
import time
|
||||
from threading import Event, RLock, Thread
|
||||
from typing import Any
|
||||
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd import OphydObject
|
||||
from ophyd_devices import PSIDeviceBase
|
||||
from ophyd_devices.utils.socket import SocketSignal
|
||||
from requests import Response, Session
|
||||
|
||||
from .http import HttpDeviceController, HttpDeviceSignal, HttpOphydDevice
|
||||
|
||||
_TIMESTAMP_ID = "__timestamp"
|
||||
_POLL_INTERVAL_SLOW = 0.1
|
||||
|
||||
|
||||
class HttpRestError(Exception):
|
||||
"""Error for rest calls from a HttpRestSignal."""
|
||||
|
||||
def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None:
|
||||
method, url = resp.request.method, resp.request.url
|
||||
data = f"{str(value)} to " if value is not None else ""
|
||||
super().__init__(
|
||||
f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.",
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
class SmargonSignal(SocketSignal):
|
||||
"""Ophyd signal which gets and puts to a REST API rather than EPICS PVs, mediated through the SmargonController"""
|
||||
|
||||
def __init__(self, *args, axis_identifier: str, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
controller: SmargonController | None = getattr(self.root, "controller", None)
|
||||
if controller is None:
|
||||
raise TypeError("SmargonSignal must be used in a device with a SmargonController")
|
||||
self._controller = controller
|
||||
self._axis_id = axis_identifier
|
||||
self._controller.register(self._axis_id)
|
||||
|
||||
def _socket_get(self): # type: ignore
|
||||
self._readback, self.metadata["timestamp"] = self._controller.get_readback(
|
||||
self._axis_id
|
||||
) or (0.0, 0.0)
|
||||
return self._readback
|
||||
|
||||
def _socket_set(self, val: float):
|
||||
self._controller.put(self._axis_id, val)
|
||||
|
||||
def get(self, **kwargs):
|
||||
if self._controller.monitor_stopped():
|
||||
self._controller.start_monitor()
|
||||
return super().get(**kwargs)
|
||||
|
||||
|
||||
class SmargonController(OphydObject):
|
||||
class SmargonController(HttpDeviceController):
|
||||
"""Controller to consolidate polling loops and other REST calls for the smargon"""
|
||||
|
||||
_readback_endpoint = "/readbackSCS"
|
||||
_target_endpoint = "/targetSCS"
|
||||
|
||||
def __init__(self, *, prefix, **kwargs):
|
||||
self._session = Session()
|
||||
self._prefix = prefix
|
||||
self._readback_endpoint = "/readbackSCS"
|
||||
self._target_endpoint = "/targetSCS"
|
||||
self._targets = {}
|
||||
self._signal_registry: set[str] = set()
|
||||
self._readback_poll_interval: float = _POLL_INTERVAL_SLOW
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._setup_readback()
|
||||
|
||||
def _setup_readback(self):
|
||||
self._readbacks: dict[str, float] = {}
|
||||
self._stop_monitor_readback_event = Event()
|
||||
self._readback_lock = RLock()
|
||||
self._monitor_readback_thread = Thread(
|
||||
target=self._monitor,
|
||||
args=[
|
||||
self._readback_endpoint,
|
||||
self._stop_monitor_readback_event,
|
||||
self._readback_lock,
|
||||
self._readbacks,
|
||||
],
|
||||
)
|
||||
|
||||
def manual_update(self):
|
||||
self._update_reading(self._readback_endpoint, self._readback_lock, self._readbacks)
|
||||
|
||||
def _update_reading(self, endpoint: str, lock: RLock, buffer: dict):
|
||||
data = self._rest_get(endpoint)
|
||||
timestamp = time.monotonic()
|
||||
with lock:
|
||||
buffer.update(data)
|
||||
buffer["__timestamp"] = timestamp
|
||||
|
||||
def _monitor(self, endpoint: str, event: Event, lock: RLock, buffer: dict):
|
||||
while not event.is_set():
|
||||
self._update_reading(endpoint, lock, buffer)
|
||||
time.sleep(self._readback_poll_interval)
|
||||
|
||||
def _clean_monitor(self):
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
self._stop_monitor_readback_event.set()
|
||||
self._monitor_readback_thread.join(timeout=2)
|
||||
if self._monitor_readback_thread.is_alive():
|
||||
raise RuntimeError("Failed to clean up Smargon monitor thread.")
|
||||
|
||||
def register(self, axis_id: str):
|
||||
self._signal_registry.add(axis_id)
|
||||
|
||||
def _rest_get(self, endpoint):
|
||||
resp = self._session.get(self._prefix + endpoint)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp)
|
||||
return resp.json()
|
||||
|
||||
def _rest_put(self, val: dict[str, float]):
|
||||
resp = self._session.put(self._prefix + self._target_endpoint, params=val)
|
||||
if not resp.ok:
|
||||
raise HttpRestError(resp, value=val)
|
||||
|
||||
def start_monitor(self):
|
||||
"""Start or restart the automonitor thread."""
|
||||
self._clean_monitor()
|
||||
self._setup_readback()
|
||||
self._monitor_readback_thread.start()
|
||||
|
||||
def monitor_stopped(self):
|
||||
return not self._monitor_readback_thread.is_alive()
|
||||
super().__init__(prefix=prefix, **kwargs)
|
||||
|
||||
def get_readback(self, axis_id: str) -> tuple[float, float] | None:
|
||||
with self._readback_lock:
|
||||
@@ -132,34 +24,19 @@ class SmargonController(OphydObject):
|
||||
return self._readbacks.get(axis_id), self._readbacks.get(_TIMESTAMP_ID) # type: ignore
|
||||
|
||||
def put(self, axis: str, val: float):
|
||||
self._rest_put({axis: val})
|
||||
self._rest_put(params={axis: val})
|
||||
|
||||
def stop(self):
|
||||
# There doesn't appear to be a stop endpoint on the server
|
||||
# Best effort: set the target to the current position
|
||||
self._rest_put(self._readbacks)
|
||||
self._rest_put(params=self._readbacks)
|
||||
|
||||
|
||||
class Smargon(PSIDeviceBase):
|
||||
x = Cpt(SmargonSignal, axis_identifier="SHX", tolerance=0.01)
|
||||
y = Cpt(SmargonSignal, axis_identifier="SHY", tolerance=0.01)
|
||||
z = Cpt(SmargonSignal, axis_identifier="SHZ", tolerance=0.01)
|
||||
phi = Cpt(SmargonSignal, axis_identifier="PHI", tolerance=0.01)
|
||||
chi = Cpt(SmargonSignal, axis_identifier="CHI", tolerance=0.01)
|
||||
class Smargon(HttpOphydDevice):
|
||||
controller_class = SmargonController
|
||||
|
||||
def __init__(
|
||||
self, *, name: str, prefix: str = "", scan_info=None, device_manager=None, **kwargs
|
||||
):
|
||||
self.controller = SmargonController(prefix=prefix)
|
||||
super().__init__(
|
||||
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_connection(self, **kwargs): # type: ignore
|
||||
self.controller.start_monitor()
|
||||
self.controller.manual_update()
|
||||
return super().wait_for_connection(**kwargs)
|
||||
|
||||
def stop(self, *, success: bool = False) -> None:
|
||||
self.controller.stop()
|
||||
return super().stop(success=success)
|
||||
x = Cpt(HttpDeviceSignal, axis_identifier="SHX", tolerance=0.01)
|
||||
y = Cpt(HttpDeviceSignal, axis_identifier="SHY", tolerance=0.01)
|
||||
z = Cpt(HttpDeviceSignal, axis_identifier="SHZ", tolerance=0.01)
|
||||
phi = Cpt(HttpDeviceSignal, axis_identifier="PHI", tolerance=0.01)
|
||||
chi = Cpt(HttpDeviceSignal, axis_identifier="CHI", tolerance=0.01)
|
||||
|
||||
@@ -77,3 +77,6 @@ good-names-rgxs = [
|
||||
".*_2D.*",
|
||||
".*_1D.*",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from copy import copy
|
||||
from threading import RLock
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class MockServer:
|
||||
def __init__(self) -> None:
|
||||
self.lock = RLock()
|
||||
self.mock_data = {
|
||||
"x": {"pos": 1.0},
|
||||
"y": {"pos": 1.0},
|
||||
"z": {"pos": 1.0},
|
||||
"u": {"pos": 1.0},
|
||||
"vel_u_deg_s": {"pos": 1.0},
|
||||
}
|
||||
|
||||
def get(self, endpoint):
|
||||
with self.lock:
|
||||
return copy(self.mock_data)
|
||||
|
||||
def put(self, params: dict | None = None, body: dict | None = None):
|
||||
with self.lock:
|
||||
assert body is not None
|
||||
for k, v in body.items():
|
||||
self.mock_data[k]["pos"] = v
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aerotech():
|
||||
mock_server = MockServer()
|
||||
from pxii_bec.devices.aerotech import Aerotech
|
||||
|
||||
s = Aerotech(name="aerotech", prefix="http://test-aerotech.psi.ch")
|
||||
s.controller._rest_get = mock_server.get
|
||||
s.controller._rest_post = mock_server.put
|
||||
yield s
|
||||
s.controller._stop_monitor_readback_event.set()
|
||||
|
||||
|
||||
class TestAerotech:
|
||||
def test_aerotech_read(self, aerotech):
|
||||
aerotech.wait_for_connection()
|
||||
reading = aerotech.read()
|
||||
assert dict(reading) == {
|
||||
"aerotech_x": {"value": 1.0, "timestamp": ANY},
|
||||
"aerotech_y": {"value": 1.0, "timestamp": ANY},
|
||||
"aerotech_z": {"value": 1.0, "timestamp": ANY},
|
||||
"aerotech_u": {"value": 1.0, "timestamp": ANY},
|
||||
"aerotech_vel_u_deg_s": {"value": 1.0, "timestamp": ANY},
|
||||
}
|
||||
|
||||
def test_aerotech_set_with_status(self, aerotech):
|
||||
aerotech.wait_for_connection()
|
||||
st = aerotech.x.set(5.0)
|
||||
st.wait(timeout=1)
|
||||
assert aerotech.x.get() == 5.0
|
||||
@@ -1,6 +1,6 @@
|
||||
from copy import copy
|
||||
from threading import RLock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
from unittest.mock import ANY
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -14,9 +14,10 @@ class MockServer:
|
||||
with self.lock:
|
||||
return copy(self.mock_data)
|
||||
|
||||
def put(self, val: dict[str, float]):
|
||||
def put(self, params: dict | None = None, body: dict | None = None):
|
||||
with self.lock:
|
||||
self.mock_data.update(val)
|
||||
assert params is not None
|
||||
self.mock_data.update(params)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user