mirror of
https://github.com/bec-project/bec_widgets.git
synced 2025-07-14 03:31:50 +02:00
260 lines
7.8 KiB
Python
260 lines
7.8 KiB
Python
from textwrap import dedent
|
|
from unittest import mock
|
|
|
|
import black
|
|
import isort
|
|
import pytest
|
|
|
|
from bec_widgets.cli.generate_cli import ClientGenerator
|
|
from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo
|
|
|
|
# pylint: disable=missing-function-docstring
|
|
|
|
|
|
# Mock classes to test the generator
|
|
class MockBECWaveform1D:
|
|
USER_ACCESS = ["set_frequency", "set_amplitude"]
|
|
|
|
def set_frequency(self, frequency: float) -> list:
|
|
"""Set the frequency of the waveform."""
|
|
return [frequency]
|
|
|
|
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
|
|
"""Set the amplitude of the waveform."""
|
|
return amplitude, amplitude
|
|
|
|
|
|
class MockBECFigure:
|
|
USER_ACCESS = ["add_plot", "remove_plot"]
|
|
|
|
def add_plot(self, plot_id: str):
|
|
"""Add a plot to the figure."""
|
|
|
|
def remove_plot(self, plot_id: str):
|
|
"""Remove a plot from the figure."""
|
|
|
|
|
|
def test_client_generator_with_black_formatting():
|
|
generator = ClientGenerator(base=True)
|
|
container = BECClassContainer()
|
|
container.add_class(
|
|
BECClassInfo(
|
|
name="MockBECWaveform1D",
|
|
module="test_module",
|
|
file="test_file",
|
|
obj=MockBECWaveform1D,
|
|
is_connector=True,
|
|
is_widget=True,
|
|
is_plugin=False,
|
|
)
|
|
)
|
|
container.add_class(
|
|
BECClassInfo(
|
|
name="MockBECFigure",
|
|
module="test_module",
|
|
file="test_file",
|
|
obj=MockBECFigure,
|
|
is_connector=True,
|
|
is_widget=True,
|
|
is_plugin=True,
|
|
)
|
|
)
|
|
|
|
generator.generate_client(container)
|
|
|
|
# Format the expected output with black to ensure it matches the generator output
|
|
expected_output = dedent(
|
|
'''\
|
|
# This file was automatically generated by generate_cli.py
|
|
# type: ignore
|
|
|
|
from __future__ import annotations
|
|
|
|
import enum
|
|
import inspect
|
|
import traceback
|
|
from functools import reduce
|
|
from operator import add
|
|
from typing import Literal, Optional
|
|
|
|
from bec_lib.logger import bec_logger
|
|
|
|
from bec_widgets.cli.rpc.rpc_base import RPCBase, rpc_call
|
|
from bec_widgets.utils.bec_plugin_helper import (get_all_plugin_widgets,
|
|
get_plugin_client_module)
|
|
|
|
logger = bec_logger.logger
|
|
|
|
# pylint: skip-file
|
|
|
|
|
|
class _WidgetsEnumType(str, enum.Enum):
|
|
"""Enum for the available widgets, to be generated programatically"""
|
|
|
|
...
|
|
|
|
_Widgets = {
|
|
"MockBECFigure": "MockBECFigure",
|
|
}
|
|
|
|
|
|
try:
|
|
_plugin_widgets = get_all_plugin_widgets()
|
|
plugin_client = get_plugin_client_module()
|
|
Widgets = _WidgetsEnumType("Widgets", {name: name for name in _plugin_widgets} | _Widgets)
|
|
|
|
if (_overlap := _Widgets.keys() & _plugin_widgets.keys()) != set():
|
|
for _widget in _overlap:
|
|
logger.warning(f"Detected duplicate widget {_widget} in plugin repo file: {inspect.getfile(_plugin_widgets[_widget])} !")
|
|
for plugin_name, plugin_class in inspect.getmembers(plugin_client, inspect.isclass):
|
|
if issubclass(plugin_class, RPCBase) and plugin_class is not RPCBase:
|
|
if plugin_name in globals():
|
|
conflicting_file = (
|
|
inspect.getfile(_plugin_widgets[plugin_name])
|
|
if plugin_name in _plugin_widgets
|
|
else f"{plugin_client}"
|
|
)
|
|
logger.warning(
|
|
f"Plugin widget {plugin_name} from {conflicting_file} conflicts with a built-in class!"
|
|
)
|
|
continue
|
|
if plugin_name not in _overlap:
|
|
globals()[plugin_name] = plugin_class
|
|
except ImportError as e:
|
|
logger.error(f"Failed loading plugins: \\n{reduce(add, traceback.format_exception(e))}")
|
|
|
|
class MockBECFigure(RPCBase):
|
|
@rpc_call
|
|
def add_plot(self, plot_id: str):
|
|
"""
|
|
Add a plot to the figure.
|
|
"""
|
|
|
|
@rpc_call
|
|
def remove_plot(self, plot_id: str):
|
|
"""
|
|
Remove a plot from the figure.
|
|
"""
|
|
|
|
|
|
class MockBECWaveform1D(RPCBase):
|
|
@rpc_call
|
|
def set_frequency(self, frequency: float) -> list:
|
|
"""
|
|
Set the frequency of the waveform.
|
|
"""
|
|
|
|
@rpc_call
|
|
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
|
|
"""
|
|
Set the amplitude of the waveform.
|
|
"""
|
|
'''
|
|
)
|
|
|
|
expected_output_formatted = black.format_str(
|
|
expected_output, mode=black.FileMode(line_length=100)
|
|
).lstrip()
|
|
|
|
generated_output_formatted = black.format_str(
|
|
generator.header + "\n" + generator.content, mode=black.FileMode(line_length=100)
|
|
)
|
|
|
|
generated_output_formatted = isort.code(generated_output_formatted)
|
|
expected_output_formatted = isort.code(expected_output_formatted)
|
|
|
|
assert expected_output_formatted == generated_output_formatted
|
|
|
|
|
|
def test_client_generator_init():
|
|
"""
|
|
Test the initialization of the ClientGenerator class.
|
|
"""
|
|
generator = ClientGenerator()
|
|
assert generator.header.startswith("# This file was automatically generated by generate_cli.py")
|
|
assert generator.content == ""
|
|
|
|
|
|
def test_generate_client():
|
|
"""
|
|
Test the generate_client method of the ClientGenerator class.
|
|
"""
|
|
generator = ClientGenerator()
|
|
class_container = mock.MagicMock(spec=BECClassContainer)
|
|
class_container.rpc_top_level_classes = [mock.MagicMock(RPC=True, __name__="TestClass1")]
|
|
class_container.connector_classes = [mock.MagicMock(RPC=True, __name__="TestClass2")]
|
|
|
|
generator.generate_client(class_container)
|
|
|
|
assert '"TestClass1": "TestClass1"' in generator.content
|
|
assert "class TestClass2(RPCBase):" in generator.content
|
|
|
|
|
|
@pytest.mark.parametrize("plugin", (True, False))
|
|
def test_write_client_enum(plugin):
|
|
"""
|
|
Test the write_client_enum method of the ClientGenerator class.
|
|
"""
|
|
generator = ClientGenerator(base=plugin)
|
|
published_classes = [
|
|
mock.MagicMock(__name__="TestClass1"),
|
|
mock.MagicMock(__name__="TestClass2"),
|
|
]
|
|
|
|
generator.write_client_enum(published_classes)
|
|
|
|
assert ("class _WidgetsEnumType(str, enum.Enum):" in generator.content) is plugin
|
|
assert '"TestClass1": "TestClass1",' in generator.content
|
|
assert '"TestClass2": "TestClass2",' in generator.content
|
|
|
|
|
|
def test_generate_content_for_class():
|
|
"""
|
|
Test the generate_content_for_class method of the ClientGenerator class.
|
|
"""
|
|
generator = ClientGenerator()
|
|
cls = mock.MagicMock(__name__="TestClass", USER_ACCESS=["method1"])
|
|
method = mock.MagicMock()
|
|
method.__name__ = "method1"
|
|
method.__doc__ = "Test method"
|
|
method_signature = "(self)"
|
|
cls.method1 = method
|
|
|
|
with mock.patch("inspect.signature", return_value=method_signature):
|
|
generator.generate_content_for_class(cls)
|
|
|
|
assert "class TestClass(RPCBase):" in generator.content
|
|
assert "def method1(self):" in generator.content
|
|
assert "Test method" in generator.content
|
|
|
|
|
|
def test_write_is_black_formatted(tmp_path):
|
|
"""
|
|
Test the write method of the ClientGenerator class.
|
|
"""
|
|
generator = ClientGenerator()
|
|
generator.content = """
|
|
def test_content():
|
|
pass
|
|
|
|
a=1
|
|
b=2
|
|
c=a+b
|
|
"""
|
|
|
|
corrected = """def test_content():
|
|
pass
|
|
|
|
|
|
a = 1
|
|
b = 2
|
|
c = a + b"""
|
|
file_name = tmp_path / "test_client.py"
|
|
|
|
generator.write(str(file_name))
|
|
|
|
with open(file_name, "r", encoding="utf-8") as file:
|
|
content = file.read()
|
|
|
|
assert corrected in content
|