From 2fb64e995e417d5791295dfd1961cdfad2faa6c1 Mon Sep 17 00:00:00 2001 From: appel_c Date: Tue, 25 Nov 2025 09:34:06 +0100 Subject: [PATCH] fix(controller): add configurable timeout, en/disable controller axes on on/off --- ophyd_devices/sim/sim_test_devices.py | 2 +- ophyd_devices/tests/utils.py | 4 +- ophyd_devices/utils/controller.py | 102 ++++++++++++++++++++------ ophyd_devices/utils/socket.py | 45 ++++++++---- tests/test_controller.py | 51 ++++++++++++- tests/test_socket.py | 18 +++++ 6 files changed, 176 insertions(+), 46 deletions(-) diff --git a/ophyd_devices/sim/sim_test_devices.py b/ophyd_devices/sim/sim_test_devices.py index e3cdd45..635b095 100644 --- a/ophyd_devices/sim/sim_test_devices.py +++ b/ophyd_devices/sim/sim_test_devices.py @@ -130,7 +130,7 @@ class SimDeviceWithStatusStageUnstage(Device): class SynController(OphydObject): - def on(self): + def on(self, timeout: int = 10): pass def off(self): diff --git a/ophyd_devices/tests/utils.py b/ophyd_devices/tests/utils.py index ca7898c..fb29171 100644 --- a/ophyd_devices/tests/utils.py +++ b/ophyd_devices/tests/utils.py @@ -84,7 +84,7 @@ class SocketMock: self.sock = None self.open() - def connect(self): + def connect(self, timeout: int = 10): """Mock connect method""" print(f"connecting to {self.host} port {self.port}") @@ -116,7 +116,7 @@ class SocketMock: """Mock receive method""" return self._recv(buffer_length=buffer_length) - def open(self): + def open(self, timeout: int = 10): """Mock open method""" self._initialize_socket() self.is_open = True diff --git a/ophyd_devices/utils/controller.py b/ophyd_devices/utils/controller.py index 67f293b..b227e1a 100644 --- a/ophyd_devices/utils/controller.py +++ b/ophyd_devices/utils/controller.py @@ -1,10 +1,17 @@ import functools import threading +from typing import TYPE_CHECKING, Type from bec_lib import bec_logger from ophyd import Device from ophyd.ophydobj import OphydObject +if TYPE_CHECKING: + from bec_server.device_server.device_server import DeviceManagerDS + + from ophyd_devices.utils.socket import SocketIO + + logger = bec_logger.logger @@ -59,7 +66,16 @@ def axis_checked(fcn): class Controller(OphydObject): - """Base class for all socker-based controllers.""" + """ + Base class for all socket-based controllers. + + Args: + name (str, optional): Name of the controller + socket_cls (Type[SocketIO]): Socket class to use for communication + socket_host (str): Hostname or IP address of the controller + socket_port (int): Port number of the controller + device_manager (DeviceManagerDS): Device manager instance + """ _controller_instances = {} _initialized = False @@ -70,10 +86,11 @@ class Controller(OphydObject): def __init__( self, *, - name=None, - socket_cls=None, - socket_host=None, - socket_port=None, + socket_cls: Type["SocketIO"], + socket_host: str, + socket_port: int, + device_manager: "DeviceManagerDS", + name: str = "", attr_name="", parent=None, labels=None, @@ -84,10 +101,11 @@ class Controller(OphydObject): name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind ) self._lock = threading.RLock() - self._axis = [] + self._axis: list[Device] = [] self._initialize() self._initialized = True self.sock = None + self.dm = device_manager self._socket_cls = socket_cls self._socket_host = socket_host self._socket_port = socket_port @@ -127,15 +145,6 @@ class Controller(OphydObject): return var.split("\r\n")[0] return var - def get_device_manager(self): - """ - Helper function to get the device manager. - """ - for axis in self._axis: - if hasattr(axis, "device_manager") and axis.device_manager: - return axis.device_manager - raise ControllerError("Could not access the device_manager") - def get_axis_by_name(self, name: str) -> Device: """ Get an axis by name. @@ -152,21 +161,58 @@ class Controller(OphydObject): return axis raise RuntimeError(f"Could not find an axis with name {name}") - def set_device_enabled(self, device_name: str, enabled: bool) -> None: + def set_device_read_write(self, device_name: str, enabled: bool) -> None: """ - Enable or disable a device for write access. + Change the read-only status of a device. + If the device is not configured, a warning is logged. + + Args: + device_name (str): Name of the device + enabled (bool): Set device to read-only or writable + """ + if device_name not in self.dm.devices: + logger.warning( + f"Device {device_name} is not available on the device manager, cannot be set to read-only: {not enabled}." + ) + return + self.dm.devices[device_name].read_only = not enabled + + def set_device_enable(self, device_name: str, enabled: bool) -> None: + """ + Enable/disable a device for write access. If the device is not configured, a warning is logged. Args: device_name (str): Name of the device enabled (bool): Enable or disable the device """ - if device_name not in self.get_device_manager().devices: + if device_name not in self.dm.devices: logger.warning( - f"Device {device_name} is not configured and cannot be enabled/disabled." + f"Device {device_name} is not available on the device manager, cannot be set to enabled: {enabled}." ) return - self.get_device_manager().devices[device_name].read_only = not enabled + self.dm.devices[device_name].enabled = enabled + if enabled: + self.on() + else: + all_disabled = all( + not self.dm.devices[axis.name].enabled for axis in self._axis if axis is not None + ) + if all_disabled: + self.off() + + def set_all_devices_enable(self, enabled: bool) -> None: + """ + Enable or disable all devices registered for the controller. + + Args: + enabled (bool): Enable or disable all devices + """ + for axis in self._axis: + if axis is None: + logger.info("Axis is not assigned, skipping enabling/disabling.") + continue + self.set_device_enable(axis.name, enabled) def _initialize(self): self._connected = False @@ -220,11 +266,16 @@ class Controller(OphydObject): f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})" ) - def on(self) -> None: - """Open a new socket connection to the controller""" + def on(self, timeout: int = 10) -> None: + """ + Open a new socket connection to the controller + + Args: + timeout (int): Time in seconds to wait for connection + """ if not self.connected or self.sock is None: self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port) - self.sock.open() + self.sock.open(timeout=timeout) self.connected = True else: logger.info("The connection has already been established.") @@ -235,6 +286,8 @@ class Controller(OphydObject): self.sock.close() self.connected = False self.sock = None + # Disable all axes associated with this controller + self.set_all_devices_enable(False) else: logger.info("The connection is already closed.") @@ -242,12 +295,15 @@ class Controller(OphydObject): socket_cls = kwargs.get("socket_cls") socket_host = kwargs.get("socket_host") socket_port = kwargs.get("socket_port") + device_manager = kwargs.get("device_manager") if not socket_cls: raise RuntimeError("Socket class must be specified.") if not socket_host: raise RuntimeError("Socket host must be specified.") if not socket_port: raise RuntimeError("Socket port must be specified.") + if not device_manager: + raise RuntimeError("Device manager must be specified.") host_port = f"{socket_host}:{socket_port}" if host_port not in cls._controller_instances: cls._controller_instances[host_port] = object.__new__(cls) diff --git a/ophyd_devices/utils/socket.py b/ophyd_devices/utils/socket.py index 01b5bc9..5764cd6 100644 --- a/ophyd_devices/utils/socket.py +++ b/ophyd_devices/utils/socket.py @@ -173,18 +173,23 @@ class SocketSignal(abc.ABC, Signal): class SocketIO: """SocketIO helper class for TCP IP connections""" - def __init__(self, host, port, max_retry=10): + def __init__(self, host: str, port: int, socket_timeout: int = 2): self.host = host self.port = port self.is_open = False - self.max_retry = max_retry + self.socket_timeout = socket_timeout self._initialize_socket() - def connect(self): - print(f"connecting to {self.host} port {self.port}") - # self.sock.create_connection((host, port)) - retry_count = 0 - while True: + def connect(self, timeout: int = 10): + """ + Establish socket connection to host:port within timeout period + + Args: + timeout (int): Time in seconds to wait for connection + """ + logger.info(f"Connecting to {self.host}:{self.port}") + start_time = time.time() + while time.time() - start_time < timeout: try: if self.sock is None: self._initialize_socket() @@ -192,10 +197,12 @@ class SocketIO: break except Exception as exc: self.sock = None - time.sleep(2) - retry_count += 1 - if retry_count > self.max_retry: - raise exc + logger.warning(f"Connection failed, retrying after 0.2 seconds... {exc}") + time.sleep(1) + else: + raise ConnectionError( + f"Could not connect to {self.host}:{self.port} within {time.time()-start_time} seconds" + ) def _put(self, msg_bytes): logger.debug(f"put message: {msg_bytes}") @@ -208,7 +215,7 @@ class SocketIO: def _initialize_socket(self): self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(5) + self.sock.settimeout(self.socket_timeout) def put(self, msg): return self._put(msg) @@ -216,8 +223,14 @@ class SocketIO: def receive(self, buffer_length=1024): return self._recv(buffer_length=buffer_length) - def open(self): - self.connect() + def open(self, timeout: int = 10): + """ " + Open the socket connection to the host:port + + Args: + timeout (int): Time in seconds to wait for connection + """ + self.connect(timeout=timeout) self.is_open = True def close(self): @@ -235,7 +248,7 @@ class SocketMock: self.is_open = False # self.open() - def connect(self): + def connect(self, timeout: int = 10): print(f"connecting to {self.host} port {self.port}") def _put(self, msg_bytes): @@ -261,7 +274,7 @@ class SocketMock: def receive(self, buffer_length=1024): return self._recv(buffer_length=buffer_length) - def open(self): + def open(self, timeout: int = 10): self._initialize_socket() self.is_open = True diff --git a/tests/test_controller.py b/tests/test_controller.py index 6543587..bf32c63 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -3,8 +3,13 @@ from unittest import mock from ophyd_devices.utils.controller import Controller -def test_controller_off(): - controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123) +def test_controller_off(dm_with_devices): + controller = Controller( + socket_cls=mock.MagicMock(), + socket_host="dummy", + socket_port=123, + device_manager=dm_with_devices, + ) controller.on() with mock.patch.object(controller.sock, "close") as mock_close: controller.off() @@ -17,10 +22,12 @@ def test_controller_off(): controller._reset_controller() -def test_controller_on(): +def test_controller_on(dm_with_devices): socket_cls = mock.MagicMock() Controller._controller_instances = {} - controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123) + controller = Controller( + socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices + ) controller.on() assert controller.sock is not None assert controller.connected is True @@ -30,3 +37,39 @@ def test_controller_on(): controller.on() socket_cls().open.assert_called_once() controller._reset_controller() + + +def test_controller_with_multiple_axes(dm_with_devices): + """Test that turning the controller on and off enables/disables all axes attached to it.""" + socket_cls = mock.MagicMock() + Controller._controller_instances = {} + Controller._axes_per_controller = 2 + controller = Controller( + socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices + ) + with mock.patch.object(controller.dm, "config_helper") as mock_config_helper: + # Disable samx, samy first + dm_with_devices.devices.get("samx").enabled = False + dm_with_devices.devices.get("samy").enabled = False + # Set axes on the controller + controller.set_axis(axis=dm_with_devices.devices["samx"], axis_nr=0) + controller.set_axis(axis=dm_with_devices.devices["samy"], axis_nr=1) + # Turn the controller on, should turn the controller on, but not enable the axes + controller.on() + assert dm_with_devices.devices.get("samx").enabled is False + assert dm_with_devices.devices.get("samy").enabled is False + assert controller.connected is True + controller.set_all_devices_enable(True) + assert dm_with_devices.devices.get("samx").enabled is True + assert dm_with_devices.devices.get("samy").enabled is True + # Disable one axis after another, the last one should turn the controller off + controller.set_device_enable("samx", False) + assert controller.connected is True + assert dm_with_devices.devices.get("samx").enabled is False + assert dm_with_devices.devices.get("samy").enabled is True + controller.set_device_enable("samy", False) + assert dm_with_devices.devices.get("samy").enabled is False + # Enabling one axis should turn the controller back on + assert controller.connected is False + controller.set_device_enable("samx", True) + assert controller.connected is True diff --git a/tests/test_socket.py b/tests/test_socket.py index dc33c0d..17f7bcc 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,4 +1,7 @@ import socket +from unittest import mock + +import pytest from ophyd_devices.utils.socket import SocketIO @@ -62,6 +65,21 @@ def test_open(): assert socketio.sock.port == socketio.port +def test_socket_open_with_timeout(): + dsocket = DummySocket() + socketio = SocketIO("localhost", 8080) + socketio.sock = dsocket + with mock.patch.object(dsocket, "connect") as mock_connect: + socketio.open(timeout=0.1) + mock_connect.assert_called_once() + mock_connect.reset_mock() + # There is a 1s sleep in the retry loop, mock_connect should be called only once + mock_connect.side_effect = Exception("Connection failed") + with pytest.raises(ConnectionError): + socketio.open(timeout=0.4) + mock_connect.assert_called_once() + + def test_close(): socketio = SocketIO("localhost", 8080) socketio.close()