0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

fix(widgets/figure): added widgets can be accessed as a list (fig.axes) or as a dictionary (fig.widgets)

This commit is contained in:
wyzula-jan
2024-03-14 17:25:16 +01:00
parent 32747baa27
commit fcf918c488
6 changed files with 97 additions and 59 deletions

View File

@ -243,6 +243,22 @@ class BECWaveform1D(RPCBase):
class BECFigure(RPCBase, BECFigureClientMixin):
@property
@rpc_call
def axes(self) -> "list[BECPlotBase]":
"""
Access all widget in BECFigure as a list
Returns:
list[BECPlotBase]: List of all widgets in the figure.
"""
@property
@rpc_call
def widgets(self) -> "dict":
"""
None
"""
@rpc_call
def add_plot(
self,

View File

@ -98,6 +98,11 @@ class RPCBase:
super().__init__()
# print(f"RPCBase: {self._gui_id}")
def __repr__(self):
type_ = type(self)
qualname = type_.__qualname__
return f"<{qualname} object at {hex(id(self))}>"
@property
def _root(self):
"""
@ -150,7 +155,9 @@ class RPCBase:
return [self._create_widget_from_msg_result(res) for res in msg_result]
if isinstance(msg_result, dict):
if "__rpc__" not in msg_result:
return msg_result
return {
key: self._create_widget_from_msg_result(val) for key, val in msg_result.items()
}
cls = msg_result.pop("widget_class", None)
msg_result.pop("__rpc__", None)

View File

@ -54,11 +54,11 @@ class BECWidgetsCLIServer:
if gui_id == self.fig.gui_id:
return self.fig
# check if the object is a widget
if gui_id in self.fig.widgets:
obj = self.fig.widgets[config["gui_id"]]
if gui_id in self.fig._widgets:
obj = self.fig._widgets[config["gui_id"]]
return obj
if self.fig.widgets:
for widget in self.fig.widgets.values():
if self.fig._widgets:
for widget in self.fig._widgets.values():
item = widget.find_widget_by_id(gui_id)
if item:
return item
@ -79,6 +79,8 @@ class BECWidgetsCLIServer:
if isinstance(res, list):
res = [self.serialize_object(obj) for obj in res]
elif isinstance(res, dict):
res = {key: self.serialize_object(val) for key, val in res.items()}
else:
res = self.serialize_object(res)
return res

View File

@ -95,6 +95,8 @@ class WidgetHandler:
class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
USER_ACCESS = [
"axes",
"widgets",
"add_plot",
"add_image",
"plot",
@ -127,11 +129,33 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
self.widget_handler = WidgetHandler()
# Widget container to reference widgets by 'widget_id'
self.widgets = defaultdict(dict)
self._widgets = defaultdict(dict)
# Container to keep track of the grid
self.grid = []
@property
def axes(self) -> list[BECPlotBase]:
"""
Access all widget in BECFigure as a list
Returns:
list[BECPlotBase]: List of all widgets in the figure.
"""
axes = [value for value in self._widgets.values() if isinstance(value, BECPlotBase)]
return axes
@axes.setter
def axes(self, value: list[BECPlotBase]):
self._axes = value
@property
def widgets(self) -> dict:
return self._widgets
@widgets.setter
def widgets(self, value: dict):
self._widgets = value
def add_plot(
self, widget_id: str = None, row: int = None, col: int = None, config=None, **axis_kwargs
) -> BECWaveform1D:
@ -314,7 +338,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
"""
if not widget_id:
widget_id = self._generate_unique_widget_id()
if widget_id in self.widgets:
if widget_id in self._widgets:
raise ValueError(f"Widget with ID '{widget_id}' already exists.")
widget = self.widget_handler.create_widget(
@ -350,7 +374,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Saving config for future referencing
self.config.widgets[widget_id] = widget.config
self.widgets[widget_id] = widget
self._widgets[widget_id] = widget
# Reflect the grid coordinates
self._change_grid(widget_id, row, col)
@ -402,7 +426,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
Returns:
BECPlotBase: The widget of the given class.
"""
for widget_id, widget in self.widgets.items():
for widget_id, widget in self._widgets.items():
if isinstance(widget, widget_class):
return widget
if can_fail:
@ -420,7 +444,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
widget = self._get_widget_by_coordinates(row, col)
if widget:
widget_id = widget.config.gui_id
if widget_id in self.widgets:
if widget_id in self._widgets:
self._remove_by_id(widget_id)
def _remove_by_id(self, widget_id: str) -> None:
@ -429,8 +453,8 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
Args:
widget_id(str): The unique identifier of the widget to remove.
"""
if widget_id in self.widgets:
widget = self.widgets.pop(widget_id)
if widget_id in self._widgets:
widget = self._widgets.pop(widget_id)
widget.cleanup()
self.removeItem(widget)
self.grid[widget.config.row][widget.config.col] = None
@ -445,10 +469,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
if isinstance(key, tuple) and len(key) == 2:
return self._get_widget_by_coordinates(*key)
elif isinstance(key, str):
widget = self.widgets.get(key)
widget = self._widgets.get(key)
if widget is None:
raise KeyError(f"No widget with ID {key}")
return self.widgets.get(key)
return self._widgets.get(key)
else:
raise TypeError(
"Key must be a string (widget id) or a tuple of two integers (grid coordinates)"
@ -478,7 +502,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
def _generate_unique_widget_id(self):
"""Generate a unique widget ID."""
existing_ids = set(self.widgets.keys())
existing_ids = set(self._widgets.keys())
for i in itertools.count(1):
widget_id = f"widget_{i}"
if widget_id not in existing_ids:
@ -511,7 +535,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Update the config of each object to reflect its new position
for row_idx, row in enumerate(new_grid):
for col_idx, widget in enumerate(row):
self.widgets[widget].config.row, self.widgets[widget].config.col = row_idx, col_idx
self._widgets[widget].config.row, self._widgets[widget].config.col = (
row_idx,
col_idx,
)
self.grid = new_grid
self._replot_layout()
@ -521,7 +548,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
self.clear()
for row_idx, row in enumerate(self.grid):
for col_idx, widget in enumerate(row):
self.addItem(self.widgets[widget], row=row_idx, col=col_idx)
self.addItem(self._widgets[widget], row=row_idx, col=col_idx)
def change_layout(self, max_columns=None, max_rows=None):
"""
@ -533,7 +560,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
max_rows (Optional[int]): The new maximum number of rows in the figure.
"""
# Calculate total number of widgets
total_widgets = len(self.widgets)
total_widgets = len(self._widgets)
if max_columns:
# Calculate the required number of rows based on max_columns
@ -549,7 +576,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Populate the new grid with widgets' IDs
current_idx = 0
for widget_id, widget in self.widgets.items():
for widget_id, widget in self._widgets.items():
row = current_idx // len(new_grid[0])
col = current_idx % len(new_grid[0])
new_grid[row][col] = widget_id
@ -565,10 +592,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
def clear_all(self):
"""Clear all widgets from the figure and reset to default state"""
for widget in self.widgets.values():
for widget in self._widgets.values():
widget.cleanup()
self.clear()
self.widgets = defaultdict(dict)
self._widgets = defaultdict(dict)
self.grid = []
theme = self.config.theme
self.config = FigureConfig(

View File

@ -219,7 +219,6 @@ class BECWaveform1D(BECPlotBase):
"remove_curve",
"scan_history",
"curves",
# "curves_data",
"get_curve",
"get_curve_config",
"apply_config",
@ -356,19 +355,6 @@ class BECWaveform1D(BECPlotBase):
def curves(self, value: list[BECCurve]):
self._curves = value
@property
def curves_data(self) -> dict[str, dict[str, BECCurve]]:
"""
Get the curves data of the plot widget as a dictionary
Returns:
dict: Dictionary of curves data.
"""
return self._curves_data
@curves_data.setter
def curves_data(self, value: dict[str, dict[str, BECCurve]]):
self._curves_data = value
def get_curve(self, identifier) -> BECCurve:
"""
Get the curve by its index or ID.

View File

@ -36,7 +36,7 @@ def test_bec_figure_init_with_config(mocked_client):
def test_bec_figure_add_remove_plot(bec_figure):
initial_count = len(bec_figure.widgets)
initial_count = len(bec_figure._widgets)
# Adding 3 widgets - 2 WaveformBase and 1 PlotBase
w0 = bec_figure.add_plot()
@ -44,13 +44,13 @@ def test_bec_figure_add_remove_plot(bec_figure):
w2 = bec_figure.add_widget(widget_id="test_plot", widget_type="PlotBase")
# Check if the widgets were added
assert len(bec_figure.widgets) == initial_count + 3
assert "widget_1" in bec_figure.widgets
assert "test_plot" in bec_figure.widgets
assert "test_waveform" in bec_figure.widgets
assert bec_figure.widgets["widget_1"].config.widget_class == "BECWaveform1D"
assert bec_figure.widgets["test_plot"].config.widget_class == "BECPlotBase"
assert bec_figure.widgets["test_waveform"].config.widget_class == "BECWaveform1D"
assert len(bec_figure._widgets) == initial_count + 3
assert "widget_1" in bec_figure._widgets
assert "test_plot" in bec_figure._widgets
assert "test_waveform" in bec_figure._widgets
assert bec_figure._widgets["widget_1"].config.widget_class == "BECWaveform1D"
assert bec_figure._widgets["test_plot"].config.widget_class == "BECPlotBase"
assert bec_figure._widgets["test_waveform"].config.widget_class == "BECWaveform1D"
# Check accessing positions by the grid in figure
assert bec_figure[0, 0] == w0
@ -59,10 +59,10 @@ def test_bec_figure_add_remove_plot(bec_figure):
# Removing 1 widget - PlotBase
bec_figure.remove(widget_id="test_plot")
assert len(bec_figure.widgets) == initial_count + 2
assert "test_plot" not in bec_figure.widgets
assert "test_waveform" in bec_figure.widgets
assert bec_figure.widgets["test_waveform"].config.widget_class == "BECWaveform1D"
assert len(bec_figure._widgets) == initial_count + 2
assert "test_plot" not in bec_figure._widgets
assert "test_waveform" in bec_figure._widgets
assert bec_figure._widgets["test_waveform"].config.widget_class == "BECWaveform1D"
def test_access_widgets_access_errors(bec_figure):
@ -116,21 +116,21 @@ def test_remove_plots(bec_figure):
# remove by coordinates
bec_figure[0, 0].remove()
assert "test_waveform_1" not in bec_figure.widgets
assert "test_waveform_1" not in bec_figure._widgets
# remove by widget_id
bec_figure.remove(widget_id="test_waveform_2")
assert "test_waveform_2" not in bec_figure.widgets
assert "test_waveform_2" not in bec_figure._widgets
# remove by widget object
w3.remove()
assert "test_waveform_3" not in bec_figure.widgets
assert "test_waveform_3" not in bec_figure._widgets
# check the remaining widget 4
assert bec_figure[0, 0] == w4
assert bec_figure["test_waveform_4"] == w4
assert "test_waveform_4" in bec_figure.widgets
assert len(bec_figure.widgets) == 1
assert "test_waveform_4" in bec_figure._widgets
assert len(bec_figure._widgets) == 1
def test_remove_plots_by_coordinates_ints(bec_figure):
@ -138,10 +138,10 @@ def test_remove_plots_by_coordinates_ints(bec_figure):
w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1)
bec_figure.remove(0, 0)
assert "test_waveform_1" not in bec_figure.widgets
assert "test_waveform_2" in bec_figure.widgets
assert "test_waveform_1" not in bec_figure._widgets
assert "test_waveform_2" in bec_figure._widgets
assert bec_figure[0, 0] == w2
assert len(bec_figure.widgets) == 1
assert len(bec_figure._widgets) == 1
def test_remove_plots_by_coordinates_tuple(bec_figure):
@ -149,10 +149,10 @@ def test_remove_plots_by_coordinates_tuple(bec_figure):
w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1)
bec_figure.remove(coordinates=(0, 0))
assert "test_waveform_1" not in bec_figure.widgets
assert "test_waveform_2" in bec_figure.widgets
assert "test_waveform_1" not in bec_figure._widgets
assert "test_waveform_2" in bec_figure._widgets
assert bec_figure[0, 0] == w2
assert len(bec_figure.widgets) == 1
assert len(bec_figure._widgets) == 1
def test_remove_plot_by_id_error(bec_figure):
@ -222,5 +222,5 @@ def test_clear_all(bec_figure):
bec_figure.clear_all()
assert len(bec_figure.widgets) == 0
assert len(bec_figure._widgets) == 0
assert np.shape(bec_figure.grid) == (0,)