from textwrap import dedent import black import isort from bec_widgets.cli.generate_cli import ClientGenerator # pylint: disable=missing-function-docstring # Mock classes to test the generator class MockBECWaveform1D: USER_ACCESS = ["set_frequency", "set_amplitude"] def set_frequency(self, frequency: float) -> list: """Set the frequency of the waveform.""" return [frequency] def set_amplitude(self, amplitude: float) -> tuple[float, float]: """Set the amplitude of the waveform.""" return amplitude, amplitude class MockBECFigure: USER_ACCESS = ["add_plot", "remove_plot"] def add_plot(self, plot_id: str): """Add a plot to the figure.""" def remove_plot(self, plot_id: str): """Remove a plot from the figure.""" def test_client_generator_with_black_formatting(): generator = ClientGenerator() rpc_classes = { "connector_classes": [MockBECWaveform1D, MockBECFigure], "top_level_classes": [MockBECFigure], } generator.generate_client(rpc_classes) # Format the expected output with black to ensure it matches the generator output expected_output = dedent( '''\ # This file was automatically generated by generate_cli.py import enum from typing import Literal, Optional, overload from bec_widgets.cli.client_utils import BECGuiClientMixin, RPCBase, rpc_call # pylint: skip-file class Widgets(str, enum.Enum): """ Enum for the available widgets. """ MockBECFigure = "MockBECFigure" class MockBECWaveform1D(RPCBase): @rpc_call def set_frequency(self, frequency: float) -> list: """ Set the frequency of the waveform. """ @rpc_call def set_amplitude(self, amplitude: float) -> tuple[float, float]: """ Set the amplitude of the waveform. """ class MockBECFigure(RPCBase): @rpc_call def add_plot(self, plot_id: str): """ Add a plot to the figure. """ @rpc_call def remove_plot(self, plot_id: str): """ Remove a plot from the figure. """ ''' ) expected_output_formatted = black.format_str( expected_output, mode=black.FileMode(line_length=100) ).lstrip() generated_output_formatted = black.format_str( generator.header + "\n" + generator.content, mode=black.FileMode(line_length=100) ) generated_output_formatted = isort.code(generated_output_formatted) assert expected_output_formatted == generated_output_formatted def test_client_generator_classes(): generator = ClientGenerator() out = generator.get_rpc_classes("bec_widgets") assert list(out.keys()) == ["connector_classes", "top_level_classes"] connector_cls_names = [cls.__name__ for cls in out["connector_classes"]] top_level_cls_names = [cls.__name__ for cls in out["top_level_classes"]] assert "BECFigure" in connector_cls_names assert "BECWaveform" in connector_cls_names assert "BECDockArea" in top_level_cls_names assert "BECFigure" in top_level_cls_names assert "BECWaveform" not in top_level_cls_names