fix(controller): add configurable timeout, en/disable controller axes on on/off

This commit is contained in:
2025-11-25 09:34:06 +01:00
committed by Christian Appel
parent 69f7a353cf
commit 2fb64e995e
6 changed files with 176 additions and 46 deletions

View File

@@ -130,7 +130,7 @@ class SimDeviceWithStatusStageUnstage(Device):
class SynController(OphydObject):
def on(self):
def on(self, timeout: int = 10):
pass
def off(self):

View File

@@ -84,7 +84,7 @@ class SocketMock:
self.sock = None
self.open()
def connect(self):
def connect(self, timeout: int = 10):
"""Mock connect method"""
print(f"connecting to {self.host} port {self.port}")
@@ -116,7 +116,7 @@ class SocketMock:
"""Mock receive method"""
return self._recv(buffer_length=buffer_length)
def open(self):
def open(self, timeout: int = 10):
"""Mock open method"""
self._initialize_socket()
self.is_open = True

View File

@@ -1,10 +1,17 @@
import functools
import threading
from typing import TYPE_CHECKING, Type
from bec_lib import bec_logger
from ophyd import Device
from ophyd.ophydobj import OphydObject
if TYPE_CHECKING:
from bec_server.device_server.device_server import DeviceManagerDS
from ophyd_devices.utils.socket import SocketIO
logger = bec_logger.logger
@@ -59,7 +66,16 @@ def axis_checked(fcn):
class Controller(OphydObject):
"""Base class for all socker-based controllers."""
"""
Base class for all socket-based controllers.
Args:
name (str, optional): Name of the controller
socket_cls (Type[SocketIO]): Socket class to use for communication
socket_host (str): Hostname or IP address of the controller
socket_port (int): Port number of the controller
device_manager (DeviceManagerDS): Device manager instance
"""
_controller_instances = {}
_initialized = False
@@ -70,10 +86,11 @@ class Controller(OphydObject):
def __init__(
self,
*,
name=None,
socket_cls=None,
socket_host=None,
socket_port=None,
socket_cls: Type["SocketIO"],
socket_host: str,
socket_port: int,
device_manager: "DeviceManagerDS",
name: str = "",
attr_name="",
parent=None,
labels=None,
@@ -84,10 +101,11 @@ class Controller(OphydObject):
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
)
self._lock = threading.RLock()
self._axis = []
self._axis: list[Device] = []
self._initialize()
self._initialized = True
self.sock = None
self.dm = device_manager
self._socket_cls = socket_cls
self._socket_host = socket_host
self._socket_port = socket_port
@@ -127,15 +145,6 @@ class Controller(OphydObject):
return var.split("\r\n")[0]
return var
def get_device_manager(self):
"""
Helper function to get the device manager.
"""
for axis in self._axis:
if hasattr(axis, "device_manager") and axis.device_manager:
return axis.device_manager
raise ControllerError("Could not access the device_manager")
def get_axis_by_name(self, name: str) -> Device:
"""
Get an axis by name.
@@ -152,21 +161,58 @@ class Controller(OphydObject):
return axis
raise RuntimeError(f"Could not find an axis with name {name}")
def set_device_enabled(self, device_name: str, enabled: bool) -> None:
def set_device_read_write(self, device_name: str, enabled: bool) -> None:
"""
Enable or disable a device for write access.
Change the read-only status of a device.
If the device is not configured, a warning is logged.
Args:
device_name (str): Name of the device
enabled (bool): Set device to read-only or writable
"""
if device_name not in self.dm.devices:
logger.warning(
f"Device {device_name} is not available on the device manager, cannot be set to read-only: {not enabled}."
)
return
self.dm.devices[device_name].read_only = not enabled
def set_device_enable(self, device_name: str, enabled: bool) -> None:
"""
Enable/disable a device for write access.
If the device is not configured, a warning is logged.
Args:
device_name (str): Name of the device
enabled (bool): Enable or disable the device
"""
if device_name not in self.get_device_manager().devices:
if device_name not in self.dm.devices:
logger.warning(
f"Device {device_name} is not configured and cannot be enabled/disabled."
f"Device {device_name} is not available on the device manager, cannot be set to enabled: {enabled}."
)
return
self.get_device_manager().devices[device_name].read_only = not enabled
self.dm.devices[device_name].enabled = enabled
if enabled:
self.on()
else:
all_disabled = all(
not self.dm.devices[axis.name].enabled for axis in self._axis if axis is not None
)
if all_disabled:
self.off()
def set_all_devices_enable(self, enabled: bool) -> None:
"""
Enable or disable all devices registered for the controller.
Args:
enabled (bool): Enable or disable all devices
"""
for axis in self._axis:
if axis is None:
logger.info("Axis is not assigned, skipping enabling/disabling.")
continue
self.set_device_enable(axis.name, enabled)
def _initialize(self):
self._connected = False
@@ -220,11 +266,16 @@ class Controller(OphydObject):
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
)
def on(self) -> None:
"""Open a new socket connection to the controller"""
def on(self, timeout: int = 10) -> None:
"""
Open a new socket connection to the controller
Args:
timeout (int): Time in seconds to wait for connection
"""
if not self.connected or self.sock is None:
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
self.sock.open()
self.sock.open(timeout=timeout)
self.connected = True
else:
logger.info("The connection has already been established.")
@@ -235,6 +286,8 @@ class Controller(OphydObject):
self.sock.close()
self.connected = False
self.sock = None
# Disable all axes associated with this controller
self.set_all_devices_enable(False)
else:
logger.info("The connection is already closed.")
@@ -242,12 +295,15 @@ class Controller(OphydObject):
socket_cls = kwargs.get("socket_cls")
socket_host = kwargs.get("socket_host")
socket_port = kwargs.get("socket_port")
device_manager = kwargs.get("device_manager")
if not socket_cls:
raise RuntimeError("Socket class must be specified.")
if not socket_host:
raise RuntimeError("Socket host must be specified.")
if not socket_port:
raise RuntimeError("Socket port must be specified.")
if not device_manager:
raise RuntimeError("Device manager must be specified.")
host_port = f"{socket_host}:{socket_port}"
if host_port not in cls._controller_instances:
cls._controller_instances[host_port] = object.__new__(cls)

View File

@@ -173,18 +173,23 @@ class SocketSignal(abc.ABC, Signal):
class SocketIO:
"""SocketIO helper class for TCP IP connections"""
def __init__(self, host, port, max_retry=10):
def __init__(self, host: str, port: int, socket_timeout: int = 2):
self.host = host
self.port = port
self.is_open = False
self.max_retry = max_retry
self.socket_timeout = socket_timeout
self._initialize_socket()
def connect(self):
print(f"connecting to {self.host} port {self.port}")
# self.sock.create_connection((host, port))
retry_count = 0
while True:
def connect(self, timeout: int = 10):
"""
Establish socket connection to host:port within timeout period
Args:
timeout (int): Time in seconds to wait for connection
"""
logger.info(f"Connecting to {self.host}:{self.port}")
start_time = time.time()
while time.time() - start_time < timeout:
try:
if self.sock is None:
self._initialize_socket()
@@ -192,10 +197,12 @@ class SocketIO:
break
except Exception as exc:
self.sock = None
time.sleep(2)
retry_count += 1
if retry_count > self.max_retry:
raise exc
logger.warning(f"Connection failed, retrying after 0.2 seconds... {exc}")
time.sleep(1)
else:
raise ConnectionError(
f"Could not connect to {self.host}:{self.port} within {time.time()-start_time} seconds"
)
def _put(self, msg_bytes):
logger.debug(f"put message: {msg_bytes}")
@@ -208,7 +215,7 @@ class SocketIO:
def _initialize_socket(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(5)
self.sock.settimeout(self.socket_timeout)
def put(self, msg):
return self._put(msg)
@@ -216,8 +223,14 @@ class SocketIO:
def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length)
def open(self):
self.connect()
def open(self, timeout: int = 10):
""" "
Open the socket connection to the host:port
Args:
timeout (int): Time in seconds to wait for connection
"""
self.connect(timeout=timeout)
self.is_open = True
def close(self):
@@ -235,7 +248,7 @@ class SocketMock:
self.is_open = False
# self.open()
def connect(self):
def connect(self, timeout: int = 10):
print(f"connecting to {self.host} port {self.port}")
def _put(self, msg_bytes):
@@ -261,7 +274,7 @@ class SocketMock:
def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length)
def open(self):
def open(self, timeout: int = 10):
self._initialize_socket()
self.is_open = True