mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-02 21:54:55 +01:00
fix(controller): add configurable timeout, en/disable controller axes on on/off
This commit is contained in:
@@ -130,7 +130,7 @@ class SimDeviceWithStatusStageUnstage(Device):
|
|||||||
|
|
||||||
|
|
||||||
class SynController(OphydObject):
|
class SynController(OphydObject):
|
||||||
def on(self):
|
def on(self, timeout: int = 10):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def off(self):
|
def off(self):
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ class SocketMock:
|
|||||||
self.sock = None
|
self.sock = None
|
||||||
self.open()
|
self.open()
|
||||||
|
|
||||||
def connect(self):
|
def connect(self, timeout: int = 10):
|
||||||
"""Mock connect method"""
|
"""Mock connect method"""
|
||||||
print(f"connecting to {self.host} port {self.port}")
|
print(f"connecting to {self.host} port {self.port}")
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class SocketMock:
|
|||||||
"""Mock receive method"""
|
"""Mock receive method"""
|
||||||
return self._recv(buffer_length=buffer_length)
|
return self._recv(buffer_length=buffer_length)
|
||||||
|
|
||||||
def open(self):
|
def open(self, timeout: int = 10):
|
||||||
"""Mock open method"""
|
"""Mock open method"""
|
||||||
self._initialize_socket()
|
self._initialize_socket()
|
||||||
self.is_open = True
|
self.is_open = True
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
import functools
|
import functools
|
||||||
import threading
|
import threading
|
||||||
|
from typing import TYPE_CHECKING, Type
|
||||||
|
|
||||||
from bec_lib import bec_logger
|
from bec_lib import bec_logger
|
||||||
from ophyd import Device
|
from ophyd import Device
|
||||||
from ophyd.ophydobj import OphydObject
|
from ophyd.ophydobj import OphydObject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from bec_server.device_server.device_server import DeviceManagerDS
|
||||||
|
|
||||||
|
from ophyd_devices.utils.socket import SocketIO
|
||||||
|
|
||||||
|
|
||||||
logger = bec_logger.logger
|
logger = bec_logger.logger
|
||||||
|
|
||||||
|
|
||||||
@@ -59,7 +66,16 @@ def axis_checked(fcn):
|
|||||||
|
|
||||||
|
|
||||||
class Controller(OphydObject):
|
class Controller(OphydObject):
|
||||||
"""Base class for all socker-based controllers."""
|
"""
|
||||||
|
Base class for all socket-based controllers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str, optional): Name of the controller
|
||||||
|
socket_cls (Type[SocketIO]): Socket class to use for communication
|
||||||
|
socket_host (str): Hostname or IP address of the controller
|
||||||
|
socket_port (int): Port number of the controller
|
||||||
|
device_manager (DeviceManagerDS): Device manager instance
|
||||||
|
"""
|
||||||
|
|
||||||
_controller_instances = {}
|
_controller_instances = {}
|
||||||
_initialized = False
|
_initialized = False
|
||||||
@@ -70,10 +86,11 @@ class Controller(OphydObject):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
name=None,
|
socket_cls: Type["SocketIO"],
|
||||||
socket_cls=None,
|
socket_host: str,
|
||||||
socket_host=None,
|
socket_port: int,
|
||||||
socket_port=None,
|
device_manager: "DeviceManagerDS",
|
||||||
|
name: str = "",
|
||||||
attr_name="",
|
attr_name="",
|
||||||
parent=None,
|
parent=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
@@ -84,10 +101,11 @@ class Controller(OphydObject):
|
|||||||
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
|
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
|
||||||
)
|
)
|
||||||
self._lock = threading.RLock()
|
self._lock = threading.RLock()
|
||||||
self._axis = []
|
self._axis: list[Device] = []
|
||||||
self._initialize()
|
self._initialize()
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
self.sock = None
|
self.sock = None
|
||||||
|
self.dm = device_manager
|
||||||
self._socket_cls = socket_cls
|
self._socket_cls = socket_cls
|
||||||
self._socket_host = socket_host
|
self._socket_host = socket_host
|
||||||
self._socket_port = socket_port
|
self._socket_port = socket_port
|
||||||
@@ -127,15 +145,6 @@ class Controller(OphydObject):
|
|||||||
return var.split("\r\n")[0]
|
return var.split("\r\n")[0]
|
||||||
return var
|
return var
|
||||||
|
|
||||||
def get_device_manager(self):
|
|
||||||
"""
|
|
||||||
Helper function to get the device manager.
|
|
||||||
"""
|
|
||||||
for axis in self._axis:
|
|
||||||
if hasattr(axis, "device_manager") and axis.device_manager:
|
|
||||||
return axis.device_manager
|
|
||||||
raise ControllerError("Could not access the device_manager")
|
|
||||||
|
|
||||||
def get_axis_by_name(self, name: str) -> Device:
|
def get_axis_by_name(self, name: str) -> Device:
|
||||||
"""
|
"""
|
||||||
Get an axis by name.
|
Get an axis by name.
|
||||||
@@ -152,21 +161,58 @@ class Controller(OphydObject):
|
|||||||
return axis
|
return axis
|
||||||
raise RuntimeError(f"Could not find an axis with name {name}")
|
raise RuntimeError(f"Could not find an axis with name {name}")
|
||||||
|
|
||||||
def set_device_enabled(self, device_name: str, enabled: bool) -> None:
|
def set_device_read_write(self, device_name: str, enabled: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Enable or disable a device for write access.
|
Change the read-only status of a device.
|
||||||
|
If the device is not configured, a warning is logged.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_name (str): Name of the device
|
||||||
|
enabled (bool): Set device to read-only or writable
|
||||||
|
"""
|
||||||
|
if device_name not in self.dm.devices:
|
||||||
|
logger.warning(
|
||||||
|
f"Device {device_name} is not available on the device manager, cannot be set to read-only: {not enabled}."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
self.dm.devices[device_name].read_only = not enabled
|
||||||
|
|
||||||
|
def set_device_enable(self, device_name: str, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Enable/disable a device for write access.
|
||||||
If the device is not configured, a warning is logged.
|
If the device is not configured, a warning is logged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
device_name (str): Name of the device
|
device_name (str): Name of the device
|
||||||
enabled (bool): Enable or disable the device
|
enabled (bool): Enable or disable the device
|
||||||
"""
|
"""
|
||||||
if device_name not in self.get_device_manager().devices:
|
if device_name not in self.dm.devices:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Device {device_name} is not configured and cannot be enabled/disabled."
|
f"Device {device_name} is not available on the device manager, cannot be set to enabled: {enabled}."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self.get_device_manager().devices[device_name].read_only = not enabled
|
self.dm.devices[device_name].enabled = enabled
|
||||||
|
if enabled:
|
||||||
|
self.on()
|
||||||
|
else:
|
||||||
|
all_disabled = all(
|
||||||
|
not self.dm.devices[axis.name].enabled for axis in self._axis if axis is not None
|
||||||
|
)
|
||||||
|
if all_disabled:
|
||||||
|
self.off()
|
||||||
|
|
||||||
|
def set_all_devices_enable(self, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Enable or disable all devices registered for the controller.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled (bool): Enable or disable all devices
|
||||||
|
"""
|
||||||
|
for axis in self._axis:
|
||||||
|
if axis is None:
|
||||||
|
logger.info("Axis is not assigned, skipping enabling/disabling.")
|
||||||
|
continue
|
||||||
|
self.set_device_enable(axis.name, enabled)
|
||||||
|
|
||||||
def _initialize(self):
|
def _initialize(self):
|
||||||
self._connected = False
|
self._connected = False
|
||||||
@@ -220,11 +266,16 @@ class Controller(OphydObject):
|
|||||||
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
|
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on(self) -> None:
|
def on(self, timeout: int = 10) -> None:
|
||||||
"""Open a new socket connection to the controller"""
|
"""
|
||||||
|
Open a new socket connection to the controller
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout (int): Time in seconds to wait for connection
|
||||||
|
"""
|
||||||
if not self.connected or self.sock is None:
|
if not self.connected or self.sock is None:
|
||||||
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
|
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
|
||||||
self.sock.open()
|
self.sock.open(timeout=timeout)
|
||||||
self.connected = True
|
self.connected = True
|
||||||
else:
|
else:
|
||||||
logger.info("The connection has already been established.")
|
logger.info("The connection has already been established.")
|
||||||
@@ -235,6 +286,8 @@ class Controller(OphydObject):
|
|||||||
self.sock.close()
|
self.sock.close()
|
||||||
self.connected = False
|
self.connected = False
|
||||||
self.sock = None
|
self.sock = None
|
||||||
|
# Disable all axes associated with this controller
|
||||||
|
self.set_all_devices_enable(False)
|
||||||
else:
|
else:
|
||||||
logger.info("The connection is already closed.")
|
logger.info("The connection is already closed.")
|
||||||
|
|
||||||
@@ -242,12 +295,15 @@ class Controller(OphydObject):
|
|||||||
socket_cls = kwargs.get("socket_cls")
|
socket_cls = kwargs.get("socket_cls")
|
||||||
socket_host = kwargs.get("socket_host")
|
socket_host = kwargs.get("socket_host")
|
||||||
socket_port = kwargs.get("socket_port")
|
socket_port = kwargs.get("socket_port")
|
||||||
|
device_manager = kwargs.get("device_manager")
|
||||||
if not socket_cls:
|
if not socket_cls:
|
||||||
raise RuntimeError("Socket class must be specified.")
|
raise RuntimeError("Socket class must be specified.")
|
||||||
if not socket_host:
|
if not socket_host:
|
||||||
raise RuntimeError("Socket host must be specified.")
|
raise RuntimeError("Socket host must be specified.")
|
||||||
if not socket_port:
|
if not socket_port:
|
||||||
raise RuntimeError("Socket port must be specified.")
|
raise RuntimeError("Socket port must be specified.")
|
||||||
|
if not device_manager:
|
||||||
|
raise RuntimeError("Device manager must be specified.")
|
||||||
host_port = f"{socket_host}:{socket_port}"
|
host_port = f"{socket_host}:{socket_port}"
|
||||||
if host_port not in cls._controller_instances:
|
if host_port not in cls._controller_instances:
|
||||||
cls._controller_instances[host_port] = object.__new__(cls)
|
cls._controller_instances[host_port] = object.__new__(cls)
|
||||||
|
|||||||
@@ -173,18 +173,23 @@ class SocketSignal(abc.ABC, Signal):
|
|||||||
class SocketIO:
|
class SocketIO:
|
||||||
"""SocketIO helper class for TCP IP connections"""
|
"""SocketIO helper class for TCP IP connections"""
|
||||||
|
|
||||||
def __init__(self, host, port, max_retry=10):
|
def __init__(self, host: str, port: int, socket_timeout: int = 2):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.is_open = False
|
self.is_open = False
|
||||||
self.max_retry = max_retry
|
self.socket_timeout = socket_timeout
|
||||||
self._initialize_socket()
|
self._initialize_socket()
|
||||||
|
|
||||||
def connect(self):
|
def connect(self, timeout: int = 10):
|
||||||
print(f"connecting to {self.host} port {self.port}")
|
"""
|
||||||
# self.sock.create_connection((host, port))
|
Establish socket connection to host:port within timeout period
|
||||||
retry_count = 0
|
|
||||||
while True:
|
Args:
|
||||||
|
timeout (int): Time in seconds to wait for connection
|
||||||
|
"""
|
||||||
|
logger.info(f"Connecting to {self.host}:{self.port}")
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
try:
|
try:
|
||||||
if self.sock is None:
|
if self.sock is None:
|
||||||
self._initialize_socket()
|
self._initialize_socket()
|
||||||
@@ -192,10 +197,12 @@ class SocketIO:
|
|||||||
break
|
break
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.sock = None
|
self.sock = None
|
||||||
time.sleep(2)
|
logger.warning(f"Connection failed, retrying after 0.2 seconds... {exc}")
|
||||||
retry_count += 1
|
time.sleep(1)
|
||||||
if retry_count > self.max_retry:
|
else:
|
||||||
raise exc
|
raise ConnectionError(
|
||||||
|
f"Could not connect to {self.host}:{self.port} within {time.time()-start_time} seconds"
|
||||||
|
)
|
||||||
|
|
||||||
def _put(self, msg_bytes):
|
def _put(self, msg_bytes):
|
||||||
logger.debug(f"put message: {msg_bytes}")
|
logger.debug(f"put message: {msg_bytes}")
|
||||||
@@ -208,7 +215,7 @@ class SocketIO:
|
|||||||
|
|
||||||
def _initialize_socket(self):
|
def _initialize_socket(self):
|
||||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
self.sock.settimeout(5)
|
self.sock.settimeout(self.socket_timeout)
|
||||||
|
|
||||||
def put(self, msg):
|
def put(self, msg):
|
||||||
return self._put(msg)
|
return self._put(msg)
|
||||||
@@ -216,8 +223,14 @@ class SocketIO:
|
|||||||
def receive(self, buffer_length=1024):
|
def receive(self, buffer_length=1024):
|
||||||
return self._recv(buffer_length=buffer_length)
|
return self._recv(buffer_length=buffer_length)
|
||||||
|
|
||||||
def open(self):
|
def open(self, timeout: int = 10):
|
||||||
self.connect()
|
""" "
|
||||||
|
Open the socket connection to the host:port
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout (int): Time in seconds to wait for connection
|
||||||
|
"""
|
||||||
|
self.connect(timeout=timeout)
|
||||||
self.is_open = True
|
self.is_open = True
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
@@ -235,7 +248,7 @@ class SocketMock:
|
|||||||
self.is_open = False
|
self.is_open = False
|
||||||
# self.open()
|
# self.open()
|
||||||
|
|
||||||
def connect(self):
|
def connect(self, timeout: int = 10):
|
||||||
print(f"connecting to {self.host} port {self.port}")
|
print(f"connecting to {self.host} port {self.port}")
|
||||||
|
|
||||||
def _put(self, msg_bytes):
|
def _put(self, msg_bytes):
|
||||||
@@ -261,7 +274,7 @@ class SocketMock:
|
|||||||
def receive(self, buffer_length=1024):
|
def receive(self, buffer_length=1024):
|
||||||
return self._recv(buffer_length=buffer_length)
|
return self._recv(buffer_length=buffer_length)
|
||||||
|
|
||||||
def open(self):
|
def open(self, timeout: int = 10):
|
||||||
self._initialize_socket()
|
self._initialize_socket()
|
||||||
self.is_open = True
|
self.is_open = True
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,13 @@ from unittest import mock
|
|||||||
from ophyd_devices.utils.controller import Controller
|
from ophyd_devices.utils.controller import Controller
|
||||||
|
|
||||||
|
|
||||||
def test_controller_off():
|
def test_controller_off(dm_with_devices):
|
||||||
controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123)
|
controller = Controller(
|
||||||
|
socket_cls=mock.MagicMock(),
|
||||||
|
socket_host="dummy",
|
||||||
|
socket_port=123,
|
||||||
|
device_manager=dm_with_devices,
|
||||||
|
)
|
||||||
controller.on()
|
controller.on()
|
||||||
with mock.patch.object(controller.sock, "close") as mock_close:
|
with mock.patch.object(controller.sock, "close") as mock_close:
|
||||||
controller.off()
|
controller.off()
|
||||||
@@ -17,10 +22,12 @@ def test_controller_off():
|
|||||||
controller._reset_controller()
|
controller._reset_controller()
|
||||||
|
|
||||||
|
|
||||||
def test_controller_on():
|
def test_controller_on(dm_with_devices):
|
||||||
socket_cls = mock.MagicMock()
|
socket_cls = mock.MagicMock()
|
||||||
Controller._controller_instances = {}
|
Controller._controller_instances = {}
|
||||||
controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123)
|
controller = Controller(
|
||||||
|
socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices
|
||||||
|
)
|
||||||
controller.on()
|
controller.on()
|
||||||
assert controller.sock is not None
|
assert controller.sock is not None
|
||||||
assert controller.connected is True
|
assert controller.connected is True
|
||||||
@@ -30,3 +37,39 @@ def test_controller_on():
|
|||||||
controller.on()
|
controller.on()
|
||||||
socket_cls().open.assert_called_once()
|
socket_cls().open.assert_called_once()
|
||||||
controller._reset_controller()
|
controller._reset_controller()
|
||||||
|
|
||||||
|
|
||||||
|
def test_controller_with_multiple_axes(dm_with_devices):
|
||||||
|
"""Test that turning the controller on and off enables/disables all axes attached to it."""
|
||||||
|
socket_cls = mock.MagicMock()
|
||||||
|
Controller._controller_instances = {}
|
||||||
|
Controller._axes_per_controller = 2
|
||||||
|
controller = Controller(
|
||||||
|
socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices
|
||||||
|
)
|
||||||
|
with mock.patch.object(controller.dm, "config_helper") as mock_config_helper:
|
||||||
|
# Disable samx, samy first
|
||||||
|
dm_with_devices.devices.get("samx").enabled = False
|
||||||
|
dm_with_devices.devices.get("samy").enabled = False
|
||||||
|
# Set axes on the controller
|
||||||
|
controller.set_axis(axis=dm_with_devices.devices["samx"], axis_nr=0)
|
||||||
|
controller.set_axis(axis=dm_with_devices.devices["samy"], axis_nr=1)
|
||||||
|
# Turn the controller on, should turn the controller on, but not enable the axes
|
||||||
|
controller.on()
|
||||||
|
assert dm_with_devices.devices.get("samx").enabled is False
|
||||||
|
assert dm_with_devices.devices.get("samy").enabled is False
|
||||||
|
assert controller.connected is True
|
||||||
|
controller.set_all_devices_enable(True)
|
||||||
|
assert dm_with_devices.devices.get("samx").enabled is True
|
||||||
|
assert dm_with_devices.devices.get("samy").enabled is True
|
||||||
|
# Disable one axis after another, the last one should turn the controller off
|
||||||
|
controller.set_device_enable("samx", False)
|
||||||
|
assert controller.connected is True
|
||||||
|
assert dm_with_devices.devices.get("samx").enabled is False
|
||||||
|
assert dm_with_devices.devices.get("samy").enabled is True
|
||||||
|
controller.set_device_enable("samy", False)
|
||||||
|
assert dm_with_devices.devices.get("samy").enabled is False
|
||||||
|
# Enabling one axis should turn the controller back on
|
||||||
|
assert controller.connected is False
|
||||||
|
controller.set_device_enable("samx", True)
|
||||||
|
assert controller.connected is True
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
import socket
|
import socket
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from ophyd_devices.utils.socket import SocketIO
|
from ophyd_devices.utils.socket import SocketIO
|
||||||
|
|
||||||
@@ -62,6 +65,21 @@ def test_open():
|
|||||||
assert socketio.sock.port == socketio.port
|
assert socketio.sock.port == socketio.port
|
||||||
|
|
||||||
|
|
||||||
|
def test_socket_open_with_timeout():
|
||||||
|
dsocket = DummySocket()
|
||||||
|
socketio = SocketIO("localhost", 8080)
|
||||||
|
socketio.sock = dsocket
|
||||||
|
with mock.patch.object(dsocket, "connect") as mock_connect:
|
||||||
|
socketio.open(timeout=0.1)
|
||||||
|
mock_connect.assert_called_once()
|
||||||
|
mock_connect.reset_mock()
|
||||||
|
# There is a 1s sleep in the retry loop, mock_connect should be called only once
|
||||||
|
mock_connect.side_effect = Exception("Connection failed")
|
||||||
|
with pytest.raises(ConnectionError):
|
||||||
|
socketio.open(timeout=0.4)
|
||||||
|
mock_connect.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_close():
|
def test_close():
|
||||||
socketio = SocketIO("localhost", 8080)
|
socketio = SocketIO("localhost", 8080)
|
||||||
socketio.close()
|
socketio.close()
|
||||||
|
|||||||
Reference in New Issue
Block a user