0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-13 19:21:50 +02:00

fix: add support for 'add_slice', add downsampling for performance improvements. add tests

This commit is contained in:
2025-03-27 17:15:24 +01:00
committed by wyzula_j
parent b5015e4e72
commit 7f7891dfa5
3 changed files with 139 additions and 15 deletions

View File

@ -97,6 +97,7 @@ class Curve(BECConnector, pg.PlotDataItem):
self.apply_config()
self.dap_params = None
self.dap_summary = None
self.slice_index = None
if kwargs:
self.set(**kwargs)
@ -303,14 +304,14 @@ class Curve(BECConnector, pg.PlotDataItem):
self.apply_config()
self.parent_item.update_with_scan_history(-1)
def get_data(self) -> tuple[np.ndarray, np.ndarray]:
def get_data(self) -> tuple[np.ndarray | None, np.ndarray | None]:
"""
Get the data of the curve.
Returns:
tuple[np.ndarray,np.ndarray]: X and Y data of the curve.
"""
try:
x_data, y_data = self.getData()
x_data, y_data = self.getOriginalDataset()
except TypeError:
x_data, y_data = np.array([]), np.array([])
return x_data, y_data

View File

@ -134,6 +134,7 @@ class Waveform(PlotBase):
# Curve data
self._sync_curves = []
self._async_curves = []
self._slice_index = None
self._dap_curves = []
self._mode: Literal["none", "sync", "async", "mixed"] = "none"
@ -947,6 +948,7 @@ class Waveform(PlotBase):
self.old_scan_id = self.scan_id
self.scan_id = current_scan_id
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # live scan
self._slice_index = None # Reset the slice index
self._mode = self._categorise_device_curves()
@ -1069,6 +1071,7 @@ class Waveform(PlotBase):
# If there's actual data, set it
if device_data is not None:
self._auto_adjust_async_curve_settings(curve, len(device_data))
if x_data is not None:
curve.setData(x_data, device_data)
else:
@ -1107,16 +1110,18 @@ class Waveform(PlotBase):
msg(dict): Message with the async data.
metadata(dict): Metadata of the message.
"""
y_data = None
x_data = None
instruction = metadata.get("async_update", {}).get("type")
max_shape = metadata.get("async_update", {}).get("max_shape", [])
for curve in self._async_curves:
new_data = None
y_data = None
x_data = None
y_entry = curve.config.signal.entry
x_name = self.x_axis_mode["name"]
for device, async_data in msg["signals"].items():
if device == y_entry:
data_plot = async_data["value"]
# Add
if instruction == "add":
if len(max_shape) > 1:
if len(data_plot.shape) > 1:
@ -1134,17 +1139,70 @@ class Waveform(PlotBase):
else:
x_data = async_data["timestamp"]
# FIXME x axis wrong if timestamp switched during scan
curve.setData(x_data, new_data)
else: # this means index as x
curve.setData(new_data)
# Add slice
elif instruction == "add_slice":
current_slice_id = metadata.get("async_update", {}).get("index")
data_plot = async_data["value"]
if current_slice_id != curve.slice_index:
curve.slice_index = current_slice_id
else:
x_data, y_data = curve.get_data()
if y_data is not None:
new_data = np.hstack((y_data, data_plot))
else:
new_data = data_plot
# Replace
elif instruction == "replace":
if x_name == "timestamp":
x_data = async_data["timestamp"]
curve.setData(x_data, data_plot)
else:
curve.setData(data_plot)
new_data = data_plot
# If update is not add, add_slice or replace, continue.
if new_data is None:
continue
# Hide symbol, activate downsampling if data >1000
self._auto_adjust_async_curve_settings(curve, len(new_data))
# Set data on the curve
if x_name == "timestamp" and instruction != "add_slice":
curve.setData(x_data, new_data)
else:
curve.setData(np.linspace(0, len(new_data) - 1, len(new_data)), new_data)
self.request_dap_update.emit()
def _auto_adjust_async_curve_settings(
self,
curve: Curve,
data_length: int,
limit: int = 1000,
method: Literal["subsample", "mean", "peak"] | None = "mean",
) -> None:
"""
Based on the length of the data this method will adjust the plotting settings of
Curve items, by deactivating the symbol and activating downsampling auto, method='mean',
if the data length exceeds N points. If the data length is less than N points, the
symbol will be activated and downsampling will be deactivated. Maximum points will be
5x the limit.
Args:
curve(Curve): The curve to adjust.
data_length(int): The length of the data.
limit(int): The limit of the data length to activate the downsampling.
"""
if limit <= 1:
logger.warning("Limit must be greater than 1.")
return
if data_length > limit:
if curve.config.symbol is not None:
curve.set_symbol(None)
sampling_factor = int(data_length / (5 * limit)) # increase by limit 5x
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
elif data_length <= limit:
curve.set_symbol("o")
sampling_factor = 1
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
def setup_dap_for_scan(self):
"""Setup DAP updates for the new scan."""
self.bec_dispatcher.disconnect_slot(

View File

@ -529,16 +529,14 @@ def test_setup_async_curve(qtbot, mocked_client, monkeypatch):
@pytest.mark.parametrize("x_mode", ("timestamp", "index"))
def test_on_async_readback(qtbot, mocked_client, x_mode):
def test_on_async_readback_add_update(qtbot, mocked_client, x_mode):
"""
Test that on_async_readback extends or replaces async data depending on metadata instruction.
For 'timestamp' mode, new timestamps are appended to x_data.
For 'index' mode, x_data simply increases by integer index.
"""
wf = create_widget(qtbot, Waveform, client=mocked_client)
dummy_scan = create_dummy_scan_item()
wf.scan_item = dummy_scan
wf.scan_item = create_dummy_scan_item()
c = wf.plot(arg1="async_device", label="async_device-async_device")
wf._async_curves = [c]
# Suppose existing data
@ -547,7 +545,8 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
# Set the x_axis_mode
wf.x_axis_mode["name"] = x_mode
# Extend readback
############# Test add ################
msg = {"signals": {"async_device": {"value": [100, 200], "timestamp": [1001, 1002]}}}
metadata = {"async_update": {"max_shape": [None], "type": "add"}}
wf.on_async_readback(msg, metadata)
@ -575,6 +574,72 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
np.testing.assert_array_equal(y_data2, [999])
############# Test add_slice ################
# Few updates, no downsampling, no symbol removed
waveform_shape = 10
for ii in range(10):
msg = {"signals": {"async_device": {"value": [100], "timestamp": [1001]}}}
metadata = {
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
}
wf.on_async_readback(msg, metadata)
# Old data should be deleted since the slice_index did not match
x_data, y_data = c.get_data()
assert len(y_data) == 10
assert len(x_data) == 10
assert c.opts["symbol"] == "o"
# Clear data from curve
c.setData([], [])
# Test large updates, limit 1000 to deactivate symbols, downsampling for 8000 should be factor 2.
waveform_shape = 12000
for ii in range(12):
msg = {
"signals": {
"async_device": {
"value": np.array(range(1000)),
"timestamp": (ii + 1) * np.linspace(0, 1000 - 1, 1000),
}
}
}
metadata = {
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
}
wf.on_async_readback(msg, metadata)
x_data, y_data = c.get_data()
assert len(y_data) == waveform_shape
assert len(x_data) == waveform_shape
assert c.opts["symbol"] == None
# Get displayed data
displayed_x, displayed_y = c.getData()
assert len(displayed_y) == waveform_shape / 2
assert len(displayed_x) == waveform_shape / 2
assert displayed_x[-1] == waveform_shape - 1 # Should be the correct index stil.
############# Test replace ################
waveform_shape = 10
for ii in range(10):
msg = {
"signals": {
"async_device": {
"value": np.array(range(waveform_shape)),
"timestamp": np.array(range(waveform_shape)),
}
}
}
metadata = {"async_update": {"type": "replace"}}
wf.on_async_readback(msg, metadata)
x_data, y_data = c.get_data()
assert np.array_equal(y_data, np.array(range(waveform_shape)))
assert len(x_data) == waveform_shape
assert c.opts["symbol"] == "o"
y_displayed, x_displayed = c.getData()
assert len(y_displayed) == waveform_shape
def test_get_x_data(qtbot, mocked_client, monkeypatch):
"""