mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 03:31:50 +02:00
fix: add support for 'add_slice', add downsampling for performance improvements. add tests
This commit is contained in:
@ -97,6 +97,7 @@ class Curve(BECConnector, pg.PlotDataItem):
|
|||||||
self.apply_config()
|
self.apply_config()
|
||||||
self.dap_params = None
|
self.dap_params = None
|
||||||
self.dap_summary = None
|
self.dap_summary = None
|
||||||
|
self.slice_index = None
|
||||||
if kwargs:
|
if kwargs:
|
||||||
self.set(**kwargs)
|
self.set(**kwargs)
|
||||||
|
|
||||||
@ -303,14 +304,14 @@ class Curve(BECConnector, pg.PlotDataItem):
|
|||||||
self.apply_config()
|
self.apply_config()
|
||||||
self.parent_item.update_with_scan_history(-1)
|
self.parent_item.update_with_scan_history(-1)
|
||||||
|
|
||||||
def get_data(self) -> tuple[np.ndarray, np.ndarray]:
|
def get_data(self) -> tuple[np.ndarray | None, np.ndarray | None]:
|
||||||
"""
|
"""
|
||||||
Get the data of the curve.
|
Get the data of the curve.
|
||||||
Returns:
|
Returns:
|
||||||
tuple[np.ndarray,np.ndarray]: X and Y data of the curve.
|
tuple[np.ndarray,np.ndarray]: X and Y data of the curve.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
x_data, y_data = self.getData()
|
x_data, y_data = self.getOriginalDataset()
|
||||||
except TypeError:
|
except TypeError:
|
||||||
x_data, y_data = np.array([]), np.array([])
|
x_data, y_data = np.array([]), np.array([])
|
||||||
return x_data, y_data
|
return x_data, y_data
|
||||||
|
@ -134,6 +134,7 @@ class Waveform(PlotBase):
|
|||||||
# Curve data
|
# Curve data
|
||||||
self._sync_curves = []
|
self._sync_curves = []
|
||||||
self._async_curves = []
|
self._async_curves = []
|
||||||
|
self._slice_index = None
|
||||||
self._dap_curves = []
|
self._dap_curves = []
|
||||||
self._mode: Literal["none", "sync", "async", "mixed"] = "none"
|
self._mode: Literal["none", "sync", "async", "mixed"] = "none"
|
||||||
|
|
||||||
@ -947,6 +948,7 @@ class Waveform(PlotBase):
|
|||||||
self.old_scan_id = self.scan_id
|
self.old_scan_id = self.scan_id
|
||||||
self.scan_id = current_scan_id
|
self.scan_id = current_scan_id
|
||||||
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # live scan
|
self.scan_item = self.queue.scan_storage.find_scan_by_ID(self.scan_id) # live scan
|
||||||
|
self._slice_index = None # Reset the slice index
|
||||||
|
|
||||||
self._mode = self._categorise_device_curves()
|
self._mode = self._categorise_device_curves()
|
||||||
|
|
||||||
@ -1069,6 +1071,7 @@ class Waveform(PlotBase):
|
|||||||
|
|
||||||
# If there's actual data, set it
|
# If there's actual data, set it
|
||||||
if device_data is not None:
|
if device_data is not None:
|
||||||
|
self._auto_adjust_async_curve_settings(curve, len(device_data))
|
||||||
if x_data is not None:
|
if x_data is not None:
|
||||||
curve.setData(x_data, device_data)
|
curve.setData(x_data, device_data)
|
||||||
else:
|
else:
|
||||||
@ -1107,16 +1110,18 @@ class Waveform(PlotBase):
|
|||||||
msg(dict): Message with the async data.
|
msg(dict): Message with the async data.
|
||||||
metadata(dict): Metadata of the message.
|
metadata(dict): Metadata of the message.
|
||||||
"""
|
"""
|
||||||
y_data = None
|
|
||||||
x_data = None
|
|
||||||
instruction = metadata.get("async_update", {}).get("type")
|
instruction = metadata.get("async_update", {}).get("type")
|
||||||
max_shape = metadata.get("async_update", {}).get("max_shape", [])
|
max_shape = metadata.get("async_update", {}).get("max_shape", [])
|
||||||
for curve in self._async_curves:
|
for curve in self._async_curves:
|
||||||
|
new_data = None
|
||||||
|
y_data = None
|
||||||
|
x_data = None
|
||||||
y_entry = curve.config.signal.entry
|
y_entry = curve.config.signal.entry
|
||||||
x_name = self.x_axis_mode["name"]
|
x_name = self.x_axis_mode["name"]
|
||||||
for device, async_data in msg["signals"].items():
|
for device, async_data in msg["signals"].items():
|
||||||
if device == y_entry:
|
if device == y_entry:
|
||||||
data_plot = async_data["value"]
|
data_plot = async_data["value"]
|
||||||
|
# Add
|
||||||
if instruction == "add":
|
if instruction == "add":
|
||||||
if len(max_shape) > 1:
|
if len(max_shape) > 1:
|
||||||
if len(data_plot.shape) > 1:
|
if len(data_plot.shape) > 1:
|
||||||
@ -1134,17 +1139,70 @@ class Waveform(PlotBase):
|
|||||||
else:
|
else:
|
||||||
x_data = async_data["timestamp"]
|
x_data = async_data["timestamp"]
|
||||||
# FIXME x axis wrong if timestamp switched during scan
|
# FIXME x axis wrong if timestamp switched during scan
|
||||||
curve.setData(x_data, new_data)
|
# Add slice
|
||||||
else: # this means index as x
|
elif instruction == "add_slice":
|
||||||
curve.setData(new_data)
|
current_slice_id = metadata.get("async_update", {}).get("index")
|
||||||
|
data_plot = async_data["value"]
|
||||||
|
if current_slice_id != curve.slice_index:
|
||||||
|
curve.slice_index = current_slice_id
|
||||||
|
else:
|
||||||
|
x_data, y_data = curve.get_data()
|
||||||
|
if y_data is not None:
|
||||||
|
new_data = np.hstack((y_data, data_plot))
|
||||||
|
else:
|
||||||
|
new_data = data_plot
|
||||||
|
# Replace
|
||||||
elif instruction == "replace":
|
elif instruction == "replace":
|
||||||
if x_name == "timestamp":
|
if x_name == "timestamp":
|
||||||
x_data = async_data["timestamp"]
|
x_data = async_data["timestamp"]
|
||||||
curve.setData(x_data, data_plot)
|
new_data = data_plot
|
||||||
else:
|
|
||||||
curve.setData(data_plot)
|
# If update is not add, add_slice or replace, continue.
|
||||||
|
if new_data is None:
|
||||||
|
continue
|
||||||
|
# Hide symbol, activate downsampling if data >1000
|
||||||
|
self._auto_adjust_async_curve_settings(curve, len(new_data))
|
||||||
|
# Set data on the curve
|
||||||
|
if x_name == "timestamp" and instruction != "add_slice":
|
||||||
|
curve.setData(x_data, new_data)
|
||||||
|
else:
|
||||||
|
curve.setData(np.linspace(0, len(new_data) - 1, len(new_data)), new_data)
|
||||||
|
|
||||||
self.request_dap_update.emit()
|
self.request_dap_update.emit()
|
||||||
|
|
||||||
|
def _auto_adjust_async_curve_settings(
|
||||||
|
self,
|
||||||
|
curve: Curve,
|
||||||
|
data_length: int,
|
||||||
|
limit: int = 1000,
|
||||||
|
method: Literal["subsample", "mean", "peak"] | None = "mean",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Based on the length of the data this method will adjust the plotting settings of
|
||||||
|
Curve items, by deactivating the symbol and activating downsampling auto, method='mean',
|
||||||
|
if the data length exceeds N points. If the data length is less than N points, the
|
||||||
|
symbol will be activated and downsampling will be deactivated. Maximum points will be
|
||||||
|
5x the limit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
curve(Curve): The curve to adjust.
|
||||||
|
data_length(int): The length of the data.
|
||||||
|
limit(int): The limit of the data length to activate the downsampling.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if limit <= 1:
|
||||||
|
logger.warning("Limit must be greater than 1.")
|
||||||
|
return
|
||||||
|
if data_length > limit:
|
||||||
|
if curve.config.symbol is not None:
|
||||||
|
curve.set_symbol(None)
|
||||||
|
sampling_factor = int(data_length / (5 * limit)) # increase by limit 5x
|
||||||
|
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
|
||||||
|
elif data_length <= limit:
|
||||||
|
curve.set_symbol("o")
|
||||||
|
sampling_factor = 1
|
||||||
|
curve.setDownsampling(ds=sampling_factor, auto=None, method=method)
|
||||||
|
|
||||||
def setup_dap_for_scan(self):
|
def setup_dap_for_scan(self):
|
||||||
"""Setup DAP updates for the new scan."""
|
"""Setup DAP updates for the new scan."""
|
||||||
self.bec_dispatcher.disconnect_slot(
|
self.bec_dispatcher.disconnect_slot(
|
||||||
|
@ -529,16 +529,14 @@ def test_setup_async_curve(qtbot, mocked_client, monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("x_mode", ("timestamp", "index"))
|
@pytest.mark.parametrize("x_mode", ("timestamp", "index"))
|
||||||
def test_on_async_readback(qtbot, mocked_client, x_mode):
|
def test_on_async_readback_add_update(qtbot, mocked_client, x_mode):
|
||||||
"""
|
"""
|
||||||
Test that on_async_readback extends or replaces async data depending on metadata instruction.
|
Test that on_async_readback extends or replaces async data depending on metadata instruction.
|
||||||
For 'timestamp' mode, new timestamps are appended to x_data.
|
For 'timestamp' mode, new timestamps are appended to x_data.
|
||||||
For 'index' mode, x_data simply increases by integer index.
|
For 'index' mode, x_data simply increases by integer index.
|
||||||
"""
|
"""
|
||||||
wf = create_widget(qtbot, Waveform, client=mocked_client)
|
wf = create_widget(qtbot, Waveform, client=mocked_client)
|
||||||
dummy_scan = create_dummy_scan_item()
|
wf.scan_item = create_dummy_scan_item()
|
||||||
wf.scan_item = dummy_scan
|
|
||||||
|
|
||||||
c = wf.plot(arg1="async_device", label="async_device-async_device")
|
c = wf.plot(arg1="async_device", label="async_device-async_device")
|
||||||
wf._async_curves = [c]
|
wf._async_curves = [c]
|
||||||
# Suppose existing data
|
# Suppose existing data
|
||||||
@ -547,7 +545,8 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
|
|||||||
# Set the x_axis_mode
|
# Set the x_axis_mode
|
||||||
wf.x_axis_mode["name"] = x_mode
|
wf.x_axis_mode["name"] = x_mode
|
||||||
|
|
||||||
# Extend readback
|
############# Test add ################
|
||||||
|
|
||||||
msg = {"signals": {"async_device": {"value": [100, 200], "timestamp": [1001, 1002]}}}
|
msg = {"signals": {"async_device": {"value": [100, 200], "timestamp": [1001, 1002]}}}
|
||||||
metadata = {"async_update": {"max_shape": [None], "type": "add"}}
|
metadata = {"async_update": {"max_shape": [None], "type": "add"}}
|
||||||
wf.on_async_readback(msg, metadata)
|
wf.on_async_readback(msg, metadata)
|
||||||
@ -575,6 +574,72 @@ def test_on_async_readback(qtbot, mocked_client, x_mode):
|
|||||||
|
|
||||||
np.testing.assert_array_equal(y_data2, [999])
|
np.testing.assert_array_equal(y_data2, [999])
|
||||||
|
|
||||||
|
############# Test add_slice ################
|
||||||
|
|
||||||
|
# Few updates, no downsampling, no symbol removed
|
||||||
|
waveform_shape = 10
|
||||||
|
for ii in range(10):
|
||||||
|
msg = {"signals": {"async_device": {"value": [100], "timestamp": [1001]}}}
|
||||||
|
metadata = {
|
||||||
|
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
|
||||||
|
}
|
||||||
|
wf.on_async_readback(msg, metadata)
|
||||||
|
|
||||||
|
# Old data should be deleted since the slice_index did not match
|
||||||
|
x_data, y_data = c.get_data()
|
||||||
|
assert len(y_data) == 10
|
||||||
|
assert len(x_data) == 10
|
||||||
|
assert c.opts["symbol"] == "o"
|
||||||
|
|
||||||
|
# Clear data from curve
|
||||||
|
c.setData([], [])
|
||||||
|
|
||||||
|
# Test large updates, limit 1000 to deactivate symbols, downsampling for 8000 should be factor 2.
|
||||||
|
waveform_shape = 12000
|
||||||
|
for ii in range(12):
|
||||||
|
msg = {
|
||||||
|
"signals": {
|
||||||
|
"async_device": {
|
||||||
|
"value": np.array(range(1000)),
|
||||||
|
"timestamp": (ii + 1) * np.linspace(0, 1000 - 1, 1000),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata = {
|
||||||
|
"async_update": {"max_shape": [None, waveform_shape], "index": 0, "type": "add_slice"}
|
||||||
|
}
|
||||||
|
wf.on_async_readback(msg, metadata)
|
||||||
|
x_data, y_data = c.get_data()
|
||||||
|
assert len(y_data) == waveform_shape
|
||||||
|
assert len(x_data) == waveform_shape
|
||||||
|
assert c.opts["symbol"] == None
|
||||||
|
# Get displayed data
|
||||||
|
displayed_x, displayed_y = c.getData()
|
||||||
|
assert len(displayed_y) == waveform_shape / 2
|
||||||
|
assert len(displayed_x) == waveform_shape / 2
|
||||||
|
assert displayed_x[-1] == waveform_shape - 1 # Should be the correct index stil.
|
||||||
|
|
||||||
|
############# Test replace ################
|
||||||
|
waveform_shape = 10
|
||||||
|
for ii in range(10):
|
||||||
|
msg = {
|
||||||
|
"signals": {
|
||||||
|
"async_device": {
|
||||||
|
"value": np.array(range(waveform_shape)),
|
||||||
|
"timestamp": np.array(range(waveform_shape)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
metadata = {"async_update": {"type": "replace"}}
|
||||||
|
wf.on_async_readback(msg, metadata)
|
||||||
|
|
||||||
|
x_data, y_data = c.get_data()
|
||||||
|
assert np.array_equal(y_data, np.array(range(waveform_shape)))
|
||||||
|
assert len(x_data) == waveform_shape
|
||||||
|
assert c.opts["symbol"] == "o"
|
||||||
|
y_displayed, x_displayed = c.getData()
|
||||||
|
assert len(y_displayed) == waveform_shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_x_data(qtbot, mocked_client, monkeypatch):
|
def test_get_x_data(qtbot, mocked_client, monkeypatch):
|
||||||
"""
|
"""
|
||||||
|
Reference in New Issue
Block a user