mirror of
https://github.com/bec-project/ophyd_devices.git
synced 2026-02-02 21:54:55 +01:00
fix(controller): add configurable timeout, en/disable controller axes on on/off
This commit is contained in:
@@ -130,7 +130,7 @@ class SimDeviceWithStatusStageUnstage(Device):
|
||||
|
||||
|
||||
class SynController(OphydObject):
|
||||
def on(self):
|
||||
def on(self, timeout: int = 10):
|
||||
pass
|
||||
|
||||
def off(self):
|
||||
|
||||
@@ -84,7 +84,7 @@ class SocketMock:
|
||||
self.sock = None
|
||||
self.open()
|
||||
|
||||
def connect(self):
|
||||
def connect(self, timeout: int = 10):
|
||||
"""Mock connect method"""
|
||||
print(f"connecting to {self.host} port {self.port}")
|
||||
|
||||
@@ -116,7 +116,7 @@ class SocketMock:
|
||||
"""Mock receive method"""
|
||||
return self._recv(buffer_length=buffer_length)
|
||||
|
||||
def open(self):
|
||||
def open(self, timeout: int = 10):
|
||||
"""Mock open method"""
|
||||
self._initialize_socket()
|
||||
self.is_open = True
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import functools
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
from bec_lib import bec_logger
|
||||
from ophyd import Device
|
||||
from ophyd.ophydobj import OphydObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_server.device_server.device_server import DeviceManagerDS
|
||||
|
||||
from ophyd_devices.utils.socket import SocketIO
|
||||
|
||||
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
@@ -59,7 +66,16 @@ def axis_checked(fcn):
|
||||
|
||||
|
||||
class Controller(OphydObject):
|
||||
"""Base class for all socker-based controllers."""
|
||||
"""
|
||||
Base class for all socket-based controllers.
|
||||
|
||||
Args:
|
||||
name (str, optional): Name of the controller
|
||||
socket_cls (Type[SocketIO]): Socket class to use for communication
|
||||
socket_host (str): Hostname or IP address of the controller
|
||||
socket_port (int): Port number of the controller
|
||||
device_manager (DeviceManagerDS): Device manager instance
|
||||
"""
|
||||
|
||||
_controller_instances = {}
|
||||
_initialized = False
|
||||
@@ -70,10 +86,11 @@ class Controller(OphydObject):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name=None,
|
||||
socket_cls=None,
|
||||
socket_host=None,
|
||||
socket_port=None,
|
||||
socket_cls: Type["SocketIO"],
|
||||
socket_host: str,
|
||||
socket_port: int,
|
||||
device_manager: "DeviceManagerDS",
|
||||
name: str = "",
|
||||
attr_name="",
|
||||
parent=None,
|
||||
labels=None,
|
||||
@@ -84,10 +101,11 @@ class Controller(OphydObject):
|
||||
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
|
||||
)
|
||||
self._lock = threading.RLock()
|
||||
self._axis = []
|
||||
self._axis: list[Device] = []
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
self.sock = None
|
||||
self.dm = device_manager
|
||||
self._socket_cls = socket_cls
|
||||
self._socket_host = socket_host
|
||||
self._socket_port = socket_port
|
||||
@@ -127,15 +145,6 @@ class Controller(OphydObject):
|
||||
return var.split("\r\n")[0]
|
||||
return var
|
||||
|
||||
def get_device_manager(self):
|
||||
"""
|
||||
Helper function to get the device manager.
|
||||
"""
|
||||
for axis in self._axis:
|
||||
if hasattr(axis, "device_manager") and axis.device_manager:
|
||||
return axis.device_manager
|
||||
raise ControllerError("Could not access the device_manager")
|
||||
|
||||
def get_axis_by_name(self, name: str) -> Device:
|
||||
"""
|
||||
Get an axis by name.
|
||||
@@ -152,21 +161,58 @@ class Controller(OphydObject):
|
||||
return axis
|
||||
raise RuntimeError(f"Could not find an axis with name {name}")
|
||||
|
||||
def set_device_enabled(self, device_name: str, enabled: bool) -> None:
|
||||
def set_device_read_write(self, device_name: str, enabled: bool) -> None:
|
||||
"""
|
||||
Enable or disable a device for write access.
|
||||
Change the read-only status of a device.
|
||||
If the device is not configured, a warning is logged.
|
||||
|
||||
Args:
|
||||
device_name (str): Name of the device
|
||||
enabled (bool): Set device to read-only or writable
|
||||
"""
|
||||
if device_name not in self.dm.devices:
|
||||
logger.warning(
|
||||
f"Device {device_name} is not available on the device manager, cannot be set to read-only: {not enabled}."
|
||||
)
|
||||
return
|
||||
self.dm.devices[device_name].read_only = not enabled
|
||||
|
||||
def set_device_enable(self, device_name: str, enabled: bool) -> None:
|
||||
"""
|
||||
Enable/disable a device for write access.
|
||||
If the device is not configured, a warning is logged.
|
||||
|
||||
Args:
|
||||
device_name (str): Name of the device
|
||||
enabled (bool): Enable or disable the device
|
||||
"""
|
||||
if device_name not in self.get_device_manager().devices:
|
||||
if device_name not in self.dm.devices:
|
||||
logger.warning(
|
||||
f"Device {device_name} is not configured and cannot be enabled/disabled."
|
||||
f"Device {device_name} is not available on the device manager, cannot be set to enabled: {enabled}."
|
||||
)
|
||||
return
|
||||
self.get_device_manager().devices[device_name].read_only = not enabled
|
||||
self.dm.devices[device_name].enabled = enabled
|
||||
if enabled:
|
||||
self.on()
|
||||
else:
|
||||
all_disabled = all(
|
||||
not self.dm.devices[axis.name].enabled for axis in self._axis if axis is not None
|
||||
)
|
||||
if all_disabled:
|
||||
self.off()
|
||||
|
||||
def set_all_devices_enable(self, enabled: bool) -> None:
|
||||
"""
|
||||
Enable or disable all devices registered for the controller.
|
||||
|
||||
Args:
|
||||
enabled (bool): Enable or disable all devices
|
||||
"""
|
||||
for axis in self._axis:
|
||||
if axis is None:
|
||||
logger.info("Axis is not assigned, skipping enabling/disabling.")
|
||||
continue
|
||||
self.set_device_enable(axis.name, enabled)
|
||||
|
||||
def _initialize(self):
|
||||
self._connected = False
|
||||
@@ -220,11 +266,16 @@ class Controller(OphydObject):
|
||||
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
|
||||
)
|
||||
|
||||
def on(self) -> None:
|
||||
"""Open a new socket connection to the controller"""
|
||||
def on(self, timeout: int = 10) -> None:
|
||||
"""
|
||||
Open a new socket connection to the controller
|
||||
|
||||
Args:
|
||||
timeout (int): Time in seconds to wait for connection
|
||||
"""
|
||||
if not self.connected or self.sock is None:
|
||||
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
|
||||
self.sock.open()
|
||||
self.sock.open(timeout=timeout)
|
||||
self.connected = True
|
||||
else:
|
||||
logger.info("The connection has already been established.")
|
||||
@@ -235,6 +286,8 @@ class Controller(OphydObject):
|
||||
self.sock.close()
|
||||
self.connected = False
|
||||
self.sock = None
|
||||
# Disable all axes associated with this controller
|
||||
self.set_all_devices_enable(False)
|
||||
else:
|
||||
logger.info("The connection is already closed.")
|
||||
|
||||
@@ -242,12 +295,15 @@ class Controller(OphydObject):
|
||||
socket_cls = kwargs.get("socket_cls")
|
||||
socket_host = kwargs.get("socket_host")
|
||||
socket_port = kwargs.get("socket_port")
|
||||
device_manager = kwargs.get("device_manager")
|
||||
if not socket_cls:
|
||||
raise RuntimeError("Socket class must be specified.")
|
||||
if not socket_host:
|
||||
raise RuntimeError("Socket host must be specified.")
|
||||
if not socket_port:
|
||||
raise RuntimeError("Socket port must be specified.")
|
||||
if not device_manager:
|
||||
raise RuntimeError("Device manager must be specified.")
|
||||
host_port = f"{socket_host}:{socket_port}"
|
||||
if host_port not in cls._controller_instances:
|
||||
cls._controller_instances[host_port] = object.__new__(cls)
|
||||
|
||||
@@ -173,18 +173,23 @@ class SocketSignal(abc.ABC, Signal):
|
||||
class SocketIO:
|
||||
"""SocketIO helper class for TCP IP connections"""
|
||||
|
||||
def __init__(self, host, port, max_retry=10):
|
||||
def __init__(self, host: str, port: int, socket_timeout: int = 2):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.is_open = False
|
||||
self.max_retry = max_retry
|
||||
self.socket_timeout = socket_timeout
|
||||
self._initialize_socket()
|
||||
|
||||
def connect(self):
|
||||
print(f"connecting to {self.host} port {self.port}")
|
||||
# self.sock.create_connection((host, port))
|
||||
retry_count = 0
|
||||
while True:
|
||||
def connect(self, timeout: int = 10):
|
||||
"""
|
||||
Establish socket connection to host:port within timeout period
|
||||
|
||||
Args:
|
||||
timeout (int): Time in seconds to wait for connection
|
||||
"""
|
||||
logger.info(f"Connecting to {self.host}:{self.port}")
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
if self.sock is None:
|
||||
self._initialize_socket()
|
||||
@@ -192,10 +197,12 @@ class SocketIO:
|
||||
break
|
||||
except Exception as exc:
|
||||
self.sock = None
|
||||
time.sleep(2)
|
||||
retry_count += 1
|
||||
if retry_count > self.max_retry:
|
||||
raise exc
|
||||
logger.warning(f"Connection failed, retrying after 0.2 seconds... {exc}")
|
||||
time.sleep(1)
|
||||
else:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to {self.host}:{self.port} within {time.time()-start_time} seconds"
|
||||
)
|
||||
|
||||
def _put(self, msg_bytes):
|
||||
logger.debug(f"put message: {msg_bytes}")
|
||||
@@ -208,7 +215,7 @@ class SocketIO:
|
||||
|
||||
def _initialize_socket(self):
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.settimeout(5)
|
||||
self.sock.settimeout(self.socket_timeout)
|
||||
|
||||
def put(self, msg):
|
||||
return self._put(msg)
|
||||
@@ -216,8 +223,14 @@ class SocketIO:
|
||||
def receive(self, buffer_length=1024):
|
||||
return self._recv(buffer_length=buffer_length)
|
||||
|
||||
def open(self):
|
||||
self.connect()
|
||||
def open(self, timeout: int = 10):
|
||||
""" "
|
||||
Open the socket connection to the host:port
|
||||
|
||||
Args:
|
||||
timeout (int): Time in seconds to wait for connection
|
||||
"""
|
||||
self.connect(timeout=timeout)
|
||||
self.is_open = True
|
||||
|
||||
def close(self):
|
||||
@@ -235,7 +248,7 @@ class SocketMock:
|
||||
self.is_open = False
|
||||
# self.open()
|
||||
|
||||
def connect(self):
|
||||
def connect(self, timeout: int = 10):
|
||||
print(f"connecting to {self.host} port {self.port}")
|
||||
|
||||
def _put(self, msg_bytes):
|
||||
@@ -261,7 +274,7 @@ class SocketMock:
|
||||
def receive(self, buffer_length=1024):
|
||||
return self._recv(buffer_length=buffer_length)
|
||||
|
||||
def open(self):
|
||||
def open(self, timeout: int = 10):
|
||||
self._initialize_socket()
|
||||
self.is_open = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user