fix(controller): add configurable timeout, en/disable controller axes on on/off

This commit is contained in:
2025-11-25 09:34:06 +01:00
committed by Christian Appel
parent 69f7a353cf
commit 2fb64e995e
6 changed files with 176 additions and 46 deletions

View File

@@ -130,7 +130,7 @@ class SimDeviceWithStatusStageUnstage(Device):
class SynController(OphydObject): class SynController(OphydObject):
def on(self): def on(self, timeout: int = 10):
pass pass
def off(self): def off(self):

View File

@@ -84,7 +84,7 @@ class SocketMock:
self.sock = None self.sock = None
self.open() self.open()
def connect(self): def connect(self, timeout: int = 10):
"""Mock connect method""" """Mock connect method"""
print(f"connecting to {self.host} port {self.port}") print(f"connecting to {self.host} port {self.port}")
@@ -116,7 +116,7 @@ class SocketMock:
"""Mock receive method""" """Mock receive method"""
return self._recv(buffer_length=buffer_length) return self._recv(buffer_length=buffer_length)
def open(self): def open(self, timeout: int = 10):
"""Mock open method""" """Mock open method"""
self._initialize_socket() self._initialize_socket()
self.is_open = True self.is_open = True

View File

@@ -1,10 +1,17 @@
import functools import functools
import threading import threading
from typing import TYPE_CHECKING, Type
from bec_lib import bec_logger from bec_lib import bec_logger
from ophyd import Device from ophyd import Device
from ophyd.ophydobj import OphydObject from ophyd.ophydobj import OphydObject
if TYPE_CHECKING:
from bec_server.device_server.device_server import DeviceManagerDS
from ophyd_devices.utils.socket import SocketIO
logger = bec_logger.logger logger = bec_logger.logger
@@ -59,7 +66,16 @@ def axis_checked(fcn):
class Controller(OphydObject): class Controller(OphydObject):
"""Base class for all socker-based controllers.""" """
Base class for all socket-based controllers.
Args:
name (str, optional): Name of the controller
socket_cls (Type[SocketIO]): Socket class to use for communication
socket_host (str): Hostname or IP address of the controller
socket_port (int): Port number of the controller
device_manager (DeviceManagerDS): Device manager instance
"""
_controller_instances = {} _controller_instances = {}
_initialized = False _initialized = False
@@ -70,10 +86,11 @@ class Controller(OphydObject):
def __init__( def __init__(
self, self,
*, *,
name=None, socket_cls: Type["SocketIO"],
socket_cls=None, socket_host: str,
socket_host=None, socket_port: int,
socket_port=None, device_manager: "DeviceManagerDS",
name: str = "",
attr_name="", attr_name="",
parent=None, parent=None,
labels=None, labels=None,
@@ -84,10 +101,11 @@ class Controller(OphydObject):
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
) )
self._lock = threading.RLock() self._lock = threading.RLock()
self._axis = [] self._axis: list[Device] = []
self._initialize() self._initialize()
self._initialized = True self._initialized = True
self.sock = None self.sock = None
self.dm = device_manager
self._socket_cls = socket_cls self._socket_cls = socket_cls
self._socket_host = socket_host self._socket_host = socket_host
self._socket_port = socket_port self._socket_port = socket_port
@@ -127,15 +145,6 @@ class Controller(OphydObject):
return var.split("\r\n")[0] return var.split("\r\n")[0]
return var return var
def get_device_manager(self):
"""
Helper function to get the device manager.
"""
for axis in self._axis:
if hasattr(axis, "device_manager") and axis.device_manager:
return axis.device_manager
raise ControllerError("Could not access the device_manager")
def get_axis_by_name(self, name: str) -> Device: def get_axis_by_name(self, name: str) -> Device:
""" """
Get an axis by name. Get an axis by name.
@@ -152,21 +161,58 @@ class Controller(OphydObject):
return axis return axis
raise RuntimeError(f"Could not find an axis with name {name}") raise RuntimeError(f"Could not find an axis with name {name}")
def set_device_enabled(self, device_name: str, enabled: bool) -> None: def set_device_read_write(self, device_name: str, enabled: bool) -> None:
""" """
Enable or disable a device for write access. Change the read-only status of a device.
If the device is not configured, a warning is logged.
Args:
device_name (str): Name of the device
enabled (bool): Set device to read-only or writable
"""
if device_name not in self.dm.devices:
logger.warning(
f"Device {device_name} is not available on the device manager, cannot be set to read-only: {not enabled}."
)
return
self.dm.devices[device_name].read_only = not enabled
def set_device_enable(self, device_name: str, enabled: bool) -> None:
"""
Enable/disable a device for write access.
If the device is not configured, a warning is logged. If the device is not configured, a warning is logged.
Args: Args:
device_name (str): Name of the device device_name (str): Name of the device
enabled (bool): Enable or disable the device enabled (bool): Enable or disable the device
""" """
if device_name not in self.get_device_manager().devices: if device_name not in self.dm.devices:
logger.warning( logger.warning(
f"Device {device_name} is not configured and cannot be enabled/disabled." f"Device {device_name} is not available on the device manager, cannot be set to enabled: {enabled}."
) )
return return
self.get_device_manager().devices[device_name].read_only = not enabled self.dm.devices[device_name].enabled = enabled
if enabled:
self.on()
else:
all_disabled = all(
not self.dm.devices[axis.name].enabled for axis in self._axis if axis is not None
)
if all_disabled:
self.off()
def set_all_devices_enable(self, enabled: bool) -> None:
"""
Enable or disable all devices registered for the controller.
Args:
enabled (bool): Enable or disable all devices
"""
for axis in self._axis:
if axis is None:
logger.info("Axis is not assigned, skipping enabling/disabling.")
continue
self.set_device_enable(axis.name, enabled)
def _initialize(self): def _initialize(self):
self._connected = False self._connected = False
@@ -220,11 +266,16 @@ class Controller(OphydObject):
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})" f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
) )
def on(self) -> None: def on(self, timeout: int = 10) -> None:
"""Open a new socket connection to the controller""" """
Open a new socket connection to the controller
Args:
timeout (int): Time in seconds to wait for connection
"""
if not self.connected or self.sock is None: if not self.connected or self.sock is None:
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port) self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
self.sock.open() self.sock.open(timeout=timeout)
self.connected = True self.connected = True
else: else:
logger.info("The connection has already been established.") logger.info("The connection has already been established.")
@@ -235,6 +286,8 @@ class Controller(OphydObject):
self.sock.close() self.sock.close()
self.connected = False self.connected = False
self.sock = None self.sock = None
# Disable all axes associated with this controller
self.set_all_devices_enable(False)
else: else:
logger.info("The connection is already closed.") logger.info("The connection is already closed.")
@@ -242,12 +295,15 @@ class Controller(OphydObject):
socket_cls = kwargs.get("socket_cls") socket_cls = kwargs.get("socket_cls")
socket_host = kwargs.get("socket_host") socket_host = kwargs.get("socket_host")
socket_port = kwargs.get("socket_port") socket_port = kwargs.get("socket_port")
device_manager = kwargs.get("device_manager")
if not socket_cls: if not socket_cls:
raise RuntimeError("Socket class must be specified.") raise RuntimeError("Socket class must be specified.")
if not socket_host: if not socket_host:
raise RuntimeError("Socket host must be specified.") raise RuntimeError("Socket host must be specified.")
if not socket_port: if not socket_port:
raise RuntimeError("Socket port must be specified.") raise RuntimeError("Socket port must be specified.")
if not device_manager:
raise RuntimeError("Device manager must be specified.")
host_port = f"{socket_host}:{socket_port}" host_port = f"{socket_host}:{socket_port}"
if host_port not in cls._controller_instances: if host_port not in cls._controller_instances:
cls._controller_instances[host_port] = object.__new__(cls) cls._controller_instances[host_port] = object.__new__(cls)

View File

@@ -173,18 +173,23 @@ class SocketSignal(abc.ABC, Signal):
class SocketIO: class SocketIO:
"""SocketIO helper class for TCP IP connections""" """SocketIO helper class for TCP IP connections"""
def __init__(self, host, port, max_retry=10): def __init__(self, host: str, port: int, socket_timeout: int = 2):
self.host = host self.host = host
self.port = port self.port = port
self.is_open = False self.is_open = False
self.max_retry = max_retry self.socket_timeout = socket_timeout
self._initialize_socket() self._initialize_socket()
def connect(self): def connect(self, timeout: int = 10):
print(f"connecting to {self.host} port {self.port}") """
# self.sock.create_connection((host, port)) Establish socket connection to host:port within timeout period
retry_count = 0
while True: Args:
timeout (int): Time in seconds to wait for connection
"""
logger.info(f"Connecting to {self.host}:{self.port}")
start_time = time.time()
while time.time() - start_time < timeout:
try: try:
if self.sock is None: if self.sock is None:
self._initialize_socket() self._initialize_socket()
@@ -192,10 +197,12 @@ class SocketIO:
break break
except Exception as exc: except Exception as exc:
self.sock = None self.sock = None
time.sleep(2) logger.warning(f"Connection failed, retrying after 0.2 seconds... {exc}")
retry_count += 1 time.sleep(1)
if retry_count > self.max_retry: else:
raise exc raise ConnectionError(
f"Could not connect to {self.host}:{self.port} within {time.time()-start_time} seconds"
)
def _put(self, msg_bytes): def _put(self, msg_bytes):
logger.debug(f"put message: {msg_bytes}") logger.debug(f"put message: {msg_bytes}")
@@ -208,7 +215,7 @@ class SocketIO:
def _initialize_socket(self): def _initialize_socket(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(5) self.sock.settimeout(self.socket_timeout)
def put(self, msg): def put(self, msg):
return self._put(msg) return self._put(msg)
@@ -216,8 +223,14 @@ class SocketIO:
def receive(self, buffer_length=1024): def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length) return self._recv(buffer_length=buffer_length)
def open(self): def open(self, timeout: int = 10):
self.connect() """ "
Open the socket connection to the host:port
Args:
timeout (int): Time in seconds to wait for connection
"""
self.connect(timeout=timeout)
self.is_open = True self.is_open = True
def close(self): def close(self):
@@ -235,7 +248,7 @@ class SocketMock:
self.is_open = False self.is_open = False
# self.open() # self.open()
def connect(self): def connect(self, timeout: int = 10):
print(f"connecting to {self.host} port {self.port}") print(f"connecting to {self.host} port {self.port}")
def _put(self, msg_bytes): def _put(self, msg_bytes):
@@ -261,7 +274,7 @@ class SocketMock:
def receive(self, buffer_length=1024): def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length) return self._recv(buffer_length=buffer_length)
def open(self): def open(self, timeout: int = 10):
self._initialize_socket() self._initialize_socket()
self.is_open = True self.is_open = True

View File

@@ -3,8 +3,13 @@ from unittest import mock
from ophyd_devices.utils.controller import Controller from ophyd_devices.utils.controller import Controller
def test_controller_off(): def test_controller_off(dm_with_devices):
controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123) controller = Controller(
socket_cls=mock.MagicMock(),
socket_host="dummy",
socket_port=123,
device_manager=dm_with_devices,
)
controller.on() controller.on()
with mock.patch.object(controller.sock, "close") as mock_close: with mock.patch.object(controller.sock, "close") as mock_close:
controller.off() controller.off()
@@ -17,10 +22,12 @@ def test_controller_off():
controller._reset_controller() controller._reset_controller()
def test_controller_on(): def test_controller_on(dm_with_devices):
socket_cls = mock.MagicMock() socket_cls = mock.MagicMock()
Controller._controller_instances = {} Controller._controller_instances = {}
controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123) controller = Controller(
socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices
)
controller.on() controller.on()
assert controller.sock is not None assert controller.sock is not None
assert controller.connected is True assert controller.connected is True
@@ -30,3 +37,39 @@ def test_controller_on():
controller.on() controller.on()
socket_cls().open.assert_called_once() socket_cls().open.assert_called_once()
controller._reset_controller() controller._reset_controller()
def test_controller_with_multiple_axes(dm_with_devices):
"""Test that turning the controller on and off enables/disables all axes attached to it."""
socket_cls = mock.MagicMock()
Controller._controller_instances = {}
Controller._axes_per_controller = 2
controller = Controller(
socket_cls=socket_cls, socket_host="dummy", socket_port=123, device_manager=dm_with_devices
)
with mock.patch.object(controller.dm, "config_helper") as mock_config_helper:
# Disable samx, samy first
dm_with_devices.devices.get("samx").enabled = False
dm_with_devices.devices.get("samy").enabled = False
# Set axes on the controller
controller.set_axis(axis=dm_with_devices.devices["samx"], axis_nr=0)
controller.set_axis(axis=dm_with_devices.devices["samy"], axis_nr=1)
# Turn the controller on, should turn the controller on, but not enable the axes
controller.on()
assert dm_with_devices.devices.get("samx").enabled is False
assert dm_with_devices.devices.get("samy").enabled is False
assert controller.connected is True
controller.set_all_devices_enable(True)
assert dm_with_devices.devices.get("samx").enabled is True
assert dm_with_devices.devices.get("samy").enabled is True
# Disable one axis after another, the last one should turn the controller off
controller.set_device_enable("samx", False)
assert controller.connected is True
assert dm_with_devices.devices.get("samx").enabled is False
assert dm_with_devices.devices.get("samy").enabled is True
controller.set_device_enable("samy", False)
assert dm_with_devices.devices.get("samy").enabled is False
# Enabling one axis should turn the controller back on
assert controller.connected is False
controller.set_device_enable("samx", True)
assert controller.connected is True

View File

@@ -1,4 +1,7 @@
import socket import socket
from unittest import mock
import pytest
from ophyd_devices.utils.socket import SocketIO from ophyd_devices.utils.socket import SocketIO
@@ -62,6 +65,21 @@ def test_open():
assert socketio.sock.port == socketio.port assert socketio.sock.port == socketio.port
def test_socket_open_with_timeout():
dsocket = DummySocket()
socketio = SocketIO("localhost", 8080)
socketio.sock = dsocket
with mock.patch.object(dsocket, "connect") as mock_connect:
socketio.open(timeout=0.1)
mock_connect.assert_called_once()
mock_connect.reset_mock()
# There is a 1s sleep in the retry loop, mock_connect should be called only once
mock_connect.side_effect = Exception("Connection failed")
with pytest.raises(ConnectionError):
socketio.open(timeout=0.4)
mock_connect.assert_called_once()
def test_close(): def test_close():
socketio = SocketIO("localhost", 8080) socketio = SocketIO("localhost", 8080)
socketio.close() socketio.close()