0
0
mirror of https://github.com/bec-project/bec_widgets.git synced 2025-07-14 03:31:50 +02:00
Files
bec_widgets/tests/unit_tests/test_generate_cli_client.py

260 lines
7.8 KiB
Python

from textwrap import dedent
from unittest import mock
import black
import isort
import pytest
from bec_widgets.cli.generate_cli import ClientGenerator
from bec_widgets.utils.plugin_utils import BECClassContainer, BECClassInfo
# pylint: disable=missing-function-docstring
# Mock classes to test the generator
class MockBECWaveform1D:
USER_ACCESS = ["set_frequency", "set_amplitude"]
def set_frequency(self, frequency: float) -> list:
"""Set the frequency of the waveform."""
return [frequency]
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
"""Set the amplitude of the waveform."""
return amplitude, amplitude
class MockBECFigure:
USER_ACCESS = ["add_plot", "remove_plot"]
def add_plot(self, plot_id: str):
"""Add a plot to the figure."""
def remove_plot(self, plot_id: str):
"""Remove a plot from the figure."""
def test_client_generator_with_black_formatting():
generator = ClientGenerator(base=True)
container = BECClassContainer()
container.add_class(
BECClassInfo(
name="MockBECWaveform1D",
module="test_module",
file="test_file",
obj=MockBECWaveform1D,
is_connector=True,
is_widget=True,
is_plugin=False,
)
)
container.add_class(
BECClassInfo(
name="MockBECFigure",
module="test_module",
file="test_file",
obj=MockBECFigure,
is_connector=True,
is_widget=True,
is_plugin=True,
)
)
generator.generate_client(container)
# Format the expected output with black to ensure it matches the generator output
expected_output = dedent(
'''\
# This file was automatically generated by generate_cli.py
# type: ignore
from __future__ import annotations
import enum
import inspect
import traceback
from functools import reduce
from operator import add
from typing import Literal, Optional
from bec_lib.logger import bec_logger
from bec_widgets.cli.rpc.rpc_base import RPCBase, rpc_call
from bec_widgets.utils.bec_plugin_helper import (get_all_plugin_widgets,
get_plugin_client_module)
logger = bec_logger.logger
# pylint: skip-file
class _WidgetsEnumType(str, enum.Enum):
"""Enum for the available widgets, to be generated programatically"""
...
_Widgets = {
"MockBECFigure": "MockBECFigure",
}
try:
_plugin_widgets = get_all_plugin_widgets()
plugin_client = get_plugin_client_module()
Widgets = _WidgetsEnumType("Widgets", {name: name for name in _plugin_widgets} | _Widgets)
if (_overlap := _Widgets.keys() & _plugin_widgets.keys()) != set():
for _widget in _overlap:
logger.warning(f"Detected duplicate widget {_widget} in plugin repo file: {inspect.getfile(_plugin_widgets[_widget])} !")
for plugin_name, plugin_class in inspect.getmembers(plugin_client, inspect.isclass):
if issubclass(plugin_class, RPCBase) and plugin_class is not RPCBase:
if plugin_name in globals():
conflicting_file = (
inspect.getfile(_plugin_widgets[plugin_name])
if plugin_name in _plugin_widgets
else f"{plugin_client}"
)
logger.warning(
f"Plugin widget {plugin_name} from {conflicting_file} conflicts with a built-in class!"
)
continue
if plugin_name not in _overlap:
globals()[plugin_name] = plugin_class
except ImportError as e:
logger.error(f"Failed loading plugins: \\n{reduce(add, traceback.format_exception(e))}")
class MockBECFigure(RPCBase):
@rpc_call
def add_plot(self, plot_id: str):
"""
Add a plot to the figure.
"""
@rpc_call
def remove_plot(self, plot_id: str):
"""
Remove a plot from the figure.
"""
class MockBECWaveform1D(RPCBase):
@rpc_call
def set_frequency(self, frequency: float) -> list:
"""
Set the frequency of the waveform.
"""
@rpc_call
def set_amplitude(self, amplitude: float) -> tuple[float, float]:
"""
Set the amplitude of the waveform.
"""
'''
)
expected_output_formatted = black.format_str(
expected_output, mode=black.FileMode(line_length=100)
).lstrip()
generated_output_formatted = black.format_str(
generator.header + "\n" + generator.content, mode=black.FileMode(line_length=100)
)
generated_output_formatted = isort.code(generated_output_formatted)
expected_output_formatted = isort.code(expected_output_formatted)
assert expected_output_formatted == generated_output_formatted
def test_client_generator_init():
"""
Test the initialization of the ClientGenerator class.
"""
generator = ClientGenerator()
assert generator.header.startswith("# This file was automatically generated by generate_cli.py")
assert generator.content == ""
def test_generate_client():
"""
Test the generate_client method of the ClientGenerator class.
"""
generator = ClientGenerator()
class_container = mock.MagicMock(spec=BECClassContainer)
class_container.rpc_top_level_classes = [mock.MagicMock(RPC=True, __name__="TestClass1")]
class_container.connector_classes = [mock.MagicMock(RPC=True, __name__="TestClass2")]
generator.generate_client(class_container)
assert '"TestClass1": "TestClass1"' in generator.content
assert "class TestClass2(RPCBase):" in generator.content
@pytest.mark.parametrize("plugin", (True, False))
def test_write_client_enum(plugin):
"""
Test the write_client_enum method of the ClientGenerator class.
"""
generator = ClientGenerator(base=plugin)
published_classes = [
mock.MagicMock(__name__="TestClass1"),
mock.MagicMock(__name__="TestClass2"),
]
generator.write_client_enum(published_classes)
assert ("class _WidgetsEnumType(str, enum.Enum):" in generator.content) is plugin
assert '"TestClass1": "TestClass1",' in generator.content
assert '"TestClass2": "TestClass2",' in generator.content
def test_generate_content_for_class():
"""
Test the generate_content_for_class method of the ClientGenerator class.
"""
generator = ClientGenerator()
cls = mock.MagicMock(__name__="TestClass", USER_ACCESS=["method1"])
method = mock.MagicMock()
method.__name__ = "method1"
method.__doc__ = "Test method"
method_signature = "(self)"
cls.method1 = method
with mock.patch("inspect.signature", return_value=method_signature):
generator.generate_content_for_class(cls)
assert "class TestClass(RPCBase):" in generator.content
assert "def method1(self):" in generator.content
assert "Test method" in generator.content
def test_write_is_black_formatted(tmp_path):
"""
Test the write method of the ClientGenerator class.
"""
generator = ClientGenerator()
generator.content = """
def test_content():
pass
a=1
b=2
c=a+b
"""
corrected = """def test_content():
pass
a = 1
b = 2
c = a + b"""
file_name = tmp_path / "test_client.py"
generator.write(str(file_name))
with open(file_name, "r", encoding="utf-8") as file:
content = file.read()
assert corrected in content