refactor(npoint): cleanup
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import functools
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
|
||||
from ophyd_devices.utils.controller import threadlocked
|
||||
from ophyd_devices.utils.controller import Controller, threadlocked
|
||||
from ophyd_devices.utils.socket import raise_if_disconnected
|
||||
from prettytable import PrettyTable
|
||||
from typeguard import typechecked
|
||||
@@ -20,75 +18,15 @@ def channel_checked(fcn):
|
||||
return wrapper
|
||||
|
||||
|
||||
class SocketIO:
|
||||
"""SocketIO helper class for TCP IP connections"""
|
||||
class NPointController(Controller):
|
||||
|
||||
def __init__(self, sock=None):
|
||||
self.is_open = False
|
||||
if sock is None:
|
||||
self.open()
|
||||
else:
|
||||
self.sock = sock
|
||||
|
||||
def connect(self, host, port):
|
||||
print(f"connecting to {host} port {port}")
|
||||
# self.sock.create_connection((host, port))
|
||||
self.sock.connect((host, port))
|
||||
|
||||
def _put(self, msg_bytes):
|
||||
return self.sock.send(msg_bytes)
|
||||
|
||||
def _recv(self, buffer_length=1024):
|
||||
return self.sock.recv(buffer_length)
|
||||
|
||||
def _initialize_socket(self):
|
||||
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.sock.settimeout(5)
|
||||
|
||||
def put(self, msg):
|
||||
return self._put(msg)
|
||||
|
||||
def receive(self, buffer_length=1024):
|
||||
return self._recv(buffer_length=buffer_length)
|
||||
|
||||
def open(self):
|
||||
self._initialize_socket()
|
||||
self.is_open = True
|
||||
|
||||
def close(self):
|
||||
self.sock.close()
|
||||
self.sock = None
|
||||
self.is_open = False
|
||||
|
||||
|
||||
class NPointController:
|
||||
_controller_instance = None
|
||||
|
||||
NUM_CHANNELS = 3
|
||||
_axes_per_controller = 3
|
||||
_read_single_loc_bit = "A0"
|
||||
_write_single_loc_bit = "A2"
|
||||
_trailing_bit = "55"
|
||||
_range_offset = "78"
|
||||
_channel_base = ["11", "83"]
|
||||
|
||||
def __init__(
|
||||
self, comm_socket: SocketIO, server_ip: str = "129.129.99.87", server_port: int = 23
|
||||
) -> None:
|
||||
self._lock = threading.RLock()
|
||||
super().__init__()
|
||||
self._server_and_port_name = (server_ip, server_port)
|
||||
self.socket = comm_socket
|
||||
self.connected = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not NPointController._controller_instance:
|
||||
NPointController._controller_instance = object.__new__(cls)
|
||||
return NPointController._controller_instance
|
||||
|
||||
@classmethod
|
||||
def create(cls):
|
||||
return cls(SocketIO())
|
||||
|
||||
def show_all(self) -> None:
|
||||
"""Display current status of all channels
|
||||
|
||||
@@ -98,54 +36,15 @@ class NPointController:
|
||||
if not self.connected:
|
||||
print("npoint controller is currently disabled.")
|
||||
return
|
||||
print(f"Connected to controller at {self._server_and_port_name}")
|
||||
print(f"Connected to controller at {self._socket_host}:{self._socket_port}")
|
||||
t = PrettyTable()
|
||||
t.field_names = ["Channel", "Range", "Position", "Target"]
|
||||
for ii in range(self.NUM_CHANNELS):
|
||||
for ii in range(self._axes_per_controller):
|
||||
t.add_row(
|
||||
[ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)]
|
||||
)
|
||||
print(t)
|
||||
|
||||
@threadlocked
|
||||
def on(self) -> None:
|
||||
"""Enable the NPoint controller and open a new socket.
|
||||
|
||||
Raises:
|
||||
TimeoutError: Raised if the socket connection raises a timeout.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if self.connected:
|
||||
print("You are already connected to the NPoint controller.")
|
||||
return
|
||||
if not self.socket.is_open:
|
||||
self.socket.open()
|
||||
try:
|
||||
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
|
||||
except socket.timeout:
|
||||
raise TimeoutError(
|
||||
f"Failed to connect to the specified server and port {self._server_and_port_name}."
|
||||
)
|
||||
except OSError:
|
||||
print("ERROR while connecting. Let's try again")
|
||||
self.socket.close()
|
||||
time.sleep(0.5)
|
||||
self.socket.open()
|
||||
self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1])
|
||||
self.connected = True
|
||||
|
||||
@threadlocked
|
||||
def off(self) -> None:
|
||||
"""Disable the controller and close the socket.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.socket.close()
|
||||
self.connected = False
|
||||
|
||||
@channel_checked
|
||||
def _get_range(self, channel: int) -> int:
|
||||
"""Get the range of the specified channel axis.
|
||||
@@ -250,7 +149,7 @@ class NPointController:
|
||||
"""
|
||||
|
||||
buffer = b"".join([bytes.fromhex(m) for m in buffer])
|
||||
self.socket.put(buffer)
|
||||
self.sock.put(buffer)
|
||||
|
||||
@threadlocked
|
||||
def _put_and_receive(self, msg_hex_list: list) -> list:
|
||||
@@ -264,8 +163,8 @@ class NPointController:
|
||||
"""
|
||||
|
||||
buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list])
|
||||
self.socket.put(buffer)
|
||||
recv_msg = self.socket.receive()
|
||||
self.sock.put(buffer)
|
||||
recv_msg = self.sock.receive()
|
||||
recv_hex_list = [hex(m) for m in recv_msg]
|
||||
self._verify_received_msg(msg_hex_list, recv_hex_list)
|
||||
return recv_hex_list
|
||||
@@ -293,9 +192,9 @@ class NPointController:
|
||||
raise RuntimeError("Connection failure. Please restart the controller.")
|
||||
|
||||
def _check_channel(self, channel: int) -> None:
|
||||
if channel >= self.NUM_CHANNELS:
|
||||
if channel >= self._axes_per_controller:
|
||||
raise ValueError(
|
||||
f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})"
|
||||
f"Channel {channel+1} exceeds the available number of channels ({self._axes_per_controller})"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -498,29 +397,6 @@ class NPointAxis:
|
||||
print(f"Setting the npoint settling time to {val:.2f} s.")
|
||||
|
||||
|
||||
class NPointEpics(NPointAxis):
|
||||
def __init__(self, controller: NPointController, channel: int, name: str) -> None:
|
||||
super().__init__(controller, channel, name)
|
||||
self.low_limit = -50
|
||||
self.high_limit = 50
|
||||
self._prefix = name
|
||||
|
||||
def get_pv(self) -> str:
|
||||
return self.name
|
||||
|
||||
def get_position(self, readback=True) -> float:
|
||||
if readback:
|
||||
return self.get()
|
||||
else:
|
||||
return self.get_target_pos()
|
||||
|
||||
def within_limits(self, pos: float) -> bool:
|
||||
return pos > self.low_limit and pos < self.high_limit
|
||||
|
||||
def move(self, position: float, wait=True) -> None:
|
||||
self.set(position)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## EXAMPLES ##
|
||||
#
|
||||
|
||||
@@ -1,49 +1,27 @@
|
||||
import copy
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from csaxs_bec.devices.npoint import NPointAxis, NPointController
|
||||
|
||||
|
||||
class SocketMock:
|
||||
def __init__(self, sock=None):
|
||||
self.buffer_put = ""
|
||||
self.buffer_recv = ""
|
||||
self.is_open = False
|
||||
if sock is None:
|
||||
self.open()
|
||||
else:
|
||||
self.sock = sock
|
||||
@pytest.fixture
|
||||
def controller():
|
||||
with mock.patch("ophyd_devices.utils.socket.SocketIO") as socket_cls:
|
||||
controller = NPointController(
|
||||
socket_cls=socket_cls, socket_host="localhost", socket_port=1234
|
||||
)
|
||||
controller.on()
|
||||
controller.sock.reset_mock()
|
||||
yield controller
|
||||
controller.off()
|
||||
|
||||
def connect(self, host, port):
|
||||
print(f"connecting to {host} port {port}")
|
||||
# self.sock.create_connection((host, port))
|
||||
# self.sock.connect((host, port))
|
||||
|
||||
def _put(self, msg_bytes):
|
||||
self.buffer_put = msg_bytes
|
||||
print(self.buffer_put)
|
||||
|
||||
def _recv(self, buffer_length=1024):
|
||||
print(self.buffer_recv)
|
||||
return self.buffer_recv
|
||||
|
||||
def _initialize_socket(self):
|
||||
pass
|
||||
|
||||
def put(self, msg):
|
||||
return self._put(msg)
|
||||
|
||||
def receive(self, buffer_length=1024):
|
||||
return self._recv(buffer_length=buffer_length)
|
||||
|
||||
def open(self):
|
||||
self._initialize_socket()
|
||||
self.is_open = True
|
||||
|
||||
def close(self):
|
||||
self.sock = None
|
||||
self.is_open = False
|
||||
@pytest.fixture
|
||||
def npointx(controller):
|
||||
npointx = NPointAxis(controller, 0, "nx")
|
||||
yield npointx
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -54,12 +32,9 @@ class SocketMock:
|
||||
(-5, b"\xa2\x18\x12\x83\x1133\xff\xffU"),
|
||||
],
|
||||
)
|
||||
def test_axis_put(pos, msg):
|
||||
controller = NPointController(SocketMock())
|
||||
npointx = NPointAxis(controller, 0, "nx")
|
||||
controller.on()
|
||||
def test_axis_put(npointx, pos, msg):
|
||||
npointx.set(pos)
|
||||
assert npointx.controller.socket.buffer_put == msg
|
||||
npointx.controller.sock.put.assert_called_with(msg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -70,13 +45,9 @@ def test_axis_put(pos, msg):
|
||||
(-5, b"\xa04\x13\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"),
|
||||
],
|
||||
)
|
||||
def test_axis_get_out(pos, msg_in, msg_out):
|
||||
controller = NPointController(SocketMock())
|
||||
npointx = NPointAxis(controller, 0, "nx")
|
||||
controller.on()
|
||||
npointx.controller.socket.buffer_recv = msg_out
|
||||
def test_axis_get_out(npointx, pos, msg_in, msg_out):
|
||||
npointx.controller.sock.receive.return_value = msg_out
|
||||
assert pytest.approx(npointx.get(), rel=0.01) == pos
|
||||
# assert controller.socket.buffer_put == msg_in
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -87,29 +58,23 @@ def test_axis_get_out(pos, msg_in, msg_out):
|
||||
(2, b"\xa043\x83\x11U", b"\xa0\x34\x13\x83\x1133\xff\xffU"),
|
||||
],
|
||||
)
|
||||
def test_axis_get_in(axis, msg_in, msg_out):
|
||||
controller = NPointController(SocketMock())
|
||||
npointx = NPointAxis(controller, 0, "nx")
|
||||
controller.on()
|
||||
controller.socket.buffer_recv = msg_out
|
||||
controller._get_current_pos(axis)
|
||||
assert controller.socket.buffer_put == msg_in
|
||||
def test_axis_get_in(npointx, axis, msg_in, msg_out):
|
||||
npointx.controller.sock.receive.return_value = msg_out
|
||||
npointx.controller._get_current_pos(axis)
|
||||
npointx.controller.sock.put.assert_called_once_with(msg_in)
|
||||
|
||||
|
||||
def test_axis_out_of_range():
|
||||
controller = NPointController(SocketMock())
|
||||
def test_axis_out_of_range(controller):
|
||||
with pytest.raises(ValueError):
|
||||
npointx = NPointAxis(controller, 3, "nx")
|
||||
|
||||
|
||||
def test_get_axis_out_of_range():
|
||||
controller = NPointController(SocketMock())
|
||||
def test_get_axis_out_of_range(controller):
|
||||
with pytest.raises(ValueError):
|
||||
controller._get_current_pos(3)
|
||||
|
||||
|
||||
def test_set_axis_out_of_range():
|
||||
controller = NPointController(SocketMock())
|
||||
def test_set_axis_out_of_range(controller):
|
||||
with pytest.raises(ValueError):
|
||||
controller._set_target_pos(3, 5)
|
||||
|
||||
@@ -139,10 +104,8 @@ def test_hex_list_to_int(in_buffer, byteorder, signed, val):
|
||||
(2, b"\xa0x0\x83\x11U", b"\xa0\x78\x13\x83\x11\x64\x00\x00\x00U"),
|
||||
],
|
||||
)
|
||||
def test_get_range(axis, msg_in, msg_out):
|
||||
controller = NPointController(SocketMock())
|
||||
npointx = NPointAxis(controller, 0, "nx")
|
||||
controller.on()
|
||||
controller.socket.buffer_recv = msg_out
|
||||
val = controller._get_range(axis)
|
||||
assert controller.socket.buffer_put == msg_in and val == 100
|
||||
def test_get_range(npointx, axis, msg_in, msg_out):
|
||||
npointx.controller.sock.receive.return_value = msg_out
|
||||
val = npointx.controller._get_range(axis)
|
||||
npointx.controller.sock.put.assert_called_once_with(msg_in)
|
||||
assert val == 100
|
||||
|
||||
Reference in New Issue
Block a user