mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-13 19:21:50 +02:00
fix: add support for 'add_slice', add downsampling for performance improvements. add tests
This commit is contained in:
@ -97,6 +97,7 @@ class Curve(BECConnector, pg.PlotDataItem):
|
||||
self.apply_config()
|
||||
self.dap_params = None
|
||||
self.dap_summary = None
|
||||
self.slice_index = None
|
||||
if kwargs:
|
||||
self.set(**kwargs)
|
||||
|
||||
@ -303,14 +304,14 @@ class Curve(BECConnector, pg.PlotDataItem):
|
||||
self.apply_config()
|
||||
self.parent_item.update_with_scan_history(-1)
|
||||
|
||||
def get_data(self) -> tuple[np.ndarray, np.ndarray]:
|
||||
def get_data(self) -> tuple[np.ndarray | None, np.ndarray | None]:
|
||||
"""
|
||||
Get the data of the curve.
|
||||
Returns:
|
||||
tuple[np.ndarray,np.ndarray]: X and Y data of the curve.
|
||||
"""
|
||||
try:
|
||||
x_data, y_data = self.getData()
|
||||
x_data, y_data = self.getOriginalDataset()
|
||||
except TypeError:
|
||||
x_data, y_data = np.array([]), np.array([])
|
||||
return x_data, y_data
|
||||
|
@ -134,6 +134,7 @@ class Waveform(PlotBase):
|
||||
# Curve data
|
||||
self._sync_curves = []
|
||||
self._async_curves = []
|
||||
self._slice_index = None
|
||||
self._dap_curves = []
|
||||
self._mode: Literal["none", "sync", "async", "mixed"] = "none"
|
||||
|
||||
@ -947,6 +948,7 @@ class Waveform(PlotBase):
|
||||
self.old_scan_id = self.scan_id
|
||||
self.scan_id = current_scan_id
|
||||
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # live scan
|
||||
self._slice_index = None # Reset the slice index
|
||||
|
||||
self._mode = self._categorise_device_curves()
|
||||
|
||||
@ -1069,6 +1071,7 @@ class Waveform(PlotBase):
|
||||
|
||||
# If there's actual data, set it
|
||||
if device_data is not None:
|
||||
self._auto_adjust_async_curve_settings(curve, len(device_data))
|
||||
if x_data is not None:
|
||||
curve.setData(x_data, device_data)
|
||||
else:
|
||||
@ -1107,16 +1110,18 @@ class Waveform(PlotBase):
|
||||
msg(dict): Message with the async data.
|
||||
metadata(dict): Metadata of the message.
|
||||
"""
|
||||
y_data = None
|
||||
x_data = None
|
||||
instruction = metadata.get("async_update", {}).get("type")
|
||||
max_shape = metadata.get("async_update", {}).get("max_shape", [])
|
||||
for curve in self._async_curves:
|
||||
new_data = None
|
||||
y_data = None
|
||||
x_data = None
|
||||
y_entry = curve.config.signal.entry
|
||||
x_name = self.x_axis_mode["name"]
|
||||
for device, async_data in msg["signals"].items():
|
||||
if device == y_entry:
|
||||
data_plot = async_data["value"]
|
||||
# Add
|
||||
if instruction == "add":
|
||||
if len(max_shape) > 1:
|
||||
if len(data_plot.shape) > 1:
|
||||
@ -1134,17 +1139,70 @@ class Waveform(PlotBase):
|
||||
else:
|
||||
x_data = async_data["timestamp"]
|
||||
# FIXME x axis wrong if timestamp switched during scan
|
||||
curve.setData(x_data, new_data)
|
||||
else: # this means index as x
|
||||
curve.setData(new_data)
|
||||
# Add slice
|
||||
elif instruction == "add_slice":
|
||||
current_slice_id = metadata.get("async_update", {}).get("index")
|
||||
data_plot = async_data["value"]
|
||||
if current_slice_id != curve.slice_index:
|
||||
curve.slice_index = current_slice_id
|
||||
else:
|
||||
x_data, y_data = curve.get_data()
|
||||
if y_data is not None:
|
||||
new_data = np.hstack((y_data, data_plot))
|
||||
else:
|
||||
new_data = data_plot
|
||||
# Replace
|
||||
elif instruction == "replace":
|
||||
if x_name == "timestamp":
|
||||
x_data = async_data["timestamp"]
|
||||
curve.setData(x_data, data_plot)
|
||||
else:
|
||||
curve.setData(data_plot)
|
||||
new_data = data_plot
|
||||
|
||||
# If update is not add, add_slice or replace, continue.
|
||||
if new_data is None:
|
||||
continue
|
||||
# Hide symbol, activate downsampling if data >1000
|
||||
self._auto_adjust_async_curve_settings(curve, len(new_data))
|
||||
# Set data on the curve
|
||||
if x_name == "timestamp" and instruction != "add_slice":
|
||||
curve.setData(x_data, new_data)
|
||||
else:
|
||||
curve.setData(np.linspace(0, len(new_data) - 1, len(new_data)), new_data)
|
||||
|
||||
self.request_dap_update.emit()
|
||||
|
||||
def _auto_adjust_async_curve_settings(
|
||||
self,
|
||||
curve: Curve,
|
||||
data_length: int,
|
||||
limit: int = 1000,
|
||||
method: Literal["subsample", "mean", "peak"] | None = "mean",
|
||||
) -> None:
|
||||
"""
|
||||
Based on the length of the data this method will adjust the plotting settings of
|
||||
Curve items, by deactivating the symbol and activating downsampling auto, method='mean',
|
||||
if the data length exceeds N points. If the data length is less than N points, the
|
||||
symbol will be activated and downsampling will be deactivated. Maximum points will be
|
||||
5x the limit.
|
||||
|
||||
Args:
|
||||
curve(Curve): The curve to adjust.
|
||||
data_length(int): The length of the data.
|
||||
limit(int): The limit of the data length to activate the downsampling.
|
||||
|
||||
"""
|
||||
if limit <= 1:
|
||||
logger.warning("Limit must be greater than 1.")
|
||||
return
|
||||
if data_length > limit:
|
||||
if curve.config.symbol is not None:
|
||||
curve.set_symbol(None)
|
||||
sampling_factor = int(data_length / (5 * limit)) # increase by limit 5x
|
||||
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
|
||||
elif data_length <= limit:
|
||||
curve.set_symbol("o")
|
||||
sampling_factor = 1
|
||||
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
|
||||
|
||||
def setup_dap_for_scan(self):
|
||||
"""Setup DAP updates for the new scan."""
|
||||
self.bec_dispatcher.disconnect_slot(
|
||||
|
@ -529,16 +529,14 @@ def test_setup_async_curve(qtbot, mocked_client, monkeypatch):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("x_mode", ("timestamp", "index"))
|
||||
def test_on_async_readback(qtbot, mocked_client, x_mode):
|
||||
def test_on_async_readback_add_update(qtbot, mocked_client, x_mode):
|
||||
"""
|
||||
Test that on_async_readback extends or replaces async data depending on metadata instruction.
|
||||
For 'timestamp' mode, new timestamps are appended to x_data.
|
||||
For 'index' mode, x_data simply increases by integer index.
|
||||
"""
|
||||
wf = create_widget(qtbot, Waveform, client=mocked_client)
|
||||
dummy_scan = create_dummy_scan_item()
|
||||
wf.scan_item = dummy_scan
|
||||
|
||||
wf.scan_item = create_dummy_scan_item()
|
||||
c = wf.plot(arg1="async_device", label="async_device-async_device")
|
||||
wf._async_curves = [c]
|
||||
# Suppose existing data
|
||||
@ -547,7 +545,8 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
|
||||
# Set the x_axis_mode
|
||||
wf.x_axis_mode["name"] = x_mode
|
||||
|
||||
# Extend readback
|
||||
############# Test add ################
|
||||
|
||||
msg = {"signals": {"async_device": {"value": [100, 200], "timestamp": [1001, 1002]}}}
|
||||
metadata = {"async_update": {"max_shape": [None], "type": "add"}}
|
||||
wf.on_async_readback(msg, metadata)
|
||||
@ -575,6 +574,72 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
|
||||
|
||||
np.testing.assert_array_equal(y_data2, [999])
|
||||
|
||||
############# Test add_slice ################
|
||||
|
||||
# Few updates, no downsampling, no symbol removed
|
||||
waveform_shape = 10
|
||||
for ii in range(10):
|
||||
msg = {"signals": {"async_device": {"value": [100], "timestamp": [1001]}}}
|
||||
metadata = {
|
||||
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
|
||||
}
|
||||
wf.on_async_readback(msg, metadata)
|
||||
|
||||
# Old data should be deleted since the slice_index did not match
|
||||
x_data, y_data = c.get_data()
|
||||
assert len(y_data) == 10
|
||||
assert len(x_data) == 10
|
||||
assert c.opts["symbol"] == "o"
|
||||
|
||||
# Clear data from curve
|
||||
c.setData([], [])
|
||||
|
||||
# Test large updates, limit 1000 to deactivate symbols, downsampling for 8000 should be factor 2.
|
||||
waveform_shape = 12000
|
||||
for ii in range(12):
|
||||
msg = {
|
||||
"signals": {
|
||||
"async_device": {
|
||||
"value": np.array(range(1000)),
|
||||
"timestamp": (ii + 1) * np.linspace(0, 1000 - 1, 1000),
|
||||
}
|
||||
}
|
||||
}
|
||||
metadata = {
|
||||
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
|
||||
}
|
||||
wf.on_async_readback(msg, metadata)
|
||||
x_data, y_data = c.get_data()
|
||||
assert len(y_data) == waveform_shape
|
||||
assert len(x_data) == waveform_shape
|
||||
assert c.opts["symbol"] == None
|
||||
# Get displayed data
|
||||
displayed_x, displayed_y = c.getData()
|
||||
assert len(displayed_y) == waveform_shape / 2
|
||||
assert len(displayed_x) == waveform_shape / 2
|
||||
assert displayed_x[-1] == waveform_shape - 1 # Should be the correct index stil.
|
||||
|
||||
############# Test replace ################
|
||||
waveform_shape = 10
|
||||
for ii in range(10):
|
||||
msg = {
|
||||
"signals": {
|
||||
"async_device": {
|
||||
"value": np.array(range(waveform_shape)),
|
||||
"timestamp": np.array(range(waveform_shape)),
|
||||
}
|
||||
}
|
||||
}
|
||||
metadata = {"async_update": {"type": "replace"}}
|
||||
wf.on_async_readback(msg, metadata)
|
||||
|
||||
x_data, y_data = c.get_data()
|
||||
assert np.array_equal(y_data, np.array(range(waveform_shape)))
|
||||
assert len(x_data) == waveform_shape
|
||||
assert c.opts["symbol"] == "o"
|
||||
y_displayed, x_displayed = c.getData()
|
||||
assert len(y_displayed) == waveform_shape
|
||||
|
||||
|
||||
def test_get_x_data(qtbot, mocked_client, monkeypatch):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user