From 006992e43cc56d56261bc4fd3e9cae9abcab2153 Mon Sep 17 00:00:00 2001 From: wyzula-jan Date: Thu, 11 Jul 2024 23:10:35 +0200 Subject: [PATCH] test(waveform): tests extended --- .../widgets/figure/plots/waveform/waveform.py | 132 ++++++++---- tests/unit_tests/client_mocks.py | 17 +- tests/unit_tests/test_device_input_widgets.py | 2 + tests/unit_tests/test_waveform1d.py | 189 +++++++++++++++++- 4 files changed, 293 insertions(+), 47 deletions(-) diff --git a/bec_widgets/widgets/figure/plots/waveform/waveform.py b/bec_widgets/widgets/figure/plots/waveform/waveform.py index d5f684ef..0ba157a2 100644 --- a/bec_widgets/widgets/figure/plots/waveform/waveform.py +++ b/bec_widgets/widgets/figure/plots/waveform/waveform.py @@ -47,7 +47,7 @@ class BECWaveform(BECPlotBase): "_config_dict", "plot", "add_dap", - "change_x_axis", + "set_x", "get_dap_params", "remove_curve", "scan_history", @@ -71,6 +71,7 @@ class BECWaveform(BECPlotBase): scan_signal_update = pyqtSignal() async_signal_update = pyqtSignal() dap_params_update = pyqtSignal(dict) + autorange_signal = pyqtSignal() def __init__( self, @@ -105,6 +106,7 @@ class BECWaveform(BECPlotBase): self.scan_signal_update, rateLimit=25, slot=self.refresh_dap ) self.async_signal_update.connect(self.replot_async_curve) + self.autorange_signal.connect(self.auto_range) # Get bec shortcuts dev, scans, queue, scan_storage, dap self.get_bec_shortcuts() @@ -274,7 +276,11 @@ class BECWaveform(BECPlotBase): arg1(list | np.ndarray | str | None): First argument which can be x data, y data, or y_name. y(list | np.ndarray): Custom y data to plot. x(list | np.ndarray): Custom y data to plot. - x_name(str): The name of the device for the x-axis. + x_name(str): Name of the x signal. + - "best_effort": Use the best effort signal. + - "timestamp": Use the timestamp signal. + - "index": Use the index signal. + - Custom signal name of device from BEC. y_name(str): The name of the device for the y-axis. z_name(str): The name of the device for the z-axis. x_entry(str): The name of the entry for the x-axis. @@ -327,19 +333,25 @@ class BECWaveform(BECPlotBase): ) self.scan_signal_update.emit() self.async_signal_update.emit() + return curve - def change_x_axis(self, x_name: str, x_entry: str | None = None): + def set_x(self, x_name: str, x_entry: str | None = None): """ Change the x axis of the plot widget. Args: x_name(str): Name of the x signal. + - "best_effort": Use the best effort signal. + - "timestamp": Use the timestamp signal. + - "index": Use the index signal. + - Custom signal name of device from BEC. x_entry(str): Entry of the x signal. """ curve_configs = self.config.curves curve_ids = list(curve_configs.keys()) curve_configs = list(curve_configs.values()) + self.set_auto_range(True, "xy") x_entry, _, _ = self._validate_signal_entries( x_name, None, None, x_entry, None, None, validate_bec=True @@ -374,6 +386,23 @@ class BECWaveform(BECPlotBase): self.async_signal_update.emit() self.scan_signal_update.emit() + @pyqtSlot() + def auto_range(self): + self.plot_item.autoRange() + + def set_auto_range(self, enabled: bool, axis: str = "xy"): + """ + Set the auto range of the plot widget. + + Args: + enabled(bool): If True, enable the auto range. + axis(str, optional): The axis to enable the auto range. + - "xy": Enable auto range for both x and y axis. + - "x": Enable auto range for x axis. + - "y": Enable auto range for y axis. + """ + self.plot_item.enableAutoRange(axis, enabled) + def add_curve_custom( self, x: list | np.ndarray, @@ -468,19 +497,24 @@ class BECWaveform(BECPlotBase): if y_name is None: raise ValueError("y_name must be provided.") - # 2. Check - get source of the device - if source is None: - source = self._validate_device_source_compatibity(y_name) - - # 3. Check - check if there is already a x axis signal + # 2. Check - check if there is already a x axis signal if x_name is None: - x_name = self.x_axis_mode["name"] + mode = self.x_axis_mode["name"] + x_name = mode if mode is not None else "best_effort" + self.x_axis_mode["name"] = x_name - # 4. Check - Get entry if not provided and validate + # 3. Check - Get entry if not provided and validate x_entry, y_entry, z_entry = self._validate_signal_entries( x_name, y_name, z_name, x_entry, y_entry, z_entry, validate_bec ) + # 4. Check - get source of the device + if source is None: + if validate_bec is True: + source = self._validate_device_source_compatibity(y_name) + else: + source = "scan_segment" + if z_name is not None and z_entry is not None: label = label or f"{z_name}-{z_entry}" else: @@ -492,7 +526,8 @@ class BECWaveform(BECPlotBase): raise ValueError(f"Curve with ID '{label}' already exists in widget '{self.gui_id}'.") # Validate or define x axis behaviour and compatibility with y_name readoutPriority - self._validate_x_axis_behaviour(y_name, x_name, x_entry) + if validate_bec is True: + self._validate_x_axis_behaviour(y_name, x_name, x_entry) # Create color if not specified color = ( @@ -521,7 +556,6 @@ class BECWaveform(BECPlotBase): ) curve = self._add_curve_object(name=label, source=source, config=curve_config) - return curve def add_dap( @@ -669,7 +703,7 @@ class BECWaveform(BECPlotBase): f"All curves must have the same x axis.\n" f" Current valid x axis: '{self._x_axis_mode['name']}'\n" f" Attempted to add curve with x axis: '{x_name}'\n" - f"If you want to change the x-axis of the curve, please remove previous curves or change the x axis of the plot widget with '.change_x_axis({x_name})'." + f"If you want to change the x-axis of the curve, please remove previous curves or change the x axis of the plot widget with '.set_x({x_name})'." ) # If x_axis_mode["name"] is None, determine the mode based on x_name @@ -688,7 +722,7 @@ class BECWaveform(BECPlotBase): raise ValueError( f"Async devices '{y_name}' cannot be used with custom x signal '{x_name}-{x_entry}'.\n" f"Please use mode 'best_effort', 'timestamp', or 'index' signal for x axis." - f"You can change the x axis mode with '.change_x_axis(mode)'" + f"You can change the x axis mode with '.set_x(mode)'" ) if auto_switch is True: @@ -766,6 +800,7 @@ class BECWaveform(BECPlotBase): if validate_bec: if x_name is None: x_name = "best_effort" + x_entry = "best_effort" if x_name: if x_name == "index" or x_name == "timestamp" or x_name == "best_effort": x_entry = x_name @@ -868,6 +903,7 @@ class BECWaveform(BECPlotBase): return if current_scan_id != self.scan_id: + self.set_auto_range(True, "xy") self.old_scan_id = self.scan_id self.scan_id = current_scan_id self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) @@ -888,7 +924,6 @@ class BECWaveform(BECPlotBase): metadata (dict): Metadata of the scan. """ self.on_scan_status(msg) - self.scan_signal_update.emit() def set_x_label(self, label: str, size: int = None): @@ -942,7 +977,18 @@ class BECWaveform(BECPlotBase): Refresh the DAP curves with the latest data from the DAP model MessageEndpoints.dap_response(). """ for curve_id, curve in self._curves_data["DAP"].items(): - if curve.config.signals.x is not None: + if len(self._curves_data["async"]) > 0: + curve.remove() + raise ValueError( + f"Cannot refresh DAP curve '{curve_id}' while async curves are present. Removing {curve_id} from display." + ) + if self._x_axis_mode["name"] == "best_effort": + try: + x_name = self.scan_item.status_message.info["scan_report_devices"][0] + x_entry = self.entry_validator.validate_signal(x_name, None) + except AttributeError: + return + elif curve.config.signals.x is not None: x_name = curve.config.signals.x.name x_entry = curve.config.signals.x.entry if ( @@ -955,12 +1001,7 @@ class BECWaveform(BECPlotBase): return except AttributeError: return - else: - try: - x_name = self.scan_item.status_message.info["scan_report_devices"][0] - x_entry = self.entry_validator.validate_signal(x_name, None) - except AttributeError: - return + y_name = curve.config.signals.y.name y_entry = curve.config.signals.y.entry model_name = curve.config.signals.dap @@ -1046,6 +1087,7 @@ class BECWaveform(BECPlotBase): for curve_id, curve in self._curves_data["async"].items(): y_name = curve.config.signals.y.name y_entry = curve.config.signals.y.entry + x_name = None if curve.config.signals.x: x_name = curve.config.signals.x.name @@ -1084,7 +1126,7 @@ class BECWaveform(BECPlotBase): z_entry = curve.config.signals.z.entry data_x = self._get_x_data(curve, y_name, y_entry) - if data_x == []: # case if the data is empty because motor is not scanned + if len(data) == 0: # case if the data is empty because motor is not scanned return try: @@ -1119,28 +1161,41 @@ class BECWaveform(BECPlotBase): Returns: list|np.ndarray|None: X data for the curve. """ - if curve.config.signals.x is not None: - if curve.config.signals.x.name == "timestamp": - timestamps = self.scan_item.data[y_name][y_entry].timestamps - x_data = self.convert_timestamps(timestamps) - elif curve.config.signals.x.name == "index": - x_data = None - else: - x_name = curve.config.signals.x.name - x_entry = curve.config.signals.x.entry - try: - x_data = self.scan_item.data[x_name][x_entry].val - except TypeError: - x_data = [] - else: + x_data = None + if self._x_axis_mode["name"] == "timestamp": + timestamps = self.scan_item.data[y_name][y_entry].timestamps + x_data = self.convert_timestamps(timestamps) + return x_data + if self._x_axis_mode["name"] == "index": + x_data = None + return x_data + + if self._x_axis_mode["name"] is None or self._x_axis_mode["name"] == "best_effort": if len(self._curves_data["async"]) > 0: x_data = None + self._x_axis_mode["label_suffix"] = f" [auto: index]" + current_label = "" if self.config.axis.x_label is None else self.config.axis.x_label + self.plot_item.setLabel( + "bottom", f"{current_label}{self._x_axis_mode['label_suffix']}" + ) + return x_data else: x_name = self.scan_item.status_message.info["scan_report_devices"][0] x_entry = self.entry_validator.validate_signal(x_name, None) x_data = self.scan_item.data[x_name][x_entry].val - self.set_x_label(f"[auto: {x_name}-{x_entry}]") + self._x_axis_mode["label_suffix"] = f" [auto: {x_name}-{x_entry}]" + current_label = "" if self.config.axis.x_label is None else self.config.axis.x_label + self.plot_item.setLabel( + "bottom", f"{current_label}{self._x_axis_mode['label_suffix']}" + ) + else: + x_name = curve.config.signals.x.name + x_entry = curve.config.signals.x.entry + try: + x_data = self.scan_item.data[x_name][x_entry].val + except TypeError: + x_data = [] return x_data def _make_z_gradient(self, data_z: list | np.ndarray, colormap: str) -> list | None: @@ -1193,6 +1248,7 @@ class BECWaveform(BECPlotBase): self.setup_dap(self.old_scan_id, self.scan_id) self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) self.scan_signal_update.emit() + self.async_signal_update.emit() def get_all_data(self, output: Literal["dict", "pandas"] = "dict") -> dict | pd.DataFrame: """ diff --git a/tests/unit_tests/client_mocks.py b/tests/unit_tests/client_mocks.py index 7679391c..06b753e9 100644 --- a/tests/unit_tests/client_mocks.py +++ b/tests/unit_tests/client_mocks.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import fakeredis import pytest from bec_lib.client import BECClient -from bec_lib.device import Positioner +from bec_lib.device import Positioner, ReadoutPriority from bec_lib.devicemanager import DeviceContainer from bec_lib.redis_connector import RedisConnector @@ -12,11 +12,12 @@ from bec_lib.redis_connector import RedisConnector class FakeDevice: """Fake minimal positioner class for testing.""" - def __init__(self, name, enabled=True): + def __init__(self, name, enabled=True, readout_priority=ReadoutPriority.MONITORED): self.name = name self.enabled = enabled self.signals = {self.name: {"value": 1.0}} self.description = {self.name: {"source": self.name, "dtype": "number", "shape": []}} + self.readout_priority = readout_priority def __contains__(self, item): return item == self.name @@ -43,8 +44,15 @@ class FakeDevice: class FakePositioner(FakeDevice): - def __init__(self, name, enabled=True, limits=None, read_value=1.0): - super().__init__(name, enabled) + def __init__( + self, + name, + enabled=True, + limits=None, + read_value=1.0, + readout_priority=ReadoutPriority.MONITORED, + ): + super().__init__(name, enabled, readout_priority) self.limits = limits if limits is not None else [0, 0] self.read_value = read_value self.name = name @@ -110,6 +118,7 @@ DEVICES = [ FakeDevice("bpm3a"), FakeDevice("bpm3i"), FakeDevice("eiger"), + FakeDevice("async_device", readout_priority=ReadoutPriority.ASYNC), Positioner("test", limits=[-10, 10], read_value=2.0), ] diff --git a/tests/unit_tests/test_device_input_widgets.py b/tests/unit_tests/test_device_input_widgets.py index 4078fb3d..b5f654aa 100644 --- a/tests/unit_tests/test_device_input_widgets.py +++ b/tests/unit_tests/test_device_input_widgets.py @@ -67,6 +67,7 @@ def test_device_input_combobox_init(device_input_combobox): "bpm3a", "bpm3i", "eiger", + "async_device", "test", ] @@ -154,6 +155,7 @@ def test_device_input_line_edit_init(device_input_line_edit): "bpm3a", "bpm3i", "eiger", + "async_device", "test", ] diff --git a/tests/unit_tests/test_waveform1d.py b/tests/unit_tests/test_waveform1d.py index 39dbfdf3..b5848603 100644 --- a/tests/unit_tests/test_waveform1d.py +++ b/tests/unit_tests/test_waveform1d.py @@ -166,6 +166,8 @@ def test_getting_curve(bec_figure): assert w1.curves[0].config == c1_expected_config assert w1._curves_data["scan_segment"]["bpm4i-bpm4i"].config == c1_expected_config assert w1.get_curve(0).config == c1_expected_config + assert w1.get_curve_config("bpm4i-bpm4i", dict_output=True) == c1_expected_config.model_dump() + assert w1.get_curve_config("bpm4i-bpm4i", dict_output=False) == c1_expected_config assert w1.get_curve("bpm4i-bpm4i").config == c1_expected_config assert c1.get_config(False) == c1_expected_config assert c1.get_config() == c1_expected_config.model_dump() @@ -448,7 +450,7 @@ def test_scan_update(bec_figure, qtbot): def test_scan_history_with_val_access(bec_figure, qtbot): w1 = bec_figure.plot() - c1 = w1.add_curve_bec(x_name="samx", y_name="bpm4i") + w1.plot(x_name="samx", y_name="bpm4i") mock_scan_data = { "samx": {"samx": mock.MagicMock(val=np.array([1, 2, 3]))}, # Use mock.MagicMock for .val @@ -464,7 +466,7 @@ def test_scan_history_with_val_access(bec_figure, qtbot): qtbot.wait(500) - x_data, y_data = c1.get_data() + x_data, y_data = w1.curves[0].get_data() assert np.array_equal(x_data, [1, 2, 3]) assert np.array_equal(y_data, [4, 5, 6]) @@ -485,8 +487,8 @@ def test_scatter_2d_update(bec_figure, qtbot): } msg_metadata = {"scan_name": "line_scan"} - mock_scan_data = mock.MagicMock() - mock_scan_data.data = { + mock_scan_item = mock.MagicMock() + mock_scan_item.data = { device_name: { entry: mock.MagicMock(val=msg["data"][device_name][entry]["value"]) for entry in msg["data"][device_name] @@ -494,7 +496,7 @@ def test_scatter_2d_update(bec_figure, qtbot): for device_name in msg["data"] } - w1.queue.scan_storage.find_scan_by_ID.return_value = mock_scan_data + w1.queue.scan_storage.find_scan_by_ID.return_value = mock_scan_item w1.on_scan_segment(msg, msg_metadata) qtbot.wait(500) @@ -508,3 +510,180 @@ def test_scatter_2d_update(bec_figure, qtbot): assert np.array_equal(data, expected_x_y_data) assert colors == expected_z_colors + + +def test_waveform_single_arg_inputs(bec_figure, qtbot): + w1 = bec_figure.plot() + + w1.plot("bpm4i") + w1.plot([1, 2, 3], label="just_y") + w1.plot([3, 4, 5], [7, 8, 9], label="x_y") + w1.plot(x=[1, 2, 3], y=[4, 5, 6], label="x_y_kwargs") + data_array_1D = np.random.rand(10) + data_array_2D = np.random.rand(10, 2) + w1.plot(data_array_1D, label="np_ndarray 1D") + w1.plot(data_array_2D, label="np_ndarray 2D") + + qtbot.wait(200) + + assert w1._curves_data["scan_segment"]["bpm4i-bpm4i"].config.label == "bpm4i-bpm4i" + assert w1._curves_data["custom"]["just_y"].config.label == "just_y" + assert w1._curves_data["custom"]["x_y"].config.label == "x_y" + assert w1._curves_data["custom"]["x_y_kwargs"].config.label == "x_y_kwargs" + + assert np.array_equal(w1._curves_data["custom"]["just_y"].get_data(), ([0, 1, 2], [1, 2, 3])) + assert np.array_equal(w1._curves_data["custom"]["just_y"].get_data(), ([0, 1, 2], [1, 2, 3])) + assert np.array_equal(w1._curves_data["custom"]["x_y"].get_data(), ([3, 4, 5], [7, 8, 9])) + assert np.array_equal( + w1._curves_data["custom"]["x_y_kwargs"].get_data(), ([1, 2, 3], [4, 5, 6]) + ) + assert np.array_equal( + w1._curves_data["custom"]["np_ndarray 1D"].get_data(), + (np.arange(data_array_1D.size), data_array_1D.T), + ) + assert np.array_equal(w1._curves_data["custom"]["np_ndarray 2D"].get_data(), data_array_2D.T) + + +def test_waveform_set_x_sync(bec_figure, qtbot): + w1 = bec_figure.plot() + custom_label = "custom_label" + w1.plot("bpm4i") + w1.set_x_label(custom_label) + + scan_item_mock = mock.MagicMock() + mock_data = { + "samx": {"samx": mock.MagicMock(val=np.array([1, 2, 3]))}, + "samy": {"samy": mock.MagicMock(val=np.array([4, 5, 6]))}, + "bpm4i": { + "bpm4i": mock.MagicMock( + val=np.array([7, 8, 9]), + timestamps=np.array([1720520189.959115, 1720520189.986618, 1720520190.0157812]), + ) + }, + } + + scan_item_mock.data = mock_data + scan_item_mock.status_message.info = {"scan_report_devices": ["samx"]} + + w1.queue.scan_storage.find_scan_by_ID.return_value = scan_item_mock + + w1.on_scan_segment({"scan_id": 1}, {}) + qtbot.wait(200) + + # Best effort - samx + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [1, 2, 3]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [auto: samx-samx]" + + # Change to samy + w1.set_x("samy") + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [4, 5, 6]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [samy-samy]" + + # change to index + w1.set_x("index") + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [index]" + + # change to timestamp + w1.set_x("timestamp") + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.allclose(x_data, np.array([1.72052019e09, 1.72052019e09, 1.72052019e09])) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [timestamp]" + + +def test_waveform_async_data_update(bec_figure, qtbot): + w1 = bec_figure.plot("async_device") + custom_label = "custom_label" + w1.set_x_label(custom_label) + + # scan_item_mock = mock.MagicMock() + # mock_data = { + # "async_device": { + # "async_device": mock.MagicMock( + # val=np.array([7, 8, 9]), + # timestamps=np.array([1720520189.959115, 1720520189.986618, 1720520190.0157812]), + # ) + # } + # } + # + # scan_item_mock.async_data = mock_data + # w1.queue.scan_storage.find_scan_by_ID.return_value = scan_item_mock + + msg_1 = {"signals": {"async_device": {"value": [7, 8, 9]}}} + w1.on_async_readback(msg_1, {"async_update": "extend"}) + + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [best_effort]" + + msg_2 = {"signals": {"async_device": {"value": [10, 11, 12]}}} + w1.on_async_readback(msg_2, {"async_update": "extend"}) + + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2, 3, 4, 5]) + assert np.array_equal(y_data, [7, 8, 9, 10, 11, 12]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [best_effort]" + + msg_3 = {"signals": {"async_device": {"value": [20, 21, 22]}}} + w1.on_async_readback(msg_3, {"async_update": "replace"}) + + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2]) + assert np.array_equal(y_data, [20, 21, 22]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [best_effort]" + + +def test_waveform_set_x_async(bec_figure, qtbot): + w1 = bec_figure.plot("async_device") + custom_label = "custom_label" + w1.set_x_label(custom_label) + + scan_item_mock = mock.MagicMock() + mock_data = { + "async_device": { + "async_device": { + "value": np.array([7, 8, 9]), + "timestamp": np.array([1720520189.959115, 1720520189.986618, 1720520190.0157812]), + } + } + } + + scan_item_mock.async_data = mock_data + w1.queue.scan_storage.find_scan_by_ID.return_value = scan_item_mock + + w1.on_scan_status({"scan_id": 1}) + w1.replot_async_curve() + + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [best_effort]" + + w1.set_x("timestamp") + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.allclose(x_data, np.array([1.72052019e09, 1.72052019e09, 1.72052019e09])) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [timestamp]" + + w1.set_x("index") + qtbot.wait(200) + x_data, y_data = w1.curves[0].get_data() + assert np.array_equal(x_data, [0, 1, 2]) + assert np.array_equal(y_data, [7, 8, 9]) + assert w1.plot_item.getAxis("bottom").labelText == custom_label + " [index]"