refactor: cleanup and unifying galil classes

This commit is contained in:
wakonig_k 2023-11-08 14:01:20 +01:00
parent 89cf412551
commit 981b877038
10 changed files with 242 additions and 532 deletions

View File

@ -30,29 +30,7 @@ logger = bec_logger.logger
class FuprGalilController(GalilController):
def __init__(
self,
*,
name="GalilController",
kind=None,
parent=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
labels=None,
):
super().__init__(
name=name,
kind=kind,
parent=parent,
socket_cls=socket_cls,
socket_host=socket_host,
socket_port=socket_port,
attr_name=attr_name,
labels=labels,
)
self._galil_axis_per_controller = 1
_axes_per_controller = 1
def is_axis_moving(self, axis_Id, axis_Id_numeric) -> bool:
if axis_Id is None and axis_Id_numeric is not None:

View File

@ -44,6 +44,7 @@ def retry_once(fcn):
class GalilController(Controller):
_axes_per_controller = 8
USER_ACCESS = [
"describe",
"show_running_threads",
@ -53,42 +54,6 @@ class GalilController(Controller):
"lgalil_is_air_off_and_orchestra_enabled",
]
def __init__(
self,
*,
name="GalilController",
kind=None,
parent=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
labels=None,
):
if not hasattr(self, "_initialized") or not self._initialized:
self._galil_axis_per_controller = 8
self._axis = [None for axis_num in range(self._galil_axis_per_controller)]
super().__init__(
name=name,
socket_cls=socket_cls,
socket_host=socket_host,
socket_port=socket_port,
attr_name=attr_name,
parent=parent,
labels=labels,
kind=kind,
)
def set_axis(self, axis: Device, axis_nr: int) -> None:
"""Assign an axis to a device instance.
Args:
axis (Device): Device instance (e.g. GalilMotor)
axis_nr (int): Controller axis number
"""
self._axis[axis_nr] = axis
@threadlocked
def socket_put(self, val: str) -> None:
self.sock.put(f"{val}\r".encode())
@ -160,7 +125,7 @@ class GalilController(Controller):
"""
return bool(float(self.socket_put_and_receive("MG allaxref").strip()))
def drive_axis_to_limit(self, axis_Id_numeric, direction: str) -> None:
def drive_axis_to_limit(self, axis_Id_numeric: int, direction: str) -> None:
"""
Drive an axis to the limit in a specified direction.
@ -215,11 +180,11 @@ class GalilController(Controller):
def show_running_threads(self) -> None:
t = PrettyTable()
t.title = f"Threads on {self.sock.host}:{self.sock.port}"
t.field_names = [str(ax) for ax in range(self._galil_axis_per_controller)]
t.field_names = [str(ax) for ax in range(self._axes_per_controller)]
t.add_row(
[
"active" if self.is_thread_active(t) else "inactive"
for t in range(self._galil_axis_per_controller)
for t in range(self._axes_per_controller)
]
)
print(t)
@ -253,7 +218,7 @@ class GalilController(Controller):
"Limits",
"Position",
]
for ax in range(self._galil_axis_per_controller):
for ax in range(self._axes_per_controller):
axis = self._axis[ax]
if axis is not None:
t.add_row(

View File

@ -1,428 +0,0 @@
import abc
import functools
import socket
import threading
import time
from ophyd import PositionerBase, Signal
from ophyd.device import Component as Cpt
from ophyd.device import Device
from prettytable import PrettyTable
from typeguard import typechecked
from ophyd_devices.utils.controller import threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
def channel_checked(fcn):
"""Decorator to catch attempted access to channels that are not available."""
@functools.wraps(fcn)
def wrapper(self, *args, **kwargs):
self._check_channel(args[0])
return fcn(self, *args, **kwargs)
return wrapper
class NPointController:
_controller_instance = None
NUM_CHANNELS = 3
_read_single_loc_bit = "A0"
_write_single_loc_bit = "A2"
_trailing_bit = "55"
_range_offset = "78"
_channel_base = ["11", "83"]
def __init__(
self,
comm_socket: SocketIO,
server_ip: str = "129.129.99.87",
server_port: int = 23,
) -> None:
self._lock = threading.RLock()
super().__init__()
self._server_and_port_name = (server_ip, server_port)
self.socket = comm_socket
self.connected = False
def __new__(cls, *args, **kwargs):
if not NPointController._controller_instance:
NPointController._controller_instance = object.__new__(cls)
return NPointController._controller_instance
@classmethod
def create(cls):
return cls(SocketIO())
def show_all(self) -> None:
"""Display current status of all channels
Returns:
None
"""
if not self.connected:
print("npoint controller is currently disabled.")
return
print(f"Connected to controller at {self._server_and_port_name}")
t = PrettyTable()
t.field_names = ["Channel", "Range", "Position", "Target"]
for ii in range(self.NUM_CHANNELS):
t.add_row(
[
ii,
self._get_range(ii),
self._get_current_pos(ii),
self._get_target_pos(ii),
]
)
print(t)
@threadlocked
def on(self) -> None:
"""Enable the NPoint controller and open a new socket.
Raises:
TimeoutError: Raised if the socket connection raises a timeout.
Returns:
None
"""
if self.connected:
print("You are already connected to the NPoint controller.")
return
if not self.socket.is_open:
self.socket.open()
try:
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
except socket.timeout:
raise TimeoutError(
f"Failed to connect to the specified server and port {self._server_and_port_name}."
)
except OSError:
print("ERROR while connecting. Let's try again")
self.socket.close()
time.sleep(0.5)
self.socket.open()
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
self.connected = True
@threadlocked
def off(self) -> None:
"""Disable the controller and close the socket.
Returns:
None
"""
self.socket.close()
self.connected = False
@channel_checked
def _get_range(self, channel: int) -> int:
"""Get the range of the specified channel axis.
Args:
channel (int): Channel for which the range should be requested.
Raises:
RuntimeError: Raised if the received message doesn't have the expected number of bytes (10).
Returns:
int: Range
"""
# for first channel: 0x11 83 10 78
addr = self._channel_base.copy()
addr.extend([f"{16 + 16 * channel:x}", self._range_offset])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
if len(recvd) != 10:
raise RuntimeError(
f"Received buffer is corrupted. Expected 10 bytes and instead got {len(recvd)}"
)
device_range = self._hex_list_to_int(recvd[5:-1], signed=False)
return device_range
@channel_checked
def _get_current_pos(self, channel: int) -> float:
# for first channel: 0x11 83 13 34
addr = self._channel_base.copy()
addr.extend([f"{19 + 16 * channel:x}", "34"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
pos_buffer = recvd[5:-1]
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
return pos
@channel_checked
def _set_target_pos(self, channel: int, pos: float) -> None:
# for first channel: 0x11 83 12 18 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
target = int(round(1048574 / 100 * pos))
data = [f"{m:02x}" for m in target.to_bytes(4, byteorder="big", signed=True)]
send_buffer = self.__write_single_location_buffer(addr, data)
self._put(send_buffer)
@channel_checked
def _get_target_pos(self, channel: int) -> float:
# for first channel: 0x11 83 12 18
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
pos_buffer = recvd[5:-1]
pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100
return pos
@channel_checked
def _set_servo(self, channel: int, enable: bool) -> None:
print("Not tested")
return
# for first channel: 0x11 83 10 84 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{16 + channel * 16:x}", "84"])
if enable:
data = ["00"] * 3 + ["01"]
else:
data = ["00"] * 4
send_buffer = self.__write_single_location_buffer(addr, data)
self._put(send_buffer)
@channel_checked
def _get_servo(self, channel: int) -> int:
# for first channel: 0x11 83 10 84 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{16 + channel * 16:x}", "84"])
send_buffer = self.__read_single_location_buffer(addr)
recvd = self._put_and_receive(send_buffer)
buffer = recvd[5:-1]
status = self._hex_list_to_int(buffer)
return status
@threadlocked
def _put(self, buffer: list) -> None:
"""Translates a list of hex values to bytes and sends them to the socket.
Args:
buffer (list): List of hex values without leading 0x
Returns:
None
"""
buffer = b"".join([bytes.fromhex(m) for m in buffer])
self.socket.put(buffer)
@threadlocked
def _put_and_receive(self, msg_hex_list: list) -> list:
"""Send msg to socket and wait for a reply.
Args:
msg_hex_list (list): List of hex values without leading 0x.
Returns:
list: Received message as a list of hex values
"""
buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list])
self.socket.put(buffer)
recv_msg = self.socket.receive()
recv_hex_list = [hex(m) for m in recv_msg]
self._verify_received_msg(msg_hex_list, recv_hex_list)
return recv_hex_list
def _verify_received_msg(self, in_list: list, out_list: list) -> None:
"""Ensure that the first address bits of sent and received messages are the same.
Args:
in_list (list): list containing the sent message
out_list (list): list containing the received message
Raises:
RuntimeError: Raised if first two address bits of 'in' and 'out' are not identical
Returns:
None
"""
# first, translate hex (str) values to int
in_list_int = [int(val, 16) for val in in_list]
out_list_int = [int(val, 16) for val in out_list]
# first ints of the reply should be the same. Otherwise something went wrong
if not in_list_int[:2] == out_list_int[:2]:
raise RuntimeError("Connection failure. Please restart the controller.")
def _check_channel(self, channel: int) -> None:
if channel >= self.NUM_CHANNELS:
raise ValueError(
f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})"
)
@staticmethod
def _hex_list_to_int(in_buffer: list, byteorder="little", signed=True) -> int:
"""Translate hex list to int.
Args:
in_buffer (list): Input buffer; received as list of hex values
byteorder (str, optional): Byteorder of in_buffer. Defaults to "little".
signed (bool, optional): Whether the hex list represents a signed int. Defaults to True.
Returns:
int: Translated integer.
"""
if byteorder == "little":
in_buffer.reverse()
# make sure that all hex strings have the same format ("FF")
val_hex = [f"{int(m, 16):02x}" for m in in_buffer]
val_bytes = [bytes.fromhex(m) for m in val_hex]
val = int.from_bytes(b"".join(val_bytes), byteorder="big", signed=signed)
return val
@staticmethod
def __read_single_location_buffer(addr) -> list:
"""Prepare buffer for reading from a single memory location (hex address).
Number of bytes: 6
Format: 0xA0 [addr] 0x55
Return Value: 0xA0 [addr] [data] 0x55
Sample Hex Transmission from PC to LC.400: A0 18 12 83 11 55
Sample Hex Return Transmission from LC.400 to PC: A0 18 12 83 11 64 00 00 00 55
Args:
addr (list): Hex address to read from
Returns:
list: List of hex values representing the read instruction.
"""
buffer = []
buffer.append(NPointController._read_single_loc_bit)
if isinstance(addr, list):
addr.reverse()
buffer.extend(addr)
else:
buffer.append(addr)
buffer.append(NPointController._trailing_bit)
return buffer
@staticmethod
def __write_single_location_buffer(addr: list, data: list) -> list:
"""Prepare buffer for writing to a single memory location (hex address).
Number of bytes: 10
Format: 0xA2 [addr] [data] 0x55
Return Value: none
Sample Hex Transmission from PC to C.400: A2 18 12 83 11 E8 03 00 00 55
Args:
addr (list): List of hex values representing the address to write to.
data (list): List of hex values representing the data that should be written.
Returns:
list: List of hex values representing the write instruction.
"""
buffer = []
buffer.append(NPointController._write_single_loc_bit)
if isinstance(addr, list):
addr.reverse()
buffer.extend(addr)
else:
buffer.append(addr)
if isinstance(data, list):
data.reverse()
buffer.extend(data)
else:
buffer.append(data)
buffer.append(NPointController._trailing_bit)
return buffer
@staticmethod
def __read_array():
raise NotImplementedError
@staticmethod
def __write_next_command():
raise NotImplementedError
def __del__(self):
if self.connected:
print("Closing npoint socket")
self.off()
class SocketSignal(abc.ABC, Signal):
def __init__(self, *, name, **kwargs):
super().__init__(**kwargs)
@abc.abstractmethod
def _socket_get(self):
...
@abc.abstractmethod
def _socket_set(self, val):
...
class NPointSignalBase(SocketSignal):
def __init__(self, controller, signal_name, **kwargs):
self.controller = controller
self.signal_name = signal_name
super().__init__(**kwargs)
class NPointReadbackSignal(NPointSignalBase):
def _socket_get(self):
pass
def _socket_set(self, val):
pass
class NPointAxis(Device, PositionerBase):
def __init__(
self,
prefix="",
*,
name,
channel=None,
kind=None,
read_attrs=None,
configuration_attrs=None,
parent=None,
**kwargs,
):
self.channel = channel
self.controller = self._get_controller()
self.readback = Cpt(
NPointSignal, controller=self.controller, signal_name="RBV", kind="hinted"
)
self.user_setpoint = Cpt(
NPointSignal, controller=self.controller, signal_name="VAL", kind="normal"
)
self.motor_resolution = Cpt(
NPointSignal, controller=self.controller, signal_name="RNGE", kind="config"
)
self.motor_is_moving = Cpt(
NPointSignal, controller=self.controller, signal_name="MOVN", kind="config"
)
self.axes_referenced = Cpt(
NPointSignal, controller=self.controller, signal_name="XREF", kind="config"
)
def _get_controller(self):
return NPointController()

View File

@ -11,7 +11,7 @@ from ophyd_devices.smaract.smaract_errors import (
SmaractCommunicationError,
SmaractErrorCode,
)
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.controller import Controller, axis_checked, threadlocked
logger = logging.getLogger("smaract_controller")
@ -21,17 +21,6 @@ class SmaractCommunicationMode(enum.Enum):
ASYNC = 1
def axis_checked(fcn):
"""Decorator to catch attempted access to channels that are not available."""
@functools.wraps(fcn)
def wrapper(self, *args, **kwargs):
self._check_axis_number(args[0])
return fcn(self, *args, **kwargs)
return wrapper
def retry_once(fcn):
"""Decorator to rerun a function in case a SmaractCommunicationError was raised. This may happen if the buffer was not empty."""
@ -82,6 +71,8 @@ class SmaractSensors:
class SmaractController(Controller):
_axes_per_controller = 6
_initialized = False
USER_ACCESS = ["socket_put_and_receive", "smaract_show_all", "move_open_loop_steps"]
def __init__(
@ -96,9 +87,7 @@ class SmaractController(Controller):
attr_name="",
labels=None,
):
if not hasattr(self, "_initialized") or not self._initialized:
self._Smaract_axis_per_controller = 6
self._axis = [None for axis_num in range(self._Smaract_axis_per_controller)]
if not self._initialized:
super().__init__(
name=name,
socket_cls=socket_cls,
@ -111,10 +100,6 @@ class SmaractController(Controller):
)
self._sensors = SmaractSensors()
@axis_checked
def set_axis(self, axis_nr, axis):
self._axis[axis_nr] = axis
@threadlocked
def socket_put(self, val: str):
self.sock.put(f":{val}\n".encode())
@ -451,12 +436,6 @@ class SmaractController(Controller):
t.add_row([None for t in t.field_names])
print(t)
def _check_axis_number(self, axis_Id_numeric: int) -> None:
if axis_Id_numeric >= self._Smaract_axis_per_controller:
raise ValueError(
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._Smaract_axis_per_controller})"
)
@axis_checked
def _error_str(self, axis_Id_numeric: int, error_number: int):
return f":E{axis_Id_numeric},{error_number}"

View File

@ -130,12 +130,12 @@ class SmaractMotor(Device, PositionerBase):
socket_cls=SocketIO,
**kwargs,
):
self.axis_Id = axis_Id
self.sign = sign
self.controller = SmaractController(
socket_cls=socket_cls, socket_host=host, socket_port=port
)
self.controller.set_axis(self.axis_Id_numeric, axis=self)
self.axis_Id = axis_Id
self.sign = sign
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
self.tolerance = kwargs.pop("tolerance", 0.5)
super().__init__(

View File

@ -2,6 +2,7 @@ import functools
import threading
from bec_lib.core import bec_logger
from ophyd import Device
from ophyd.ophydobj import OphydObject
logger = bec_logger.logger
@ -19,8 +20,28 @@ def threadlocked(fcn):
return wrapper
def axis_checked(fcn):
"""Decorator to catch attempted access to channels that are not available."""
@functools.wraps(fcn)
def wrapper(self, *args, **kwargs):
if "axis_nr" in kwargs:
self._check_axis_number(kwargs["axis_nr"])
elif "axis_Id_numeric" in kwargs:
self._check_axis_number(kwargs["axis_Id_numeric"])
elif args:
self._check_axis_number(args[0])
return fcn(self, *args, **kwargs)
return wrapper
class Controller(OphydObject):
"""Base class for all socker-based controllers."""
_controller_instances = {}
_initialized = False
_axes_per_controller = 1
SUB_CONNECTION_CHANGE = "connection_change"
@ -40,11 +61,12 @@ class Controller(OphydObject):
self._socket_cls = socket_cls
self._socket_host = socket_host
self._socket_port = socket_port
if not hasattr(self, "_initialized"):
if not self._initialized:
super().__init__(
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
)
self._lock = threading.RLock()
self._axis = []
self._initialize()
self._initialized = True
@ -54,8 +76,12 @@ class Controller(OphydObject):
def _set_default_values(self):
# no. of axes controlled by each controller
self._axis_per_controller = 8
self._motors = [None for axis_num in range(self._axis_per_controller)]
self._axis = [None for axis_num in range(self._axes_per_controller)]
@classmethod
def _reset_controller(cls):
cls._controller_instances = {}
cls._initialized = False
@property
def connected(self):
@ -66,13 +92,35 @@ class Controller(OphydObject):
self._connected = value
self._run_subs(sub_type=self.SUB_CONNECTION_CHANGE)
def set_motor(self, motor, axis):
"""Set the motor instance for a specified controller axis."""
self._motors[axis] = motor
@axis_checked
def set_axis(self, *, axis: Device, axis_nr: int) -> None:
"""Assign an axis to a device instance.
def get_motor(self, axis):
"""Get motor instance for a specified controller axis."""
return self._motors[axis]
Args:
axis (Device): Device instance (e.g. GalilMotor)
axis_nr (int): Controller axis number
"""
self._axis[axis_nr] = axis
@axis_checked
def get_axis(self, axis_nr: int) -> Device:
"""Get device instance for a specified controller axis.
Args:
axis_nr (int): Controller axis number
Returns:
Device: Device instance (e.g. GalilMotor)
"""
return self._axis[axis_nr]
def _check_axis_number(self, axis_Id_numeric: int) -> None:
if axis_Id_numeric >= self._axes_per_controller:
raise ValueError(
f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})"
)
def on(self) -> None:
"""Open a new socket connection to the controller"""
@ -103,6 +151,6 @@ class Controller(OphydObject):
if not socket_port:
raise RuntimeError("Socket port must be specified.")
host_port = f"{socket_host}:{socket_port}"
if host_port not in Controller._controller_instances:
Controller._controller_instances[host_port] = object.__new__(cls)
return Controller._controller_instances[host_port]
if host_port not in cls._controller_instances:
cls._controller_instances[host_port] = object.__new__(cls)
return cls._controller_instances[host_port]

View File

@ -14,6 +14,6 @@ if __name__ == "__main__":
"std_daq_client",
"pyepics",
],
extras_require={"dev": ["pytest", "pytest-random-order", "black"]},
extras_require={"dev": ["pytest", "pytest-random-order", "black", "coverage"]},
version=__version__,
)

30
tests/test_controller.py Normal file
View File

@ -0,0 +1,30 @@
from unittest import mock
from ophyd_devices.utils.controller import Controller
def test_controller_off():
controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123)
controller.on()
with mock.patch.object(controller.sock, "close") as mock_close:
controller.off()
assert controller.sock is None
assert controller.connected is False
mock_close.assert_called_once()
# make sure it is indempotent
controller.off()
def test_controller_on():
socket_cls = mock.MagicMock()
Controller._controller_instances = {}
controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123)
controller.on()
assert controller.sock is not None
assert controller.connected is True
socket_cls().open.assert_called_once()
# make sure it is indempotent
controller.on()
socket_cls().open.assert_called_once()

View File

@ -1,3 +1,5 @@
from unittest import mock
import pytest
from utils import SocketMock
@ -61,3 +63,129 @@ def test_axis_put(target_pos, socket_put_messages, socket_get_messages):
leyey.controller.sock.buffer_recv = socket_get_messages
leyey.user_setpoint.put(target_pos)
assert leyey.controller.sock.buffer_put == socket_put_messages
@pytest.mark.parametrize(
"axis_nr,direction,socket_put_messages,socket_get_messages",
[
(
0,
"forward",
[
b"naxis=0\r",
b"ndir=1\r",
b"XQ#NEWPAR\r",
b"XQ#FES\r",
b"MG_BGA\r",
b"MGbcklact[axis]\r",
b"MG_XQ0\r",
b"MG_XQ2\r",
b"MG _LRA, _LFA\r",
],
[
b":",
b":",
b":",
b":",
b"0",
b"0",
b"-1",
b"-1",
b"1.000 0.000",
],
),
(
1,
"reverse",
[
b"naxis=1\r",
b"ndir=-1\r",
b"XQ#NEWPAR\r",
b"XQ#FES\r",
b"MG_BGB\r",
b"MGbcklact[axis]\r",
b"MG_XQ0\r",
b"MG_XQ2\r",
b"MG _LRB, _LFB\r",
],
[
b":",
b":",
b":",
b":",
b"0",
b"0",
b"-1",
b"-1",
b"0.000 1.000",
],
),
],
)
def test_drive_axis_to_limit(axis_nr, direction, socket_put_messages, socket_get_messages):
leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock)
leyey.controller.on()
leyey.controller.sock.flush_buffer()
leyey.controller.sock.buffer_recv = socket_get_messages
leyey.controller.drive_axis_to_limit(axis_nr, direction)
assert leyey.controller.sock.buffer_put == socket_put_messages
@pytest.mark.parametrize(
"axis_nr,socket_put_messages,socket_get_messages",
[
(
0,
[
b"naxis=0\r",
b"XQ#NEWPAR\r",
b"XQ#FRM\r",
b"MG_BGA\r",
b"MGbcklact[axis]\r",
b"MG_XQ0\r",
b"MG_XQ2\r",
b"MG axisref[0]\r",
],
[
b":",
b":",
b":",
b"0",
b"0",
b"-1",
b"-1",
b"1.00",
],
),
(
1,
[
b"naxis=1\r",
b"XQ#NEWPAR\r",
b"XQ#FRM\r",
b"MG_BGB\r",
b"MGbcklact[axis]\r",
b"MG_XQ0\r",
b"MG_XQ2\r",
b"MG axisref[1]\r",
],
[
b":",
b":",
b":",
b"0",
b"0",
b"-1",
b"-1",
b"1.00",
],
),
],
)
def test_find_reference(axis_nr, socket_put_messages, socket_get_messages):
leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock)
leyey.controller.on()
leyey.controller.sock.flush_buffer()
leyey.controller.sock.buffer_recv = socket_get_messages
leyey.controller.find_reference(axis_nr)
assert leyey.controller.sock.buffer_put == socket_put_messages

View File

@ -20,6 +20,7 @@ from ophyd_devices.smaract.smaract_ophyd import SmaractMotor
],
)
def test_get_position(axis, position, get_message, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -39,6 +40,7 @@ def test_get_position(axis, position, get_message, return_msg):
],
)
def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, exception):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -62,6 +64,7 @@ def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, except
],
)
def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -87,6 +90,7 @@ def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
],
)
def test_communication_mode(mode, get_message, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -112,6 +116,7 @@ def test_communication_mode(mode, get_message, return_msg):
],
)
def test_axis_is_moving(is_moving, get_message, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -132,6 +137,7 @@ def test_axis_is_moving(is_moving, get_message, return_msg):
],
)
def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -149,6 +155,7 @@ def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
],
)
def test_set_move_speed(move_speed, axis, get_msg, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -166,6 +173,7 @@ def test_set_move_speed(move_speed, axis, get_msg, return_msg):
],
)
def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_msg):
SmaractController._reset_controller()
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
@ -203,6 +211,7 @@ def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_ms
],
)
def test_move_axis(pos, get_msg, return_msg):
SmaractController._reset_controller()
lsmarA = SmaractMotor(
"A",
name="lsmarA",
@ -230,6 +239,7 @@ def test_move_axis(pos, get_msg, return_msg):
],
)
def test_stop_axis(num_axes, get_msg, return_msg):
SmaractController._reset_controller()
lsmarA = SmaractMotor(
"A",
name="lsmarA",