0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

fix(tests): ensure threads started during plot tests are properly stopped

This commit is contained in:
2024-01-20 23:01:41 +01:00
parent d909673071
commit 3fb6644543
4 changed files with 53 additions and 29 deletions

View File

@ -47,7 +47,13 @@ class EigerPlot(QWidget):
self.key_bindings() self.key_bindings()
# ZMQ Consumer # ZMQ Consumer
self.start_zmq_consumer() self._zmq_consumer_exit_event = threading.Event()
self._zmq_consumer_thread = self.start_zmq_consumer()
def close(self):
super().close()
self._zmq_consumer_exit_event.set()
self._zmq_consumer_thread.join()
def init_ui(self): def init_ui(self):
# Create Plot and add ImageItem # Create Plot and add ImageItem
@ -182,25 +188,36 @@ class EigerPlot(QWidget):
############################### ###############################
def start_zmq_consumer(self): def start_zmq_consumer(self):
consumer_thread = threading.Thread(target=self.zmq_consumer, daemon=True).start() consumer_thread = threading.Thread(
target=self.zmq_consumer, args=(self._zmq_consumer_exit_event,), daemon=True
)
consumer_thread.start()
return consumer_thread
def zmq_consumer(self): def zmq_consumer(self, exit_event):
try: print("starting consumer")
print("starting consumer") live_stream_url = "tcp://129.129.95.38:20000"
live_stream_url = "tcp://129.129.95.38:20000" receiver = zmq.Context().socket(zmq.SUB)
receiver = zmq.Context().socket(zmq.SUB) receiver.connect(live_stream_url)
receiver.connect(live_stream_url) receiver.setsockopt_string(zmq.SUBSCRIBE, "")
receiver.setsockopt_string(zmq.SUBSCRIBE, "")
poller = zmq.Poller()
poller.register(receiver, zmq.POLLIN)
# code could be a bit simpler here, testing exit_event in
# 'while' condition, but like this it is easier for the
# 'test_zmq_consumer' test
while True:
if poller.poll(1000): # 1s timeout
raw_meta, raw_data = receiver.recv_multipart(zmq.NOBLOCK)
while True:
raw_meta, raw_data = receiver.recv_multipart()
meta = json.loads(raw_meta.decode("utf-8")) meta = json.loads(raw_meta.decode("utf-8"))
self.image = np.frombuffer(raw_data, dtype=meta["type"]).reshape(meta["shape"]) self.image = np.frombuffer(raw_data, dtype=meta["type"]).reshape(meta["shape"])
self.update_signal.emit() self.update_signal.emit()
if exit_event.is_set():
break
finally: receiver.disconnect(live_stream_url)
receiver.disconnect(live_stream_url)
receiver.context.term()
############################### ###############################
# just simulations from here # just simulations from here

View File

@ -54,7 +54,10 @@ class StreamPlot(QtWidgets.QWidget):
self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update) self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update)
self.data_retriever = threading.Thread(target=self.on_projection, daemon=True) self._data_retriever_thread_exit_event = threading.Event()
self.data_retriever = threading.Thread(
target=self.on_projection, args=(self._data_retriever_thread_exit_event,), daemon=True
)
self.data_retriever.start() self.data_retriever.start()
########################## ##########################
@ -64,6 +67,11 @@ class StreamPlot(QtWidgets.QWidget):
self.init_curves() self.init_curves()
self.hook_crosshair() self.hook_crosshair()
def close(self):
super().close()
self._data_retriever_thread_exit_event.set()
self.data_retriever.join()
def init_ui(self): def init_ui(self):
"""Setup all ui elements""" """Setup all ui elements"""
########################## ##########################
@ -257,8 +265,8 @@ class StreamPlot(QtWidgets.QWidget):
# else: # else:
# return # return
def on_projection(self): def on_projection(self, exit_event):
while True: while not exit_event.is_set():
if self._current_proj is None: if self._current_proj is None:
time.sleep(0.1) time.sleep(0.1)
continue continue

View File

@ -14,6 +14,7 @@ def eiger_plot_instance(qtbot):
qtbot.addWidget(widget) qtbot.addWidget(widget)
qtbot.waitExposed(widget) qtbot.waitExposed(widget)
yield widget yield widget
widget.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -80,27 +81,24 @@ def test_zmq_consumer(eiger_plot_instance, qtbot):
fake_meta = json.dumps({"type": "int32", "shape": (2, 2)}).encode("utf-8") fake_meta = json.dumps({"type": "int32", "shape": (2, 2)}).encode("utf-8")
fake_data = np.array([[1, 2], [3, 4]], dtype="int32").tobytes() fake_data = np.array([[1, 2], [3, 4]], dtype="int32").tobytes()
with patch("zmq.Context") as MockContext: with patch("zmq.Context", autospec=True) as MockContext:
MockContext.reset_mock() # Reset the mock here
# Mocking zmq socket and its methods
mock_socket = MagicMock() mock_socket = MagicMock()
MockContext().socket.return_value = mock_socket mock_socket.recv_multipart.side_effect = ((fake_meta, fake_data),)
mock_socket.recv_multipart.side_effect = [[fake_meta, fake_data], Exception("Break loop")] MockContext.return_value.socket.return_value = mock_socket
# Mocking the update_signal to check if it gets emitted # Mocking the update_signal to check if it gets emitted
eiger_plot_instance.update_signal = MagicMock() eiger_plot_instance.update_signal = MagicMock()
try: with patch("zmq.Poller"):
# will do only 1 iteration of the loop in the thread
eiger_plot_instance._zmq_consumer_exit_event.set()
# Run the method under test # Run the method under test
eiger_plot_instance.zmq_consumer() consumer_thread = eiger_plot_instance.start_zmq_consumer()
except Exception as e: consumer_thread.join()
# Ensure the loop was broken by our mocked exception
assert str(e) == "Break loop"
# Check if zmq methods are called # Check if zmq methods are called
# MockContext.assert_called_once() # MockContext.assert_called_once()
assert MockContext.call_count == 2 # TODO why 2? assert MockContext.call_count == 1
mock_socket.connect.assert_called_with("tcp://129.129.95.38:20000") mock_socket.connect.assert_called_with("tcp://129.129.95.38:20000")
mock_socket.setsockopt_string.assert_called_with(zmq.SUBSCRIBE, "") mock_socket.setsockopt_string.assert_called_with(zmq.SUBSCRIBE, "")
mock_socket.recv_multipart.assert_called() mock_socket.recv_multipart.assert_called()

View File

@ -17,6 +17,7 @@ def stream_app(qtbot):
qtbot.addWidget(widget) qtbot.addWidget(widget)
qtbot.waitExposed(widget) qtbot.waitExposed(widget)
yield widget yield widget
widget.close()
def test_roi_signals_emitted(qtbot, stream_app): def test_roi_signals_emitted(qtbot, stream_app):