diff --git a/csaxs_bec/devices/npoint/npoint.py b/csaxs_bec/devices/npoint/npoint.py index 9fcb5d9..30427be 100644 --- a/csaxs_bec/devices/npoint/npoint.py +++ b/csaxs_bec/devices/npoint/npoint.py @@ -1,9 +1,7 @@ import functools -import socket -import threading import time -from ophyd_devices.utils.controller import threadlocked +from ophyd_devices.utils.controller import Controller, threadlocked from ophyd_devices.utils.socket import raise_if_disconnected from prettytable import PrettyTable from typeguard import typechecked @@ -20,75 +18,15 @@ def channel_checked(fcn): return wrapper -class SocketIO: - """SocketIO helper class for TCP IP connections""" +class NPointController(Controller): - def __init__(self, sock=None): - self.is_open = False - if sock is None: - self.open() - else: - self.sock = sock - - def connect(self, host, port): - print(f"connecting to {host} port {port}") - # self.sock.create_connection((host, port)) - self.sock.connect((host, port)) - - def _put(self, msg_bytes): - return self.sock.send(msg_bytes) - - def _recv(self, buffer_length=1024): - return self.sock.recv(buffer_length) - - def _initialize_socket(self): - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(5) - - def put(self, msg): - return self._put(msg) - - def receive(self, buffer_length=1024): - return self._recv(buffer_length=buffer_length) - - def open(self): - self._initialize_socket() - self.is_open = True - - def close(self): - self.sock.close() - self.sock = None - self.is_open = False - - -class NPointController: - _controller_instance = None - - NUM_CHANNELS = 3 + _axes_per_controller = 3 _read_single_loc_bit = "A0" _write_single_loc_bit = "A2" _trailing_bit = "55" _range_offset = "78" _channel_base = ["11", "83"] - def __init__( - self, comm_socket: SocketIO, server_ip: str = "129.129.99.87", server_port: int = 23 - ) -> None: - self._lock = threading.RLock() - super().__init__() - self._server_and_port_name = (server_ip, server_port) - self.socket = comm_socket - self.connected = False - - def __new__(cls, *args, **kwargs): - if not NPointController._controller_instance: - NPointController._controller_instance = object.__new__(cls) - return NPointController._controller_instance - - @classmethod - def create(cls): - return cls(SocketIO()) - def show_all(self) -> None: """Display current status of all channels @@ -98,54 +36,15 @@ class NPointController: if not self.connected: print("npoint controller is currently disabled.") return - print(f"Connected to controller at {self._server_and_port_name}") + print(f"Connected to controller at {self._socket_host}:{self._socket_port}") t = PrettyTable() t.field_names = ["Channel", "Range", "Position", "Target"] - for ii in range(self.NUM_CHANNELS): + for ii in range(self._axes_per_controller): t.add_row( [ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)] ) print(t) - @threadlocked - def on(self) -> None: - """Enable the NPoint controller and open a new socket. - - Raises: - TimeoutError: Raised if the socket connection raises a timeout. - - Returns: - None - """ - if self.connected: - print("You are already connected to the NPoint controller.") - return - if not self.socket.is_open: - self.socket.open() - try: - self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1]) - except socket.timeout: - raise TimeoutError( - f"Failed to connect to the specified server and port {self._server_and_port_name}." - ) - except OSError: - print("ERROR while connecting. Let's try again") - self.socket.close() - time.sleep(0.5) - self.socket.open() - self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1]) - self.connected = True - - @threadlocked - def off(self) -> None: - """Disable the controller and close the socket. - - Returns: - None - """ - self.socket.close() - self.connected = False - @channel_checked def _get_range(self, channel: int) -> int: """Get the range of the specified channel axis. @@ -250,7 +149,7 @@ class NPointController: """ buffer = b"".join([bytes.fromhex(m) for m in buffer]) - self.socket.put(buffer) + self.sock.put(buffer) @threadlocked def _put_and_receive(self, msg_hex_list: list) -> list: @@ -264,8 +163,8 @@ class NPointController: """ buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list]) - self.socket.put(buffer) - recv_msg = self.socket.receive() + self.sock.put(buffer) + recv_msg = self.sock.receive() recv_hex_list = [hex(m) for m in recv_msg] self._verify_received_msg(msg_hex_list, recv_hex_list) return recv_hex_list @@ -293,9 +192,9 @@ class NPointController: raise RuntimeError("Connection failure. Please restart the controller.") def _check_channel(self, channel: int) -> None: - if channel >= self.NUM_CHANNELS: + if channel >= self._axes_per_controller: raise ValueError( - f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})" + f"Channel {channel+1} exceeds the available number of channels ({self._axes_per_controller})" ) @staticmethod @@ -498,29 +397,6 @@ class NPointAxis: print(f"Setting the npoint settling time to {val:.2f} s.") -class NPointEpics(NPointAxis): - def __init__(self, controller: NPointController, channel: int, name: str) -> None: - super().__init__(controller, channel, name) - self.low_limit = -50 - self.high_limit = 50 - self._prefix = name - - def get_pv(self) -> str: - return self.name - - def get_position(self, readback=True) -> float: - if readback: - return self.get() - else: - return self.get_target_pos() - - def within_limits(self, pos: float) -> bool: - return pos > self.low_limit and pos < self.high_limit - - def move(self, position: float, wait=True) -> None: - self.set(position) - - if __name__ == "__main__": ## EXAMPLES ## # diff --git a/tests/tests_devices/test_npoint_piezo.py b/tests/tests_devices/test_npoint_piezo.py index 795be36..e03e3e2 100644 --- a/tests/tests_devices/test_npoint_piezo.py +++ b/tests/tests_devices/test_npoint_piezo.py @@ -1,49 +1,27 @@ import copy +from unittest import mock import pytest from csaxs_bec.devices.npoint import NPointAxis, NPointController -class SocketMock: - def __init__(self, sock=None): - self.buffer_put = "" - self.buffer_recv = "" - self.is_open = False - if sock is None: - self.open() - else: - self.sock = sock +@pytest.fixture +def controller(): + with mock.patch("ophyd_devices.utils.socket.SocketIO") as socket_cls: + controller = NPointController( + socket_cls=socket_cls, socket_host="localhost", socket_port=1234 + ) + controller.on() + controller.sock.reset_mock() + yield controller + controller.off() - def connect(self, host, port): - print(f"connecting to {host} port {port}") - # self.sock.create_connection((host, port)) - # self.sock.connect((host, port)) - def _put(self, msg_bytes): - self.buffer_put = msg_bytes - print(self.buffer_put) - - def _recv(self, buffer_length=1024): - print(self.buffer_recv) - return self.buffer_recv - - def _initialize_socket(self): - pass - - def put(self, msg): - return self._put(msg) - - def receive(self, buffer_length=1024): - return self._recv(buffer_length=buffer_length) - - def open(self): - self._initialize_socket() - self.is_open = True - - def close(self): - self.sock = None - self.is_open = False +@pytest.fixture +def npointx(controller): + npointx = NPointAxis(controller, 0, "nx") + yield npointx @pytest.mark.parametrize( @@ -54,12 +32,9 @@ class SocketMock: (-5, b"\xa2\x18\x12\x83\x1133\xff\xffU"), ], ) -def test_axis_put(pos, msg): - controller = NPointController(SocketMock()) - npointx = NPointAxis(controller, 0, "nx") - controller.on() +def test_axis_put(npointx, pos, msg): npointx.set(pos) - assert npointx.controller.socket.buffer_put == msg + npointx.controller.sock.put.assert_called_with(msg) @pytest.mark.parametrize( @@ -70,13 +45,9 @@ def test_axis_put(pos, msg): (-5, b"\xa04\x13\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"), ], ) -def test_axis_get_out(pos, msg_in, msg_out): - controller = NPointController(SocketMock()) - npointx = NPointAxis(controller, 0, "nx") - controller.on() - npointx.controller.socket.buffer_recv = msg_out +def test_axis_get_out(npointx, pos, msg_in, msg_out): + npointx.controller.sock.receive.return_value = msg_out assert pytest.approx(npointx.get(), rel=0.01) == pos - # assert controller.socket.buffer_put == msg_in @pytest.mark.parametrize( @@ -87,29 +58,23 @@ def test_axis_get_out(pos, msg_in, msg_out): (2, b"\xa043\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"), ], ) -def test_axis_get_in(axis, msg_in, msg_out): - controller = NPointController(SocketMock()) - npointx = NPointAxis(controller, 0, "nx") - controller.on() - controller.socket.buffer_recv = msg_out - controller._get_current_pos(axis) - assert controller.socket.buffer_put == msg_in +def test_axis_get_in(npointx, axis, msg_in, msg_out): + npointx.controller.sock.receive.return_value = msg_out + npointx.controller._get_current_pos(axis) + npointx.controller.sock.put.assert_called_once_with(msg_in) -def test_axis_out_of_range(): - controller = NPointController(SocketMock()) +def test_axis_out_of_range(controller): with pytest.raises(ValueError): npointx = NPointAxis(controller, 3, "nx") -def test_get_axis_out_of_range(): - controller = NPointController(SocketMock()) +def test_get_axis_out_of_range(controller): with pytest.raises(ValueError): controller._get_current_pos(3) -def test_set_axis_out_of_range(): - controller = NPointController(SocketMock()) +def test_set_axis_out_of_range(controller): with pytest.raises(ValueError): controller._set_target_pos(3, 5) @@ -139,10 +104,8 @@ def test_hex_list_to_int(in_buffer, byteorder, signed, val): (2, b"\xa0x0\x83\x11U", b"\xa0\x78\x13\x83\x11\x64\x00\x00\x00U"), ], ) -def test_get_range(axis, msg_in, msg_out): - controller = NPointController(SocketMock()) - npointx = NPointAxis(controller, 0, "nx") - controller.on() - controller.socket.buffer_recv = msg_out - val = controller._get_range(axis) - assert controller.socket.buffer_put == msg_in and val == 100 +def test_get_range(npointx, axis, msg_in, msg_out): + npointx.controller.sock.receive.return_value = msg_out + val = npointx.controller._get_range(axis) + npointx.controller.sock.put.assert_called_once_with(msg_in) + assert val == 100