refactor: cleanup and unifying galil classes
This commit is contained in:
parent
89cf412551
commit
981b877038
@ -30,29 +30,7 @@ logger = bec_logger.logger
|
||||
|
||||
|
||||
class FuprGalilController(GalilController):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name="GalilController",
|
||||
kind=None,
|
||||
parent=None,
|
||||
socket_cls=None,
|
||||
socket_host=None,
|
||||
socket_port=None,
|
||||
attr_name="",
|
||||
labels=None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
kind=kind,
|
||||
parent=parent,
|
||||
socket_cls=socket_cls,
|
||||
socket_host=socket_host,
|
||||
socket_port=socket_port,
|
||||
attr_name=attr_name,
|
||||
labels=labels,
|
||||
)
|
||||
self._galil_axis_per_controller = 1
|
||||
_axes_per_controller = 1
|
||||
|
||||
def is_axis_moving(self, axis_Id, axis_Id_numeric) -> bool:
|
||||
if axis_Id is None and axis_Id_numeric is not None:
|
||||
|
@ -44,6 +44,7 @@ def retry_once(fcn):
|
||||
|
||||
|
||||
class GalilController(Controller):
|
||||
_axes_per_controller = 8
|
||||
USER_ACCESS = [
|
||||
"describe",
|
||||
"show_running_threads",
|
||||
@ -53,42 +54,6 @@ class GalilController(Controller):
|
||||
"lgalil_is_air_off_and_orchestra_enabled",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name="GalilController",
|
||||
kind=None,
|
||||
parent=None,
|
||||
socket_cls=None,
|
||||
socket_host=None,
|
||||
socket_port=None,
|
||||
attr_name="",
|
||||
labels=None,
|
||||
):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
self._galil_axis_per_controller = 8
|
||||
self._axis = [None for axis_num in range(self._galil_axis_per_controller)]
|
||||
super().__init__(
|
||||
name=name,
|
||||
socket_cls=socket_cls,
|
||||
socket_host=socket_host,
|
||||
socket_port=socket_port,
|
||||
attr_name=attr_name,
|
||||
parent=parent,
|
||||
labels=labels,
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
def set_axis(self, axis: Device, axis_nr: int) -> None:
|
||||
"""Assign an axis to a device instance.
|
||||
|
||||
Args:
|
||||
axis (Device): Device instance (e.g. GalilMotor)
|
||||
axis_nr (int): Controller axis number
|
||||
|
||||
"""
|
||||
self._axis[axis_nr] = axis
|
||||
|
||||
@threadlocked
|
||||
def socket_put(self, val: str) -> None:
|
||||
self.sock.put(f"{val}\r".encode())
|
||||
@ -160,7 +125,7 @@ class GalilController(Controller):
|
||||
"""
|
||||
return bool(float(self.socket_put_and_receive("MG allaxref").strip()))
|
||||
|
||||
def drive_axis_to_limit(self, axis_Id_numeric, direction: str) -> None:
|
||||
def drive_axis_to_limit(self, axis_Id_numeric: int, direction: str) -> None:
|
||||
"""
|
||||
Drive an axis to the limit in a specified direction.
|
||||
|
||||
@ -215,11 +180,11 @@ class GalilController(Controller):
|
||||
def show_running_threads(self) -> None:
|
||||
t = PrettyTable()
|
||||
t.title = f"Threads on {self.sock.host}:{self.sock.port}"
|
||||
t.field_names = [str(ax) for ax in range(self._galil_axis_per_controller)]
|
||||
t.field_names = [str(ax) for ax in range(self._axes_per_controller)]
|
||||
t.add_row(
|
||||
[
|
||||
"active" if self.is_thread_active(t) else "inactive"
|
||||
for t in range(self._galil_axis_per_controller)
|
||||
for t in range(self._axes_per_controller)
|
||||
]
|
||||
)
|
||||
print(t)
|
||||
@ -253,7 +218,7 @@ class GalilController(Controller):
|
||||
"Limits",
|
||||
"Position",
|
||||
]
|
||||
for ax in range(self._galil_axis_per_controller):
|
||||
for ax in range(self._axes_per_controller):
|
||||
axis = self._axis[ax]
|
||||
if axis is not None:
|
||||
t.add_row(
|
||||
|
@ -1,428 +0,0 @@
|
||||
import abc
|
||||
import functools
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from ophyd import PositionerBase, Signal
|
||||
from ophyd.device import Component as Cpt
|
||||
from ophyd.device import Device
|
||||
from prettytable import PrettyTable
|
||||
from typeguard import typechecked
|
||||
|
||||
from ophyd_devices.utils.controller import threadlocked
|
||||
from ophyd_devices.utils.socket import raise_if_disconnected
|
||||
|
||||
|
||||
def channel_checked(fcn):
|
||||
"""Decorator to catch attempted access to channels that are not available."""
|
||||
|
||||
@functools.wraps(fcn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._check_channel(args[0])
|
||||
return fcn(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class NPointController:
|
||||
_controller_instance = None
|
||||
|
||||
NUM_CHANNELS = 3
|
||||
_read_single_loc_bit = "A0"
|
||||
_write_single_loc_bit = "A2"
|
||||
_trailing_bit = "55"
|
||||
_range_offset = "78"
|
||||
_channel_base = ["11", "83"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
comm_socket: SocketIO,
|
||||
server_ip: str = "129.129.99.87",
|
||||
server_port: int = 23,
|
||||
) -> None:
|
||||
self._lock = threading.RLock()
|
||||
super().__init__()
|
||||
self._server_and_port_name = (server_ip, server_port)
|
||||
self.socket = comm_socket
|
||||
self.connected = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not NPointController._controller_instance:
|
||||
NPointController._controller_instance = object.__new__(cls)
|
||||
return NPointController._controller_instance
|
||||
|
||||
@classmethod
|
||||
def create(cls):
|
||||
return cls(SocketIO())
|
||||
|
||||
def show_all(self) -> None:
|
||||
"""Display current status of all channels
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not self.connected:
|
||||
print("npoint controller is currently disabled.")
|
||||
return
|
||||
print(f"Connected to controller at {self._server_and_port_name}")
|
||||
t = PrettyTable()
|
||||
t.field_names = ["Channel", "Range", "Position", "Target"]
|
||||
for ii in range(self.NUM_CHANNELS):
|
||||
t.add_row(
|
||||
[
|
||||
ii,
|
||||
self._get_range(ii),
|
||||
self._get_current_pos(ii),
|
||||
self._get_target_pos(ii),
|
||||
]
|
||||
)
|
||||
print(t)
|
||||
|
||||
@threadlocked
|
||||
def on(self) -> None:
|
||||
"""Enable the NPoint controller and open a new socket.
|
||||
|
||||
Raises:
|
||||
TimeoutError: Raised if the socket connection raises a timeout.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self.connected:
|
||||
print("You are already connected to the NPoint controller.")
|
||||
return
|
||||
if not self.socket.is_open:
|
||||
self.socket.open()
|
||||
try:
|
||||
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
|
||||
except socket.timeout:
|
||||
raise TimeoutError(
|
||||
f"Failed to connect to the specified server and port {self._server_and_port_name}."
|
||||
)
|
||||
except OSError:
|
||||
print("ERROR while connecting. Let's try again")
|
||||
self.socket.close()
|
||||
time.sleep(0.5)
|
||||
self.socket.open()
|
||||
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
|
||||
self.connected = True
|
||||
|
||||
@threadlocked
|
||||
def off(self) -> None:
|
||||
"""Disable the controller and close the socket.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.socket.close()
|
||||
self.connected = False
|
||||
|
||||
@channel_checked
|
||||
def _get_range(self, channel: int) -> int:
|
||||
"""Get the range of the specified channel axis.
|
||||
|
||||
Args:
|
||||
channel (int): Channel for which the range should be requested.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Raised if the received message doesn't have the expected number of bytes (10).
|
||||
|
||||
Returns:
|
||||
int: Range
|
||||
"""
|
||||
|
||||
# for first channel: 0x11 83 10 78
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{16 + 16 * channel:x}", self._range_offset])
|
||||
send_buffer = self.__read_single_location_buffer(addr)
|
||||
|
||||
recvd = self._put_and_receive(send_buffer)
|
||||
if len(recvd) != 10:
|
||||
raise RuntimeError(
|
||||
f"Received buffer is corrupted. Expected 10 bytes and instead got {len(recvd)}"
|
||||
)
|
||||
device_range = self._hex_list_to_int(recvd[5:-1], signed=False)
|
||||
return device_range
|
||||
|
||||
@channel_checked
|
||||
def _get_current_pos(self, channel: int) -> float:
|
||||
# for first channel: 0x11 83 13 34
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{19 + 16 * channel:x}", "34"])
|
||||
send_buffer = self.__read_single_location_buffer(addr)
|
||||
|
||||
recvd = self._put_and_receive(send_buffer)
|
||||
|
||||
pos_buffer = recvd[5:-1]
|
||||
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
|
||||
return pos
|
||||
|
||||
@channel_checked
|
||||
def _set_target_pos(self, channel: int, pos: float) -> None:
|
||||
# for first channel: 0x11 83 12 18 00 00 00 00
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{18 + channel * 16:x}", "18"])
|
||||
|
||||
target = int(round(1048574 / 100 * pos))
|
||||
data = [f"{m:02x}" for m in target.to_bytes(4, byteorder="big", signed=True)]
|
||||
|
||||
send_buffer = self.__write_single_location_buffer(addr, data)
|
||||
self._put(send_buffer)
|
||||
|
||||
@channel_checked
|
||||
def _get_target_pos(self, channel: int) -> float:
|
||||
# for first channel: 0x11 83 12 18
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{18 + channel * 16:x}", "18"])
|
||||
send_buffer = self.__read_single_location_buffer(addr)
|
||||
|
||||
recvd = self._put_and_receive(send_buffer)
|
||||
pos_buffer = recvd[5:-1]
|
||||
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
|
||||
return pos
|
||||
|
||||
@channel_checked
|
||||
def _set_servo(self, channel: int, enable: bool) -> None:
|
||||
print("Not tested")
|
||||
return
|
||||
# for first channel: 0x11 83 10 84 00 00 00 00
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{16 + channel * 16:x}", "84"])
|
||||
|
||||
if enable:
|
||||
data = ["00"] * 3 + ["01"]
|
||||
else:
|
||||
data = ["00"] * 4
|
||||
send_buffer = self.__write_single_location_buffer(addr, data)
|
||||
|
||||
self._put(send_buffer)
|
||||
|
||||
@channel_checked
|
||||
def _get_servo(self, channel: int) -> int:
|
||||
# for first channel: 0x11 83 10 84 00 00 00 00
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{16 + channel * 16:x}", "84"])
|
||||
send_buffer = self.__read_single_location_buffer(addr)
|
||||
|
||||
recvd = self._put_and_receive(send_buffer)
|
||||
buffer = recvd[5:-1]
|
||||
status = self._hex_list_to_int(buffer)
|
||||
return status
|
||||
|
||||
@threadlocked
|
||||
def _put(self, buffer: list) -> None:
|
||||
"""Translates a list of hex values to bytes and sends them to the socket.
|
||||
|
||||
Args:
|
||||
buffer (list): List of hex values without leading 0x
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
buffer = b"".join([bytes.fromhex(m) for m in buffer])
|
||||
self.socket.put(buffer)
|
||||
|
||||
@threadlocked
|
||||
def _put_and_receive(self, msg_hex_list: list) -> list:
|
||||
"""Send msg to socket and wait for a reply.
|
||||
|
||||
Args:
|
||||
msg_hex_list (list): List of hex values without leading 0x.
|
||||
|
||||
Returns:
|
||||
list: Received message as a list of hex values
|
||||
"""
|
||||
|
||||
buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list])
|
||||
self.socket.put(buffer)
|
||||
recv_msg = self.socket.receive()
|
||||
recv_hex_list = [hex(m) for m in recv_msg]
|
||||
self._verify_received_msg(msg_hex_list, recv_hex_list)
|
||||
return recv_hex_list
|
||||
|
||||
def _verify_received_msg(self, in_list: list, out_list: list) -> None:
|
||||
"""Ensure that the first address bits of sent and received messages are the same.
|
||||
|
||||
Args:
|
||||
in_list (list): list containing the sent message
|
||||
out_list (list): list containing the received message
|
||||
|
||||
Raises:
|
||||
RuntimeError: Raised if first two address bits of 'in' and 'out' are not identical
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
# first, translate hex (str) values to int
|
||||
in_list_int = [int(val, 16) for val in in_list]
|
||||
out_list_int = [int(val, 16) for val in out_list]
|
||||
|
||||
# first ints of the reply should be the same. Otherwise something went wrong
|
||||
if not in_list_int[:2] == out_list_int[:2]:
|
||||
raise RuntimeError("Connection failure. Please restart the controller.")
|
||||
|
||||
def _check_channel(self, channel: int) -> None:
|
||||
if channel >= self.NUM_CHANNELS:
|
||||
raise ValueError(
|
||||
f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _hex_list_to_int(in_buffer: list, byteorder="little", signed=True) -> int:
|
||||
"""Translate hex list to int.
|
||||
|
||||
Args:
|
||||
in_buffer (list): Input buffer; received as list of hex values
|
||||
byteorder (str, optional): Byteorder of in_buffer. Defaults to "little".
|
||||
signed (bool, optional): Whether the hex list represents a signed int. Defaults to True.
|
||||
|
||||
Returns:
|
||||
int: Translated integer.
|
||||
"""
|
||||
if byteorder == "little":
|
||||
in_buffer.reverse()
|
||||
|
||||
# make sure that all hex strings have the same format ("FF")
|
||||
val_hex = [f"{int(m, 16):02x}" for m in in_buffer]
|
||||
|
||||
val_bytes = [bytes.fromhex(m) for m in val_hex]
|
||||
val = int.from_bytes(b"".join(val_bytes), byteorder="big", signed=signed)
|
||||
return val
|
||||
|
||||
@staticmethod
|
||||
def __read_single_location_buffer(addr) -> list:
|
||||
"""Prepare buffer for reading from a single memory location (hex address).
|
||||
Number of bytes: 6
|
||||
Format: 0xA0 [addr] 0x55
|
||||
Return Value: 0xA0 [addr] [data] 0x55
|
||||
Sample Hex Transmission from PC to LC.400: A0 18 12 83 11 55
|
||||
Sample Hex Return Transmission from LC.400 to PC: A0 18 12 83 11 64 00 00 00 55
|
||||
|
||||
Args:
|
||||
addr (list): Hex address to read from
|
||||
|
||||
Returns:
|
||||
list: List of hex values representing the read instruction.
|
||||
"""
|
||||
buffer = []
|
||||
buffer.append(NPointController._read_single_loc_bit)
|
||||
if isinstance(addr, list):
|
||||
addr.reverse()
|
||||
buffer.extend(addr)
|
||||
else:
|
||||
buffer.append(addr)
|
||||
buffer.append(NPointController._trailing_bit)
|
||||
|
||||
return buffer
|
||||
|
||||
@staticmethod
|
||||
def __write_single_location_buffer(addr: list, data: list) -> list:
|
||||
"""Prepare buffer for writing to a single memory location (hex address).
|
||||
Number of bytes: 10
|
||||
Format: 0xA2 [addr] [data] 0x55
|
||||
Return Value: none
|
||||
Sample Hex Transmission from PC to C.400: A2 18 12 83 11 E8 03 00 00 55
|
||||
|
||||
Args:
|
||||
addr (list): List of hex values representing the address to write to.
|
||||
data (list): List of hex values representing the data that should be written.
|
||||
|
||||
Returns:
|
||||
list: List of hex values representing the write instruction.
|
||||
"""
|
||||
buffer = []
|
||||
buffer.append(NPointController._write_single_loc_bit)
|
||||
if isinstance(addr, list):
|
||||
addr.reverse()
|
||||
buffer.extend(addr)
|
||||
else:
|
||||
buffer.append(addr)
|
||||
|
||||
if isinstance(data, list):
|
||||
data.reverse()
|
||||
buffer.extend(data)
|
||||
else:
|
||||
buffer.append(data)
|
||||
buffer.append(NPointController._trailing_bit)
|
||||
return buffer
|
||||
|
||||
@staticmethod
|
||||
def __read_array():
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def __write_next_command():
|
||||
raise NotImplementedError
|
||||
|
||||
def __del__(self):
|
||||
if self.connected:
|
||||
print("Closing npoint socket")
|
||||
self.off()
|
||||
|
||||
|
||||
class SocketSignal(abc.ABC, Signal):
|
||||
def __init__(self, *, name, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _socket_get(self):
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def _socket_set(self, val):
|
||||
...
|
||||
|
||||
|
||||
class NPointSignalBase(SocketSignal):
|
||||
def __init__(self, controller, signal_name, **kwargs):
|
||||
self.controller = controller
|
||||
self.signal_name = signal_name
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class NPointReadbackSignal(NPointSignalBase):
|
||||
def _socket_get(self):
|
||||
pass
|
||||
|
||||
def _socket_set(self, val):
|
||||
pass
|
||||
|
||||
|
||||
class NPointAxis(Device, PositionerBase):
|
||||
def __init__(
|
||||
self,
|
||||
prefix="",
|
||||
*,
|
||||
name,
|
||||
channel=None,
|
||||
kind=None,
|
||||
read_attrs=None,
|
||||
configuration_attrs=None,
|
||||
parent=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.channel = channel
|
||||
self.controller = self._get_controller()
|
||||
|
||||
self.readback = Cpt(
|
||||
NPointSignal, controller=self.controller, signal_name="RBV", kind="hinted"
|
||||
)
|
||||
self.user_setpoint = Cpt(
|
||||
NPointSignal, controller=self.controller, signal_name="VAL", kind="normal"
|
||||
)
|
||||
|
||||
self.motor_resolution = Cpt(
|
||||
NPointSignal, controller=self.controller, signal_name="RNGE", kind="config"
|
||||
)
|
||||
self.motor_is_moving = Cpt(
|
||||
NPointSignal, controller=self.controller, signal_name="MOVN", kind="config"
|
||||
)
|
||||
self.axes_referenced = Cpt(
|
||||
NPointSignal, controller=self.controller, signal_name="XREF", kind="config"
|
||||
)
|
||||
|
||||
def _get_controller(self):
|
||||
return NPointController()
|
@ -11,7 +11,7 @@ from ophyd_devices.smaract.smaract_errors import (
|
||||
SmaractCommunicationError,
|
||||
SmaractErrorCode,
|
||||
)
|
||||
from ophyd_devices.utils.controller import Controller, threadlocked
|
||||
from ophyd_devices.utils.controller import Controller, axis_checked, threadlocked
|
||||
|
||||
logger = logging.getLogger("smaract_controller")
|
||||
|
||||
@ -21,17 +21,6 @@ class SmaractCommunicationMode(enum.Enum):
|
||||
ASYNC = 1
|
||||
|
||||
|
||||
def axis_checked(fcn):
|
||||
"""Decorator to catch attempted access to channels that are not available."""
|
||||
|
||||
@functools.wraps(fcn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
self._check_axis_number(args[0])
|
||||
return fcn(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def retry_once(fcn):
|
||||
"""Decorator to rerun a function in case a SmaractCommunicationError was raised. This may happen if the buffer was not empty."""
|
||||
|
||||
@ -82,6 +71,8 @@ class SmaractSensors:
|
||||
|
||||
|
||||
class SmaractController(Controller):
|
||||
_axes_per_controller = 6
|
||||
_initialized = False
|
||||
USER_ACCESS = ["socket_put_and_receive", "smaract_show_all", "move_open_loop_steps"]
|
||||
|
||||
def __init__(
|
||||
@ -96,9 +87,7 @@ class SmaractController(Controller):
|
||||
attr_name="",
|
||||
labels=None,
|
||||
):
|
||||
if not hasattr(self, "_initialized") or not self._initialized:
|
||||
self._Smaract_axis_per_controller = 6
|
||||
self._axis = [None for axis_num in range(self._Smaract_axis_per_controller)]
|
||||
if not self._initialized:
|
||||
super().__init__(
|
||||
name=name,
|
||||
socket_cls=socket_cls,
|
||||
@ -111,10 +100,6 @@ class SmaractController(Controller):
|
||||
)
|
||||
self._sensors = SmaractSensors()
|
||||
|
||||
@axis_checked
|
||||
def set_axis(self, axis_nr, axis):
|
||||
self._axis[axis_nr] = axis
|
||||
|
||||
@threadlocked
|
||||
def socket_put(self, val: str):
|
||||
self.sock.put(f":{val}\n".encode())
|
||||
@ -451,12 +436,6 @@ class SmaractController(Controller):
|
||||
t.add_row([None for t in t.field_names])
|
||||
print(t)
|
||||
|
||||
def _check_axis_number(self, axis_Id_numeric: int) -> None:
|
||||
if axis_Id_numeric >= self._Smaract_axis_per_controller:
|
||||
raise ValueError(
|
||||
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._Smaract_axis_per_controller})"
|
||||
)
|
||||
|
||||
@axis_checked
|
||||
def _error_str(self, axis_Id_numeric: int, error_number: int):
|
||||
return f":E{axis_Id_numeric},{error_number}"
|
||||
|
@ -130,12 +130,12 @@ class SmaractMotor(Device, PositionerBase):
|
||||
socket_cls=SocketIO,
|
||||
**kwargs,
|
||||
):
|
||||
self.axis_Id = axis_Id
|
||||
self.sign = sign
|
||||
self.controller = SmaractController(
|
||||
socket_cls=socket_cls, socket_host=host, socket_port=port
|
||||
)
|
||||
self.controller.set_axis(self.axis_Id_numeric, axis=self)
|
||||
self.axis_Id = axis_Id
|
||||
self.sign = sign
|
||||
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
|
||||
self.tolerance = kwargs.pop("tolerance", 0.5)
|
||||
|
||||
super().__init__(
|
||||
|
@ -2,6 +2,7 @@ import functools
|
||||
import threading
|
||||
|
||||
from bec_lib.core import bec_logger
|
||||
from ophyd import Device
|
||||
from ophyd.ophydobj import OphydObject
|
||||
|
||||
logger = bec_logger.logger
|
||||
@ -19,8 +20,28 @@ def threadlocked(fcn):
|
||||
return wrapper
|
||||
|
||||
|
||||
def axis_checked(fcn):
|
||||
"""Decorator to catch attempted access to channels that are not available."""
|
||||
|
||||
@functools.wraps(fcn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if "axis_nr" in kwargs:
|
||||
self._check_axis_number(kwargs["axis_nr"])
|
||||
elif "axis_Id_numeric" in kwargs:
|
||||
self._check_axis_number(kwargs["axis_Id_numeric"])
|
||||
elif args:
|
||||
self._check_axis_number(args[0])
|
||||
return fcn(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Controller(OphydObject):
|
||||
"""Base class for all socker-based controllers."""
|
||||
|
||||
_controller_instances = {}
|
||||
_initialized = False
|
||||
_axes_per_controller = 1
|
||||
|
||||
SUB_CONNECTION_CHANGE = "connection_change"
|
||||
|
||||
@ -40,11 +61,12 @@ class Controller(OphydObject):
|
||||
self._socket_cls = socket_cls
|
||||
self._socket_host = socket_host
|
||||
self._socket_port = socket_port
|
||||
if not hasattr(self, "_initialized"):
|
||||
if not self._initialized:
|
||||
super().__init__(
|
||||
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
|
||||
)
|
||||
self._lock = threading.RLock()
|
||||
self._axis = []
|
||||
self._initialize()
|
||||
self._initialized = True
|
||||
|
||||
@ -54,8 +76,12 @@ class Controller(OphydObject):
|
||||
|
||||
def _set_default_values(self):
|
||||
# no. of axes controlled by each controller
|
||||
self._axis_per_controller = 8
|
||||
self._motors = [None for axis_num in range(self._axis_per_controller)]
|
||||
self._axis = [None for axis_num in range(self._axes_per_controller)]
|
||||
|
||||
@classmethod
|
||||
def _reset_controller(cls):
|
||||
cls._controller_instances = {}
|
||||
cls._initialized = False
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
@ -66,13 +92,35 @@ class Controller(OphydObject):
|
||||
self._connected = value
|
||||
self._run_subs(sub_type=self.SUB_CONNECTION_CHANGE)
|
||||
|
||||
def set_motor(self, motor, axis):
|
||||
"""Set the motor instance for a specified controller axis."""
|
||||
self._motors[axis] = motor
|
||||
@axis_checked
|
||||
def set_axis(self, *, axis: Device, axis_nr: int) -> None:
|
||||
"""Assign an axis to a device instance.
|
||||
|
||||
def get_motor(self, axis):
|
||||
"""Get motor instance for a specified controller axis."""
|
||||
return self._motors[axis]
|
||||
Args:
|
||||
axis (Device): Device instance (e.g. GalilMotor)
|
||||
axis_nr (int): Controller axis number
|
||||
|
||||
"""
|
||||
self._axis[axis_nr] = axis
|
||||
|
||||
@axis_checked
|
||||
def get_axis(self, axis_nr: int) -> Device:
|
||||
"""Get device instance for a specified controller axis.
|
||||
|
||||
Args:
|
||||
axis_nr (int): Controller axis number
|
||||
|
||||
Returns:
|
||||
Device: Device instance (e.g. GalilMotor)
|
||||
|
||||
"""
|
||||
return self._axis[axis_nr]
|
||||
|
||||
def _check_axis_number(self, axis_Id_numeric: int) -> None:
|
||||
if axis_Id_numeric >= self._axes_per_controller:
|
||||
raise ValueError(
|
||||
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
|
||||
)
|
||||
|
||||
def on(self) -> None:
|
||||
"""Open a new socket connection to the controller"""
|
||||
@ -103,6 +151,6 @@ class Controller(OphydObject):
|
||||
if not socket_port:
|
||||
raise RuntimeError("Socket port must be specified.")
|
||||
host_port = f"{socket_host}:{socket_port}"
|
||||
if host_port not in Controller._controller_instances:
|
||||
Controller._controller_instances[host_port] = object.__new__(cls)
|
||||
return Controller._controller_instances[host_port]
|
||||
if host_port not in cls._controller_instances:
|
||||
cls._controller_instances[host_port] = object.__new__(cls)
|
||||
return cls._controller_instances[host_port]
|
||||
|
2
setup.py
2
setup.py
@ -14,6 +14,6 @@ if __name__ == "__main__":
|
||||
"std_daq_client",
|
||||
"pyepics",
|
||||
],
|
||||
extras_require={"dev": ["pytest", "pytest-random-order", "black"]},
|
||||
extras_require={"dev": ["pytest", "pytest-random-order", "black", "coverage"]},
|
||||
version=__version__,
|
||||
)
|
||||
|
30
tests/test_controller.py
Normal file
30
tests/test_controller.py
Normal file
@ -0,0 +1,30 @@
|
||||
from unittest import mock
|
||||
|
||||
from ophyd_devices.utils.controller import Controller
|
||||
|
||||
|
||||
def test_controller_off():
|
||||
controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
with mock.patch.object(controller.sock, "close") as mock_close:
|
||||
controller.off()
|
||||
assert controller.sock is None
|
||||
assert controller.connected is False
|
||||
mock_close.assert_called_once()
|
||||
|
||||
# make sure it is indempotent
|
||||
controller.off()
|
||||
|
||||
|
||||
def test_controller_on():
|
||||
socket_cls = mock.MagicMock()
|
||||
Controller._controller_instances = {}
|
||||
controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
assert controller.sock is not None
|
||||
assert controller.connected is True
|
||||
socket_cls().open.assert_called_once()
|
||||
|
||||
# make sure it is indempotent
|
||||
controller.on()
|
||||
socket_cls().open.assert_called_once()
|
@ -1,3 +1,5 @@
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from utils import SocketMock
|
||||
|
||||
@ -61,3 +63,129 @@ def test_axis_put(target_pos, socket_put_messages, socket_get_messages):
|
||||
leyey.controller.sock.buffer_recv = socket_get_messages
|
||||
leyey.user_setpoint.put(target_pos)
|
||||
assert leyey.controller.sock.buffer_put == socket_put_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"axis_nr,direction,socket_put_messages,socket_get_messages",
|
||||
[
|
||||
(
|
||||
0,
|
||||
"forward",
|
||||
[
|
||||
b"naxis=0\r",
|
||||
b"ndir=1\r",
|
||||
b"XQ#NEWPAR\r",
|
||||
b"XQ#FES\r",
|
||||
b"MG_BGA\r",
|
||||
b"MGbcklact[axis]\r",
|
||||
b"MG_XQ0\r",
|
||||
b"MG_XQ2\r",
|
||||
b"MG _LRA, _LFA\r",
|
||||
],
|
||||
[
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b"0",
|
||||
b"0",
|
||||
b"-1",
|
||||
b"-1",
|
||||
b"1.000 0.000",
|
||||
],
|
||||
),
|
||||
(
|
||||
1,
|
||||
"reverse",
|
||||
[
|
||||
b"naxis=1\r",
|
||||
b"ndir=-1\r",
|
||||
b"XQ#NEWPAR\r",
|
||||
b"XQ#FES\r",
|
||||
b"MG_BGB\r",
|
||||
b"MGbcklact[axis]\r",
|
||||
b"MG_XQ0\r",
|
||||
b"MG_XQ2\r",
|
||||
b"MG _LRB, _LFB\r",
|
||||
],
|
||||
[
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b"0",
|
||||
b"0",
|
||||
b"-1",
|
||||
b"-1",
|
||||
b"0.000 1.000",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_drive_axis_to_limit(axis_nr, direction, socket_put_messages, socket_get_messages):
|
||||
leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock)
|
||||
leyey.controller.on()
|
||||
leyey.controller.sock.flush_buffer()
|
||||
leyey.controller.sock.buffer_recv = socket_get_messages
|
||||
leyey.controller.drive_axis_to_limit(axis_nr, direction)
|
||||
assert leyey.controller.sock.buffer_put == socket_put_messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"axis_nr,socket_put_messages,socket_get_messages",
|
||||
[
|
||||
(
|
||||
0,
|
||||
[
|
||||
b"naxis=0\r",
|
||||
b"XQ#NEWPAR\r",
|
||||
b"XQ#FRM\r",
|
||||
b"MG_BGA\r",
|
||||
b"MGbcklact[axis]\r",
|
||||
b"MG_XQ0\r",
|
||||
b"MG_XQ2\r",
|
||||
b"MG axisref[0]\r",
|
||||
],
|
||||
[
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b"0",
|
||||
b"0",
|
||||
b"-1",
|
||||
b"-1",
|
||||
b"1.00",
|
||||
],
|
||||
),
|
||||
(
|
||||
1,
|
||||
[
|
||||
b"naxis=1\r",
|
||||
b"XQ#NEWPAR\r",
|
||||
b"XQ#FRM\r",
|
||||
b"MG_BGB\r",
|
||||
b"MGbcklact[axis]\r",
|
||||
b"MG_XQ0\r",
|
||||
b"MG_XQ2\r",
|
||||
b"MG axisref[1]\r",
|
||||
],
|
||||
[
|
||||
b":",
|
||||
b":",
|
||||
b":",
|
||||
b"0",
|
||||
b"0",
|
||||
b"-1",
|
||||
b"-1",
|
||||
b"1.00",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_find_reference(axis_nr, socket_put_messages, socket_get_messages):
|
||||
leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock)
|
||||
leyey.controller.on()
|
||||
leyey.controller.sock.flush_buffer()
|
||||
leyey.controller.sock.buffer_recv = socket_get_messages
|
||||
leyey.controller.find_reference(axis_nr)
|
||||
assert leyey.controller.sock.buffer_put == socket_put_messages
|
||||
|
@ -20,6 +20,7 @@ from ophyd_devices.smaract.smaract_ophyd import SmaractMotor
|
||||
],
|
||||
)
|
||||
def test_get_position(axis, position, get_message, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -39,6 +40,7 @@ def test_get_position(axis, position, get_message, return_msg):
|
||||
],
|
||||
)
|
||||
def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, exception):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -62,6 +64,7 @@ def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, except
|
||||
],
|
||||
)
|
||||
def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -87,6 +90,7 @@ def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
|
||||
],
|
||||
)
|
||||
def test_communication_mode(mode, get_message, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -112,6 +116,7 @@ def test_communication_mode(mode, get_message, return_msg):
|
||||
],
|
||||
)
|
||||
def test_axis_is_moving(is_moving, get_message, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -132,6 +137,7 @@ def test_axis_is_moving(is_moving, get_message, return_msg):
|
||||
],
|
||||
)
|
||||
def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -149,6 +155,7 @@ def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
|
||||
],
|
||||
)
|
||||
def test_set_move_speed(move_speed, axis, get_msg, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -166,6 +173,7 @@ def test_set_move_speed(move_speed, axis, get_msg, return_msg):
|
||||
],
|
||||
)
|
||||
def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
|
||||
controller.on()
|
||||
controller.sock.flush_buffer()
|
||||
@ -203,6 +211,7 @@ def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_ms
|
||||
],
|
||||
)
|
||||
def test_move_axis(pos, get_msg, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
lsmarA = SmaractMotor(
|
||||
"A",
|
||||
name="lsmarA",
|
||||
@ -230,6 +239,7 @@ def test_move_axis(pos, get_msg, return_msg):
|
||||
],
|
||||
)
|
||||
def test_stop_axis(num_axes, get_msg, return_msg):
|
||||
SmaractController._reset_controller()
|
||||
lsmarA = SmaractMotor(
|
||||
"A",
|
||||
name="lsmarA",
|
||||
|
Loading…
x
Reference in New Issue
Block a user