refactor(npoint): cleanup

This commit is contained in:
2024-09-18 21:49:49 +02:00
parent 59e0755e14
commit 0d2b4c4423
2 changed files with 41 additions and 202 deletions

View File

@@ -1,9 +1,7 @@
import functools
import socket
import threading
import time
from ophyd_devices.utils.controller import threadlocked
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
from prettytable import PrettyTable
from typeguard import typechecked
@@ -20,75 +18,15 @@ def channel_checked(fcn):
return wrapper
class SocketIO:
"""SocketIO helper class for TCP IP connections"""
class NPointController(Controller):
def __init__(self, sock=None):
self.is_open = False
if sock is None:
self.open()
else:
self.sock = sock
def connect(self, host, port):
print(f"connecting to {host} port {port}")
# self.sock.create_connection((host, port))
self.sock.connect((host, port))
def _put(self, msg_bytes):
return self.sock.send(msg_bytes)
def _recv(self, buffer_length=1024):
return self.sock.recv(buffer_length)
def _initialize_socket(self):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.settimeout(5)
def put(self, msg):
return self._put(msg)
def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length)
def open(self):
self._initialize_socket()
self.is_open = True
def close(self):
self.sock.close()
self.sock = None
self.is_open = False
class NPointController:
_controller_instance = None
NUM_CHANNELS = 3
_axes_per_controller = 3
_read_single_loc_bit = "A0"
_write_single_loc_bit = "A2"
_trailing_bit = "55"
_range_offset = "78"
_channel_base = ["11", "83"]
def __init__(
self, comm_socket: SocketIO, server_ip: str = "129.129.99.87", server_port: int = 23
) -> None:
self._lock = threading.RLock()
super().__init__()
self._server_and_port_name = (server_ip, server_port)
self.socket = comm_socket
self.connected = False
def __new__(cls, *args, **kwargs):
if not NPointController._controller_instance:
NPointController._controller_instance = object.__new__(cls)
return NPointController._controller_instance
@classmethod
def create(cls):
return cls(SocketIO())
def show_all(self) -> None:
"""Display current status of all channels
@@ -98,54 +36,15 @@ class NPointController:
if not self.connected:
print("npoint controller is currently disabled.")
return
print(f"Connected to controller at {self._server_and_port_name}")
print(f"Connected to controller at {self._socket_host}:{self._socket_port}")
t = PrettyTable()
t.field_names = ["Channel", "Range", "Position", "Target"]
for ii in range(self.NUM_CHANNELS):
for ii in range(self._axes_per_controller):
t.add_row(
[ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)]
)
print(t)
@threadlocked
def on(self) -> None:
"""Enable the NPoint controller and open a new socket.
Raises:
TimeoutError: Raised if the socket connection raises a timeout.
Returns:
None
"""
if self.connected:
print("You are already connected to the NPoint controller.")
return
if not self.socket.is_open:
self.socket.open()
try:
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
except socket.timeout:
raise TimeoutError(
f"Failed to connect to the specified server and port {self._server_and_port_name}."
)
except OSError:
print("ERROR while connecting. Let's try again")
self.socket.close()
time.sleep(0.5)
self.socket.open()
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
self.connected = True
@threadlocked
def off(self) -> None:
"""Disable the controller and close the socket.
Returns:
None
"""
self.socket.close()
self.connected = False
@channel_checked
def _get_range(self, channel: int) -> int:
"""Get the range of the specified channel axis.
@@ -250,7 +149,7 @@ class NPointController:
"""
buffer = b"".join([bytes.fromhex(m) for m in buffer])
self.socket.put(buffer)
self.sock.put(buffer)
@threadlocked
def _put_and_receive(self, msg_hex_list: list) -> list:
@@ -264,8 +163,8 @@ class NPointController:
"""
buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list])
self.socket.put(buffer)
recv_msg = self.socket.receive()
self.sock.put(buffer)
recv_msg = self.sock.receive()
recv_hex_list = [hex(m) for m in recv_msg]
self._verify_received_msg(msg_hex_list, recv_hex_list)
return recv_hex_list
@@ -293,9 +192,9 @@ class NPointController:
raise RuntimeError("Connection failure. Please restart the controller.")
def _check_channel(self, channel: int) -> None:
if channel >= self.NUM_CHANNELS:
if channel >= self._axes_per_controller:
raise ValueError(
f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})"
f"Channel {channel+1} exceeds the available number of channels ({self._axes_per_controller})"
)
@staticmethod
@@ -498,29 +397,6 @@ class NPointAxis:
print(f"Setting the npoint settling time to {val:.2f} s.")
class NPointEpics(NPointAxis):
def __init__(self, controller: NPointController, channel: int, name: str) -> None:
super().__init__(controller, channel, name)
self.low_limit = -50
self.high_limit = 50
self._prefix = name
def get_pv(self) -> str:
return self.name
def get_position(self, readback=True) -> float:
if readback:
return self.get()
else:
return self.get_target_pos()
def within_limits(self, pos: float) -> bool:
return pos > self.low_limit and pos < self.high_limit
def move(self, position: float, wait=True) -> None:
self.set(position)
if __name__ == "__main__":
## EXAMPLES ##
#

View File

@@ -1,49 +1,27 @@
import copy
from unittest import mock
import pytest
from csaxs_bec.devices.npoint import NPointAxis, NPointController
class SocketMock:
def __init__(self, sock=None):
self.buffer_put = ""
self.buffer_recv = ""
self.is_open = False
if sock is None:
self.open()
else:
self.sock = sock
@pytest.fixture
def controller():
with mock.patch("ophyd_devices.utils.socket.SocketIO") as socket_cls:
controller = NPointController(
socket_cls=socket_cls, socket_host="localhost", socket_port=1234
)
controller.on()
controller.sock.reset_mock()
yield controller
controller.off()
def connect(self, host, port):
print(f"connecting to {host} port {port}")
# self.sock.create_connection((host, port))
# self.sock.connect((host, port))
def _put(self, msg_bytes):
self.buffer_put = msg_bytes
print(self.buffer_put)
def _recv(self, buffer_length=1024):
print(self.buffer_recv)
return self.buffer_recv
def _initialize_socket(self):
pass
def put(self, msg):
return self._put(msg)
def receive(self, buffer_length=1024):
return self._recv(buffer_length=buffer_length)
def open(self):
self._initialize_socket()
self.is_open = True
def close(self):
self.sock = None
self.is_open = False
@pytest.fixture
def npointx(controller):
npointx = NPointAxis(controller, 0, "nx")
yield npointx
@pytest.mark.parametrize(
@@ -54,12 +32,9 @@ class SocketMock:
(-5, b"\xa2\x18\x12\x83\x1133\xff\xffU"),
],
)
def test_axis_put(pos, msg):
controller = NPointController(SocketMock())
npointx = NPointAxis(controller, 0, "nx")
controller.on()
def test_axis_put(npointx, pos, msg):
npointx.set(pos)
assert npointx.controller.socket.buffer_put == msg
npointx.controller.sock.put.assert_called_with(msg)
@pytest.mark.parametrize(
@@ -70,13 +45,9 @@ def test_axis_put(pos, msg):
(-5, b"\xa04\x13\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"),
],
)
def test_axis_get_out(pos, msg_in, msg_out):
controller = NPointController(SocketMock())
npointx = NPointAxis(controller, 0, "nx")
controller.on()
npointx.controller.socket.buffer_recv = msg_out
def test_axis_get_out(npointx, pos, msg_in, msg_out):
npointx.controller.sock.receive.return_value = msg_out
assert pytest.approx(npointx.get(), rel=0.01) == pos
# assert controller.socket.buffer_put == msg_in
@pytest.mark.parametrize(
@@ -87,29 +58,23 @@ def test_axis_get_out(pos, msg_in, msg_out):
(2, b"\xa043\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"),
],
)
def test_axis_get_in(axis, msg_in, msg_out):
controller = NPointController(SocketMock())
npointx = NPointAxis(controller, 0, "nx")
controller.on()
controller.socket.buffer_recv = msg_out
controller._get_current_pos(axis)
assert controller.socket.buffer_put == msg_in
def test_axis_get_in(npointx, axis, msg_in, msg_out):
npointx.controller.sock.receive.return_value = msg_out
npointx.controller._get_current_pos(axis)
npointx.controller.sock.put.assert_called_once_with(msg_in)
def test_axis_out_of_range():
controller = NPointController(SocketMock())
def test_axis_out_of_range(controller):
with pytest.raises(ValueError):
npointx = NPointAxis(controller, 3, "nx")
def test_get_axis_out_of_range():
controller = NPointController(SocketMock())
def test_get_axis_out_of_range(controller):
with pytest.raises(ValueError):
controller._get_current_pos(3)
def test_set_axis_out_of_range():
controller = NPointController(SocketMock())
def test_set_axis_out_of_range(controller):
with pytest.raises(ValueError):
controller._set_target_pos(3, 5)
@@ -139,10 +104,8 @@ def test_hex_list_to_int(in_buffer, byteorder, signed, val):
(2, b"\xa0x0\x83\x11U", b"\xa0\x78\x13\x83\x11\x64\x00\x00\x00U"),
],
)
def test_get_range(axis, msg_in, msg_out):
controller = NPointController(SocketMock())
npointx = NPointAxis(controller, 0, "nx")
controller.on()
controller.socket.buffer_recv = msg_out
val = controller._get_range(axis)
assert controller.socket.buffer_put == msg_in and val == 100
def test_get_range(npointx, axis, msg_in, msg_out):
npointx.controller.sock.receive.return_value = msg_out
val = npointx.controller._get_range(axis)
npointx.controller.sock.put.assert_called_once_with(msg_in)
assert val == 100