from textwrap import dedent from unittest import mock import black import isort import pytest from bec_widgets.cli.generate_cli import ClientGenerator from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo # pylint: disable=missing-function-docstring # Mock classes to test the generator class MockBECWaveform1D: USER_ACCESS = ["set_frequency", "set_amplitude"] def set_frequency(self, frequency: float) -> list: """Set the frequency of the waveform.""" return [frequency] def set_amplitude(self, amplitude: float) -> tuple[float, float]: """Set the amplitude of the waveform.""" return amplitude, amplitude class MockBECFigure: USER_ACCESS = ["add_plot", "remove_plot"] def add_plot(self, plot_id: str): """Add a plot to the figure.""" def remove_plot(self, plot_id: str): """Remove a plot from the figure.""" def test_client_generator_with_black_formatting(): generator = ClientGenerator(base=True) container = BECClassContainer() container.add_class( BECClassInfo( name="MockBECWaveform1D", module="test_module", file="test_file", obj=MockBECWaveform1D, is_connector=True, is_widget=True, is_plugin=False, ) ) container.add_class( BECClassInfo( name="MockBECFigure", module="test_module", file="test_file", obj=MockBECFigure, is_connector=True, is_widget=True, is_plugin=True, ) ) generator.generate_client(container) # Format the expected output with black to ensure it matches the generator output expected_output = dedent( '''\ # This file was automatically generated by generate_cli.py # type: ignore from __future__ import annotations import enum import inspect import traceback from functools import reduce from operator import add from typing import Literal, Optional from bec_lib.logger import bec_logger from bec_widgets.cli.rpc.rpc_base import RPCBase, rpc_call from bec_widgets.utils.bec_plugin_helper import (get_all_plugin_widgets, get_plugin_client_module) logger = bec_logger.logger # pylint: skip-file class _WidgetsEnumType(str, enum.Enum): """Enum for the available widgets, to be generated programatically""" ... _Widgets = { "MockBECFigure": "MockBECFigure", } try: _plugin_widgets = get_all_plugin_widgets() plugin_client = get_plugin_client_module() Widgets = _WidgetsEnumType("Widgets", {name: name for name in _plugin_widgets} | _Widgets) if (_overlap := _Widgets.keys() & _plugin_widgets.keys()) != set(): for _widget in _overlap: logger.warning(f"Detected duplicate widget {_widget} in plugin repo file: {inspect.getfile(_plugin_widgets[_widget])} !") for plugin_name, plugin_class in inspect.getmembers(plugin_client, inspect.isclass): if issubclass(plugin_class, RPCBase) and plugin_class is not RPCBase: if plugin_name in globals(): conflicting_file = ( inspect.getfile(_plugin_widgets[plugin_name]) if plugin_name in _plugin_widgets else f"{plugin_client}" ) logger.warning( f"Plugin widget {plugin_name} from {conflicting_file} conflicts with a built-in class!" ) continue if plugin_name not in _overlap: globals()[plugin_name] = plugin_class except ImportError as e: logger.error(f"Failed loading plugins: \\n{reduce(add, traceback.format_exception(e))}") class MockBECFigure(RPCBase): @rpc_call def add_plot(self, plot_id: str): """ Add a plot to the figure. """ @rpc_call def remove_plot(self, plot_id: str): """ Remove a plot from the figure. """ class MockBECWaveform1D(RPCBase): @rpc_call def set_frequency(self, frequency: float) -> list: """ Set the frequency of the waveform. """ @rpc_call def set_amplitude(self, amplitude: float) -> tuple[float, float]: """ Set the amplitude of the waveform. """ ''' ) expected_output_formatted = black.format_str( expected_output, mode=black.FileMode(line_length=100) ).lstrip() generated_output_formatted = black.format_str( generator.header + "\n" + generator.content, mode=black.FileMode(line_length=100) ) generated_output_formatted = isort.code(generated_output_formatted) expected_output_formatted = isort.code(expected_output_formatted) assert expected_output_formatted == generated_output_formatted def test_client_generator_init(): """ Test the initialization of the ClientGenerator class. """ generator = ClientGenerator() assert generator.header.startswith("# This file was automatically generated by generate_cli.py") assert generator.content == "" def test_generate_client(): """ Test the generate_client method of the ClientGenerator class. """ generator = ClientGenerator() class_container = mock.MagicMock(spec=BECClassContainer) class_container.rpc_top_level_classes = [mock.MagicMock(RPC=True, __name__="TestClass1")] class_container.connector_classes = [mock.MagicMock(RPC=True, __name__="TestClass2")] generator.generate_client(class_container) assert '"TestClass1": "TestClass1"' in generator.content assert "class TestClass2(RPCBase):" in generator.content @pytest.mark.parametrize("plugin", (True, False)) def test_write_client_enum(plugin): """ Test the write_client_enum method of the ClientGenerator class. """ generator = ClientGenerator(base=plugin) published_classes = [ mock.MagicMock(__name__="TestClass1"), mock.MagicMock(__name__="TestClass2"), ] generator.write_client_enum(published_classes) assert ("class _WidgetsEnumType(str, enum.Enum):" in generator.content) is plugin assert '"TestClass1": "TestClass1",' in generator.content assert '"TestClass2": "TestClass2",' in generator.content def test_generate_content_for_class(): """ Test the generate_content_for_class method of the ClientGenerator class. """ generator = ClientGenerator() cls = mock.MagicMock(__name__="TestClass", USER_ACCESS=["method1"]) method = mock.MagicMock() method.__name__ = "method1" method.__doc__ = "Test method" method_signature = "(self)" cls.method1 = method with mock.patch("inspect.signature", return_value=method_signature): generator.generate_content_for_class(cls) assert "class TestClass(RPCBase):" in generator.content assert "def method1(self):" in generator.content assert "Test method" in generator.content def test_write_is_black_formatted(tmp_path): """ Test the write method of the ClientGenerator class. """ generator = ClientGenerator() generator.content = """ def test_content(): pass a=1 b=2 c=a+b """ corrected = """def test_content(): pass a = 1 b = 2 c = a + b""" file_name = tmp_path / "test_client.py" generator.write(str(file_name)) with open(file_name, "r", encoding="utf-8") as file: content = file.read() assert corrected in content