diff --git a/bec_widgets/examples/eiger_plot/eiger_plot.py b/bec_widgets/examples/eiger_plot/eiger_plot.py index 1cbc7bd6..d6f8e5ea 100644 --- a/bec_widgets/examples/eiger_plot/eiger_plot.py +++ b/bec_widgets/examples/eiger_plot/eiger_plot.py @@ -47,7 +47,13 @@ class EigerPlot(QWidget): self.key_bindings() # ZMQ Consumer - self.start_zmq_consumer() + self._zmq_consumer_exit_event = threading.Event() + self._zmq_consumer_thread = self.start_zmq_consumer() + + def close(self): + super().close() + self._zmq_consumer_exit_event.set() + self._zmq_consumer_thread.join() def init_ui(self): # Create Plot and add ImageItem @@ -182,25 +188,36 @@ class EigerPlot(QWidget): ############################### def start_zmq_consumer(self): - consumer_thread = threading.Thread(target=self.zmq_consumer, daemon=True).start() + consumer_thread = threading.Thread( + target=self.zmq_consumer, args=(self._zmq_consumer_exit_event,), daemon=True + ) + consumer_thread.start() + return consumer_thread - def zmq_consumer(self): - try: - print("starting consumer") - live_stream_url = "tcp://129.129.95.38:20000" - receiver = zmq.Context().socket(zmq.SUB) - receiver.connect(live_stream_url) - receiver.setsockopt_string(zmq.SUBSCRIBE, "") + def zmq_consumer(self, exit_event): + print("starting consumer") + live_stream_url = "tcp://129.129.95.38:20000" + receiver = zmq.Context().socket(zmq.SUB) + receiver.connect(live_stream_url) + receiver.setsockopt_string(zmq.SUBSCRIBE, "") + + poller = zmq.Poller() + poller.register(receiver, zmq.POLLIN) + + # code could be a bit simpler here, testing exit_event in + # 'while' condition, but like this it is easier for the + # 'test_zmq_consumer' test + while True: + if poller.poll(1000): # 1s timeout + raw_meta, raw_data = receiver.recv_multipart(zmq.NOBLOCK) - while True: - raw_meta, raw_data = receiver.recv_multipart() meta = json.loads(raw_meta.decode("utf-8")) self.image = np.frombuffer(raw_data, dtype=meta["type"]).reshape(meta["shape"]) self.update_signal.emit() + if exit_event.is_set(): + break - finally: - receiver.disconnect(live_stream_url) - receiver.context.term() + receiver.disconnect(live_stream_url) ############################### # just simulations from here diff --git a/bec_widgets/examples/stream_plot/stream_plot.py b/bec_widgets/examples/stream_plot/stream_plot.py index 3fe12260..688ca603 100644 --- a/bec_widgets/examples/stream_plot/stream_plot.py +++ b/bec_widgets/examples/stream_plot/stream_plot.py @@ -54,7 +54,10 @@ class StreamPlot(QtWidgets.QWidget): self.proxy_update = pg.SignalProxy(self.update_signal, rateLimit=25, slot=self.update) - self.data_retriever = threading.Thread(target=self.on_projection, daemon=True) + self._data_retriever_thread_exit_event = threading.Event() + self.data_retriever = threading.Thread( + target=self.on_projection, args=(self._data_retriever_thread_exit_event,), daemon=True + ) self.data_retriever.start() ########################## @@ -64,6 +67,11 @@ class StreamPlot(QtWidgets.QWidget): self.init_curves() self.hook_crosshair() + def close(self): + super().close() + self._data_retriever_thread_exit_event.set() + self.data_retriever.join() + def init_ui(self): """Setup all ui elements""" ########################## @@ -257,8 +265,8 @@ class StreamPlot(QtWidgets.QWidget): # else: # return - def on_projection(self): - while True: + def on_projection(self, exit_event): + while not exit_event.is_set(): if self._current_proj is None: time.sleep(0.1) continue diff --git a/tests/test_eiger_plot.py b/tests/test_eiger_plot.py index f3e17bf8..943785e7 100644 --- a/tests/test_eiger_plot.py +++ b/tests/test_eiger_plot.py @@ -14,6 +14,7 @@ def eiger_plot_instance(qtbot): qtbot.addWidget(widget) qtbot.waitExposed(widget) yield widget + widget.close() @pytest.mark.parametrize( @@ -80,27 +81,24 @@ def test_zmq_consumer(eiger_plot_instance, qtbot): fake_meta = json.dumps({"type": "int32", "shape": (2, 2)}).encode("utf-8") fake_data = np.array([[1, 2], [3, 4]], dtype="int32").tobytes() - with patch("zmq.Context") as MockContext: - MockContext.reset_mock() # Reset the mock here - - # Mocking zmq socket and its methods + with patch("zmq.Context", autospec=True) as MockContext: mock_socket = MagicMock() - MockContext().socket.return_value = mock_socket - mock_socket.recv_multipart.side_effect = [[fake_meta, fake_data], Exception("Break loop")] + mock_socket.recv_multipart.side_effect = ((fake_meta, fake_data),) + MockContext.return_value.socket.return_value = mock_socket # Mocking the update_signal to check if it gets emitted eiger_plot_instance.update_signal = MagicMock() - try: + with patch("zmq.Poller"): + # will do only 1 iteration of the loop in the thread + eiger_plot_instance._zmq_consumer_exit_event.set() # Run the method under test - eiger_plot_instance.zmq_consumer() - except Exception as e: - # Ensure the loop was broken by our mocked exception - assert str(e) == "Break loop" + consumer_thread = eiger_plot_instance.start_zmq_consumer() + consumer_thread.join() # Check if zmq methods are called # MockContext.assert_called_once() - assert MockContext.call_count == 2 # TODO why 2? + assert MockContext.call_count == 1 mock_socket.connect.assert_called_with("tcp://129.129.95.38:20000") mock_socket.setsockopt_string.assert_called_with(zmq.SUBSCRIBE, "") mock_socket.recv_multipart.assert_called() diff --git a/tests/test_stream_plot.py b/tests/test_stream_plot.py index 8230a6db..7f0bbb0c 100644 --- a/tests/test_stream_plot.py +++ b/tests/test_stream_plot.py @@ -17,6 +17,7 @@ def stream_app(qtbot): qtbot.addWidget(widget) qtbot.waitExposed(widget) yield widget + widget.close() def test_roi_signals_emitted(qtbot, stream_app):