feat(controller): added option to specify a socket to use instead of the internal socket

This commit is contained in:
2024-07-29 20:58:31 +02:00
parent 9f46069959
commit 5ce19e9c5e

View File

@ -93,34 +93,48 @@ class Controller(OphydObject):
self._socket_port = socket_port self._socket_port = socket_port
@threadlocked @threadlocked
def socket_put(self, val: str): def socket_put(self, val: str, socket=None) -> None:
""" """
Send a command to the controller through the socket. Send a command to the controller through the socket.
Args: Args:
val (str): Command to send val (str): Command to send
socket (socket.socket): Socket object to use. If None, the default socket is used.
""" """
self.sock.put(f"{val}\n".encode()) sock = socket or self.sock
sock.put(f"{val}\n".encode())
@threadlocked @threadlocked
def socket_get(self): def socket_get(self, socket=None) -> str:
""" """
Receive a response from the controller through the socket. Receive a response from the controller through the socket.
Args:
socket (socket.socket): Socket object to use. If None, the default socket is used.
Returns:
str: Response from the controller
""" """
return self.sock.receive().decode() sock = socket or self.sock
return sock.receive().decode()
@retry_once @retry_once
@threadlocked @threadlocked
def socket_put_and_receive(self, val: str, remove_trailing_chars=True) -> str: def socket_put_and_receive(self, val: str, remove_trailing_chars=True, socket=None) -> str:
""" """
Send a command to the controller and receive the response. Send a command to the controller and receive the response.
Override this method in the derived class if necessary, especially if the response Override this method in the derived class if necessary, especially if the response
needs to be parsed differently. needs to be parsed differently.
Args:
val (str): Command to send
remove_trailing_chars (bool): Remove trailing characters from the response
socket (socket.socket): Socket object to use. If None, the default socket is used.
""" """
self.socket_put(val) self.socket_put(val, socket)
if remove_trailing_chars: if remove_trailing_chars:
return self._remove_trailing_characters(self.sock.receive().decode()) return self._remove_trailing_characters(self.sock.receive().decode())
return self.socket_get() return self.socket_get(socket)
def _remove_trailing_characters(self, var) -> str: def _remove_trailing_characters(self, var) -> str:
if len(var) > 1: if len(var) > 1: