0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 11:41:49 +02:00

fix(widgets/figure): added widgets can be accessed as a list (fig.axes) or as a dictionary (fig.widgets)

This commit is contained in:
wyzula-jan
2024-03-14 17:25:16 +01:00
parent 32747baa27
commit fcf918c488
6 changed files with 97 additions and 59 deletions

View File

@ -243,6 +243,22 @@ class BECWaveform1D(RPCBase):
class BECFigure(RPCBase, BECFigureClientMixin): class BECFigure(RPCBase, BECFigureClientMixin):
@property
@rpc_call
def axes(self) -> "list[BECPlotBase]":
"""
Access all widget in BECFigure as a list
Returns:
list[BECPlotBase]: List of all widgets in the figure.
"""
@property
@rpc_call
def widgets(self) -> "dict":
"""
None
"""
@rpc_call @rpc_call
def add_plot( def add_plot(
self, self,

View File

@ -98,6 +98,11 @@ class RPCBase:
super().__init__() super().__init__()
# print(f"RPCBase: {self._gui_id}") # print(f"RPCBase: {self._gui_id}")
def __repr__(self):
type_ = type(self)
qualname = type_.__qualname__
return f"<{qualname} object at {hex(id(self))}>"
@property @property
def _root(self): def _root(self):
""" """
@ -150,7 +155,9 @@ class RPCBase:
return [self._create_widget_from_msg_result(res) for res in msg_result] return [self._create_widget_from_msg_result(res) for res in msg_result]
if isinstance(msg_result, dict): if isinstance(msg_result, dict):
if "__rpc__" not in msg_result: if "__rpc__" not in msg_result:
return msg_result return {
key: self._create_widget_from_msg_result(val) for key, val in msg_result.items()
}
cls = msg_result.pop("widget_class", None) cls = msg_result.pop("widget_class", None)
msg_result.pop("__rpc__", None) msg_result.pop("__rpc__", None)

View File

@ -54,11 +54,11 @@ class BECWidgetsCLIServer:
if gui_id == self.fig.gui_id: if gui_id == self.fig.gui_id:
return self.fig return self.fig
# check if the object is a widget # check if the object is a widget
if gui_id in self.fig.widgets: if gui_id in self.fig._widgets:
obj = self.fig.widgets[config["gui_id"]] obj = self.fig._widgets[config["gui_id"]]
return obj return obj
if self.fig.widgets: if self.fig._widgets:
for widget in self.fig.widgets.values(): for widget in self.fig._widgets.values():
item = widget.find_widget_by_id(gui_id) item = widget.find_widget_by_id(gui_id)
if item: if item:
return item return item
@ -79,6 +79,8 @@ class BECWidgetsCLIServer:
if isinstance(res, list): if isinstance(res, list):
res = [self.serialize_object(obj) for obj in res] res = [self.serialize_object(obj) for obj in res]
elif isinstance(res, dict):
res = {key: self.serialize_object(val) for key, val in res.items()}
else: else:
res = self.serialize_object(res) res = self.serialize_object(res)
return res return res

View File

@ -95,6 +95,8 @@ class WidgetHandler:
class BECFigure(BECConnector, pg.GraphicsLayoutWidget): class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
USER_ACCESS = [ USER_ACCESS = [
"axes",
"widgets",
"add_plot", "add_plot",
"add_image", "add_image",
"plot", "plot",
@ -127,11 +129,33 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
self.widget_handler = WidgetHandler() self.widget_handler = WidgetHandler()
# Widget container to reference widgets by 'widget_id' # Widget container to reference widgets by 'widget_id'
self.widgets = defaultdict(dict) self._widgets = defaultdict(dict)
# Container to keep track of the grid # Container to keep track of the grid
self.grid = [] self.grid = []
@property
def axes(self) -> list[BECPlotBase]:
"""
Access all widget in BECFigure as a list
Returns:
list[BECPlotBase]: List of all widgets in the figure.
"""
axes = [value for value in self._widgets.values() if isinstance(value, BECPlotBase)]
return axes
@axes.setter
def axes(self, value: list[BECPlotBase]):
self._axes = value
@property
def widgets(self) -> dict:
return self._widgets
@widgets.setter
def widgets(self, value: dict):
self._widgets = value
def add_plot( def add_plot(
self, widget_id: str = None, row: int = None, col: int = None, config=None, **axis_kwargs self, widget_id: str = None, row: int = None, col: int = None, config=None, **axis_kwargs
) -> BECWaveform1D: ) -> BECWaveform1D:
@ -314,7 +338,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
""" """
if not widget_id: if not widget_id:
widget_id = self._generate_unique_widget_id() widget_id = self._generate_unique_widget_id()
if widget_id in self.widgets: if widget_id in self._widgets:
raise ValueError(f"Widget with ID '{widget_id}' already exists.") raise ValueError(f"Widget with ID '{widget_id}' already exists.")
widget = self.widget_handler.create_widget( widget = self.widget_handler.create_widget(
@ -350,7 +374,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Saving config for future referencing # Saving config for future referencing
self.config.widgets[widget_id] = widget.config self.config.widgets[widget_id] = widget.config
self.widgets[widget_id] = widget self._widgets[widget_id] = widget
# Reflect the grid coordinates # Reflect the grid coordinates
self._change_grid(widget_id, row, col) self._change_grid(widget_id, row, col)
@ -402,7 +426,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
Returns: Returns:
BECPlotBase: The widget of the given class. BECPlotBase: The widget of the given class.
""" """
for widget_id, widget in self.widgets.items(): for widget_id, widget in self._widgets.items():
if isinstance(widget, widget_class): if isinstance(widget, widget_class):
return widget return widget
if can_fail: if can_fail:
@ -420,7 +444,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
widget = self._get_widget_by_coordinates(row, col) widget = self._get_widget_by_coordinates(row, col)
if widget: if widget:
widget_id = widget.config.gui_id widget_id = widget.config.gui_id
if widget_id in self.widgets: if widget_id in self._widgets:
self._remove_by_id(widget_id) self._remove_by_id(widget_id)
def _remove_by_id(self, widget_id: str) -> None: def _remove_by_id(self, widget_id: str) -> None:
@ -429,8 +453,8 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
Args: Args:
widget_id(str): The unique identifier of the widget to remove. widget_id(str): The unique identifier of the widget to remove.
""" """
if widget_id in self.widgets: if widget_id in self._widgets:
widget = self.widgets.pop(widget_id) widget = self._widgets.pop(widget_id)
widget.cleanup() widget.cleanup()
self.removeItem(widget) self.removeItem(widget)
self.grid[widget.config.row][widget.config.col] = None self.grid[widget.config.row][widget.config.col] = None
@ -445,10 +469,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
if isinstance(key, tuple) and len(key) == 2: if isinstance(key, tuple) and len(key) == 2:
return self._get_widget_by_coordinates(*key) return self._get_widget_by_coordinates(*key)
elif isinstance(key, str): elif isinstance(key, str):
widget = self.widgets.get(key) widget = self._widgets.get(key)
if widget is None: if widget is None:
raise KeyError(f"No widget with ID {key}") raise KeyError(f"No widget with ID {key}")
return self.widgets.get(key) return self._widgets.get(key)
else: else:
raise TypeError( raise TypeError(
"Key must be a string (widget id) or a tuple of two integers (grid coordinates)" "Key must be a string (widget id) or a tuple of two integers (grid coordinates)"
@ -478,7 +502,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
def _generate_unique_widget_id(self): def _generate_unique_widget_id(self):
"""Generate a unique widget ID.""" """Generate a unique widget ID."""
existing_ids = set(self.widgets.keys()) existing_ids = set(self._widgets.keys())
for i in itertools.count(1): for i in itertools.count(1):
widget_id = f"widget_{i}" widget_id = f"widget_{i}"
if widget_id not in existing_ids: if widget_id not in existing_ids:
@ -511,7 +535,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Update the config of each object to reflect its new position # Update the config of each object to reflect its new position
for row_idx, row in enumerate(new_grid): for row_idx, row in enumerate(new_grid):
for col_idx, widget in enumerate(row): for col_idx, widget in enumerate(row):
self.widgets[widget].config.row, self.widgets[widget].config.col = row_idx, col_idx self._widgets[widget].config.row, self._widgets[widget].config.col = (
row_idx,
col_idx,
)
self.grid = new_grid self.grid = new_grid
self._replot_layout() self._replot_layout()
@ -521,7 +548,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
self.clear() self.clear()
for row_idx, row in enumerate(self.grid): for row_idx, row in enumerate(self.grid):
for col_idx, widget in enumerate(row): for col_idx, widget in enumerate(row):
self.addItem(self.widgets[widget], row=row_idx, col=col_idx) self.addItem(self._widgets[widget], row=row_idx, col=col_idx)
def change_layout(self, max_columns=None, max_rows=None): def change_layout(self, max_columns=None, max_rows=None):
""" """
@ -533,7 +560,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
max_rows (Optional[int]): The new maximum number of rows in the figure. max_rows (Optional[int]): The new maximum number of rows in the figure.
""" """
# Calculate total number of widgets # Calculate total number of widgets
total_widgets = len(self.widgets) total_widgets = len(self._widgets)
if max_columns: if max_columns:
# Calculate the required number of rows based on max_columns # Calculate the required number of rows based on max_columns
@ -549,7 +576,7 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
# Populate the new grid with widgets' IDs # Populate the new grid with widgets' IDs
current_idx = 0 current_idx = 0
for widget_id, widget in self.widgets.items(): for widget_id, widget in self._widgets.items():
row = current_idx // len(new_grid[0]) row = current_idx // len(new_grid[0])
col = current_idx % len(new_grid[0]) col = current_idx % len(new_grid[0])
new_grid[row][col] = widget_id new_grid[row][col] = widget_id
@ -565,10 +592,10 @@ class BECFigure(BECConnector, pg.GraphicsLayoutWidget):
def clear_all(self): def clear_all(self):
"""Clear all widgets from the figure and reset to default state""" """Clear all widgets from the figure and reset to default state"""
for widget in self.widgets.values(): for widget in self._widgets.values():
widget.cleanup() widget.cleanup()
self.clear() self.clear()
self.widgets = defaultdict(dict) self._widgets = defaultdict(dict)
self.grid = [] self.grid = []
theme = self.config.theme theme = self.config.theme
self.config = FigureConfig( self.config = FigureConfig(

View File

@ -219,7 +219,6 @@ class BECWaveform1D(BECPlotBase):
"remove_curve", "remove_curve",
"scan_history", "scan_history",
"curves", "curves",
# "curves_data",
"get_curve", "get_curve",
"get_curve_config", "get_curve_config",
"apply_config", "apply_config",
@ -356,19 +355,6 @@ class BECWaveform1D(BECPlotBase):
def curves(self, value: list[BECCurve]): def curves(self, value: list[BECCurve]):
self._curves = value self._curves = value
@property
def curves_data(self) -> dict[str, dict[str, BECCurve]]:
"""
Get the curves data of the plot widget as a dictionary
Returns:
dict: Dictionary of curves data.
"""
return self._curves_data
@curves_data.setter
def curves_data(self, value: dict[str, dict[str, BECCurve]]):
self._curves_data = value
def get_curve(self, identifier) -> BECCurve: def get_curve(self, identifier) -> BECCurve:
""" """
Get the curve by its index or ID. Get the curve by its index or ID.

View File

@ -36,7 +36,7 @@ def test_bec_figure_init_with_config(mocked_client):
def test_bec_figure_add_remove_plot(bec_figure): def test_bec_figure_add_remove_plot(bec_figure):
initial_count = len(bec_figure.widgets) initial_count = len(bec_figure._widgets)
# Adding 3 widgets - 2 WaveformBase and 1 PlotBase # Adding 3 widgets - 2 WaveformBase and 1 PlotBase
w0 = bec_figure.add_plot() w0 = bec_figure.add_plot()
@ -44,13 +44,13 @@ def test_bec_figure_add_remove_plot(bec_figure):
w2 = bec_figure.add_widget(widget_id="test_plot", widget_type="PlotBase") w2 = bec_figure.add_widget(widget_id="test_plot", widget_type="PlotBase")
# Check if the widgets were added # Check if the widgets were added
assert len(bec_figure.widgets) == initial_count + 3 assert len(bec_figure._widgets) == initial_count + 3
assert "widget_1" in bec_figure.widgets assert "widget_1" in bec_figure._widgets
assert "test_plot" in bec_figure.widgets assert "test_plot" in bec_figure._widgets
assert "test_waveform" in bec_figure.widgets assert "test_waveform" in bec_figure._widgets
assert bec_figure.widgets["widget_1"].config.widget_class == "BECWaveform1D" assert bec_figure._widgets["widget_1"].config.widget_class == "BECWaveform1D"
assert bec_figure.widgets["test_plot"].config.widget_class == "BECPlotBase" assert bec_figure._widgets["test_plot"].config.widget_class == "BECPlotBase"
assert bec_figure.widgets["test_waveform"].config.widget_class == "BECWaveform1D" assert bec_figure._widgets["test_waveform"].config.widget_class == "BECWaveform1D"
# Check accessing positions by the grid in figure # Check accessing positions by the grid in figure
assert bec_figure[0, 0] == w0 assert bec_figure[0, 0] == w0
@ -59,10 +59,10 @@ def test_bec_figure_add_remove_plot(bec_figure):
# Removing 1 widget - PlotBase # Removing 1 widget - PlotBase
bec_figure.remove(widget_id="test_plot") bec_figure.remove(widget_id="test_plot")
assert len(bec_figure.widgets) == initial_count + 2 assert len(bec_figure._widgets) == initial_count + 2
assert "test_plot" not in bec_figure.widgets assert "test_plot" not in bec_figure._widgets
assert "test_waveform" in bec_figure.widgets assert "test_waveform" in bec_figure._widgets
assert bec_figure.widgets["test_waveform"].config.widget_class == "BECWaveform1D" assert bec_figure._widgets["test_waveform"].config.widget_class == "BECWaveform1D"
def test_access_widgets_access_errors(bec_figure): def test_access_widgets_access_errors(bec_figure):
@ -116,21 +116,21 @@ def test_remove_plots(bec_figure):
# remove by coordinates # remove by coordinates
bec_figure[0, 0].remove() bec_figure[0, 0].remove()
assert "test_waveform_1" not in bec_figure.widgets assert "test_waveform_1" not in bec_figure._widgets
# remove by widget_id # remove by widget_id
bec_figure.remove(widget_id="test_waveform_2") bec_figure.remove(widget_id="test_waveform_2")
assert "test_waveform_2" not in bec_figure.widgets assert "test_waveform_2" not in bec_figure._widgets
# remove by widget object # remove by widget object
w3.remove() w3.remove()
assert "test_waveform_3" not in bec_figure.widgets assert "test_waveform_3" not in bec_figure._widgets
# check the remaining widget 4 # check the remaining widget 4
assert bec_figure[0, 0] == w4 assert bec_figure[0, 0] == w4
assert bec_figure["test_waveform_4"] == w4 assert bec_figure["test_waveform_4"] == w4
assert "test_waveform_4" in bec_figure.widgets assert "test_waveform_4" in bec_figure._widgets
assert len(bec_figure.widgets) == 1 assert len(bec_figure._widgets) == 1
def test_remove_plots_by_coordinates_ints(bec_figure): def test_remove_plots_by_coordinates_ints(bec_figure):
@ -138,10 +138,10 @@ def test_remove_plots_by_coordinates_ints(bec_figure):
w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1) w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1)
bec_figure.remove(0, 0) bec_figure.remove(0, 0)
assert "test_waveform_1" not in bec_figure.widgets assert "test_waveform_1" not in bec_figure._widgets
assert "test_waveform_2" in bec_figure.widgets assert "test_waveform_2" in bec_figure._widgets
assert bec_figure[0, 0] == w2 assert bec_figure[0, 0] == w2
assert len(bec_figure.widgets) == 1 assert len(bec_figure._widgets) == 1
def test_remove_plots_by_coordinates_tuple(bec_figure): def test_remove_plots_by_coordinates_tuple(bec_figure):
@ -149,10 +149,10 @@ def test_remove_plots_by_coordinates_tuple(bec_figure):
w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1) w2 = bec_figure.add_plot(widget_id="test_waveform_2", row=0, col=1)
bec_figure.remove(coordinates=(0, 0)) bec_figure.remove(coordinates=(0, 0))
assert "test_waveform_1" not in bec_figure.widgets assert "test_waveform_1" not in bec_figure._widgets
assert "test_waveform_2" in bec_figure.widgets assert "test_waveform_2" in bec_figure._widgets
assert bec_figure[0, 0] == w2 assert bec_figure[0, 0] == w2
assert len(bec_figure.widgets) == 1 assert len(bec_figure._widgets) == 1
def test_remove_plot_by_id_error(bec_figure): def test_remove_plot_by_id_error(bec_figure):
@ -222,5 +222,5 @@ def test_clear_all(bec_figure):
bec_figure.clear_all() bec_figure.clear_all()
assert len(bec_figure.widgets) == 0 assert len(bec_figure._widgets) == 0
assert np.shape(bec_figure.grid) == (0,) assert np.shape(bec_figure.grid) == (0,)