mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 11:41:49 +02:00
fix(tests): ensure threads started during plot tests are properly stopped
This commit is contained in:
@ -47,7 +47,13 @@ class EigerPlot(QWidget):
|
|||||||
self.key_bindings()
|
self.key_bindings()
|
||||||
|
|
||||||
# ZMQ Consumer
|
# ZMQ Consumer
|
||||||
self.start_zmq_consumer()
|
self._zmq_consumer_exit_event = threading.Event()
|
||||||
|
self._zmq_consumer_thread = self.start_zmq_consumer()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super().close()
|
||||||
|
self._zmq_consumer_exit_event.set()
|
||||||
|
self._zmq_consumer_thread.join()
|
||||||
|
|
||||||
def init_ui(self):
|
def init_ui(self):
|
||||||
# Create Plot and add ImageItem
|
# Create Plot and add ImageItem
|
||||||
@ -182,25 +188,36 @@ class EigerPlot(QWidget):
|
|||||||
###############################
|
###############################
|
||||||
|
|
||||||
def start_zmq_consumer(self):
|
def start_zmq_consumer(self):
|
||||||
consumer_thread = threading.Thread(target=self.zmq_consumer, daemon=True).start()
|
consumer_thread = threading.Thread(
|
||||||
|
target=self.zmq_consumer, args=(self._zmq_consumer_exit_event,), daemon=True
|
||||||
|
)
|
||||||
|
consumer_thread.start()
|
||||||
|
return consumer_thread
|
||||||
|
|
||||||
def zmq_consumer(self):
|
def zmq_consumer(self, exit_event):
|
||||||
try:
|
print("starting consumer")
|
||||||
print("starting consumer")
|
live_stream_url = "tcp://129.129.95.38:20000"
|
||||||
live_stream_url = "tcp://129.129.95.38:20000"
|
receiver = zmq.Context().socket(zmq.SUB)
|
||||||
receiver = zmq.Context().socket(zmq.SUB)
|
receiver.connect(live_stream_url)
|
||||||
receiver.connect(live_stream_url)
|
receiver.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||||
receiver.setsockopt_string(zmq.SUBSCRIBE, "")
|
|
||||||
|
poller = zmq.Poller()
|
||||||
|
poller.register(receiver, zmq.POLLIN)
|
||||||
|
|
||||||
|
# code could be a bit simpler here, testing exit_event in
|
||||||
|
# 'while' condition, but like this it is easier for the
|
||||||
|
# 'test_zmq_consumer' test
|
||||||
|
while True:
|
||||||
|
if poller.poll(1000): # 1s timeout
|
||||||
|
raw_meta, raw_data = receiver.recv_multipart(zmq.NOBLOCK)
|
||||||
|
|
||||||
while True:
|
|
||||||
raw_meta, raw_data = receiver.recv_multipart()
|
|
||||||
meta = json.loads(raw_meta.decode("utf-8"))
|
meta = json.loads(raw_meta.decode("utf-8"))
|
||||||
self.image = np.frombuffer(raw_data, dtype=meta["type"]).reshape(meta["shape"])
|
self.image = np.frombuffer(raw_data, dtype=meta["type"]).reshape(meta["shape"])
|
||||||
self.update_signal.emit()
|
self.update_signal.emit()
|
||||||
|
if exit_event.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
finally:
|
receiver.disconnect(live_stream_url)
|
||||||
receiver.disconnect(live_stream_url)
|
|
||||||
receiver.context.term()
|
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# just simulations from here
|
# just simulations from here
|
||||||
|
@ -54,7 +54,10 @@ class StreamPlot(QtWidgets.QWidget):
|
|||||||
|
|
||||||
self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update)
|
self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update)
|
||||||
|
|
||||||
self.data_retriever = threading.Thread(target=self.on_projection, daemon=True)
|
self._data_retriever_thread_exit_event = threading.Event()
|
||||||
|
self.data_retriever = threading.Thread(
|
||||||
|
target=self.on_projection, args=(self._data_retriever_thread_exit_event,), daemon=True
|
||||||
|
)
|
||||||
self.data_retriever.start()
|
self.data_retriever.start()
|
||||||
|
|
||||||
##########################
|
##########################
|
||||||
@ -64,6 +67,11 @@ class StreamPlot(QtWidgets.QWidget):
|
|||||||
self.init_curves()
|
self.init_curves()
|
||||||
self.hook_crosshair()
|
self.hook_crosshair()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
super().close()
|
||||||
|
self._data_retriever_thread_exit_event.set()
|
||||||
|
self.data_retriever.join()
|
||||||
|
|
||||||
def init_ui(self):
|
def init_ui(self):
|
||||||
"""Setup all ui elements"""
|
"""Setup all ui elements"""
|
||||||
##########################
|
##########################
|
||||||
@ -257,8 +265,8 @@ class StreamPlot(QtWidgets.QWidget):
|
|||||||
# else:
|
# else:
|
||||||
# return
|
# return
|
||||||
|
|
||||||
def on_projection(self):
|
def on_projection(self, exit_event):
|
||||||
while True:
|
while not exit_event.is_set():
|
||||||
if self._current_proj is None:
|
if self._current_proj is None:
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
|
@ -14,6 +14,7 @@ def eiger_plot_instance(qtbot):
|
|||||||
qtbot.addWidget(widget)
|
qtbot.addWidget(widget)
|
||||||
qtbot.waitExposed(widget)
|
qtbot.waitExposed(widget)
|
||||||
yield widget
|
yield widget
|
||||||
|
widget.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -80,27 +81,24 @@ def test_zmq_consumer(eiger_plot_instance, qtbot):
|
|||||||
fake_meta = json.dumps({"type": "int32", "shape": (2, 2)}).encode("utf-8")
|
fake_meta = json.dumps({"type": "int32", "shape": (2, 2)}).encode("utf-8")
|
||||||
fake_data = np.array([[1, 2], [3, 4]], dtype="int32").tobytes()
|
fake_data = np.array([[1, 2], [3, 4]], dtype="int32").tobytes()
|
||||||
|
|
||||||
with patch("zmq.Context") as MockContext:
|
with patch("zmq.Context", autospec=True) as MockContext:
|
||||||
MockContext.reset_mock() # Reset the mock here
|
|
||||||
|
|
||||||
# Mocking zmq socket and its methods
|
|
||||||
mock_socket = MagicMock()
|
mock_socket = MagicMock()
|
||||||
MockContext().socket.return_value = mock_socket
|
mock_socket.recv_multipart.side_effect = ((fake_meta, fake_data),)
|
||||||
mock_socket.recv_multipart.side_effect = [[fake_meta, fake_data], Exception("Break loop")]
|
MockContext.return_value.socket.return_value = mock_socket
|
||||||
|
|
||||||
# Mocking the update_signal to check if it gets emitted
|
# Mocking the update_signal to check if it gets emitted
|
||||||
eiger_plot_instance.update_signal = MagicMock()
|
eiger_plot_instance.update_signal = MagicMock()
|
||||||
|
|
||||||
try:
|
with patch("zmq.Poller"):
|
||||||
|
# will do only 1 iteration of the loop in the thread
|
||||||
|
eiger_plot_instance._zmq_consumer_exit_event.set()
|
||||||
# Run the method under test
|
# Run the method under test
|
||||||
eiger_plot_instance.zmq_consumer()
|
consumer_thread = eiger_plot_instance.start_zmq_consumer()
|
||||||
except Exception as e:
|
consumer_thread.join()
|
||||||
# Ensure the loop was broken by our mocked exception
|
|
||||||
assert str(e) == "Break loop"
|
|
||||||
|
|
||||||
# Check if zmq methods are called
|
# Check if zmq methods are called
|
||||||
# MockContext.assert_called_once()
|
# MockContext.assert_called_once()
|
||||||
assert MockContext.call_count == 2 # TODO why 2?
|
assert MockContext.call_count == 1
|
||||||
mock_socket.connect.assert_called_with("tcp://129.129.95.38:20000")
|
mock_socket.connect.assert_called_with("tcp://129.129.95.38:20000")
|
||||||
mock_socket.setsockopt_string.assert_called_with(zmq.SUBSCRIBE, "")
|
mock_socket.setsockopt_string.assert_called_with(zmq.SUBSCRIBE, "")
|
||||||
mock_socket.recv_multipart.assert_called()
|
mock_socket.recv_multipart.assert_called()
|
||||||
|
@ -17,6 +17,7 @@ def stream_app(qtbot):
|
|||||||
qtbot.addWidget(widget)
|
qtbot.addWidget(widget)
|
||||||
qtbot.waitExposed(widget)
|
qtbot.waitExposed(widget)
|
||||||
yield widget
|
yield widget
|
||||||
|
widget.close()
|
||||||
|
|
||||||
|
|
||||||
def test_roi_signals_emitted(qtbot, stream_app):
|
def test_roi_signals_emitted(qtbot, stream_app):
|
||||||
|
Reference in New Issue
Block a user