mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-20 01:40:02 +02:00
feat: Scan metadata schema plugin
- adds a new submodule to the `scans` module in the plugin repo where models may be defined against which the metadata added to ScanQueueMessage should be validated
This commit is contained in:
parent
070e20d9aa
commit
8bfe544066
@ -7,7 +7,9 @@ from copy import deepcopy
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
|
||||
|
||||
from bec_lib.metadata_schema import BasicScanMetadata, get_metadata_schema_for_scan
|
||||
|
||||
|
||||
class BECStatus(enum.Enum):
|
||||
@ -127,6 +129,20 @@ class ScanQueueMessage(BECMessage):
|
||||
parameter: dict
|
||||
queue: str = Field(default="primary")
|
||||
|
||||
@model_validator(mode="after")
|
||||
@classmethod
|
||||
def _validate_metadata(cls, data):
|
||||
"""Make sure the metadata conforms to the registered schema, but
|
||||
leave it as a dict"""
|
||||
schema = get_metadata_schema_for_scan(data.scan_type)
|
||||
try:
|
||||
schema.model_validate(data.metadata)
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"Scan metadata {data.metadata} does not conform to registered schema {schema}. \n Errors: {str(e)}"
|
||||
) from e
|
||||
return data
|
||||
|
||||
|
||||
class ScanQueueHistoryMessage(BECMessage):
|
||||
"""Sent after removal from the active queue. Contains information about the scan.
|
||||
|
45
bec_lib/bec_lib/metadata_schema.py
Normal file
45
bec_lib/bec_lib/metadata_schema.py
Normal file
@ -0,0 +1,45 @@
|
||||
from functools import cache
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from bec_lib import plugin_helper
|
||||
from bec_lib.logger import bec_logger
|
||||
|
||||
logger = bec_logger.logger
|
||||
_METADATA_SCHEMA_REGISTRY = {}
|
||||
|
||||
|
||||
class BasicScanMetadata(BaseModel):
|
||||
"""Scan metadata base class which behaves like a dict, and will accept any keys,
|
||||
like the existing metadata field in messages, but can be extended to add required
|
||||
fields for specific scans."""
|
||||
|
||||
model_config = ConfigDict(extra="allow", validate_assignment=True)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_metadata_schema_registry() -> dict[str, type[BasicScanMetadata]]:
|
||||
plugin_schema = plugin_helper.get_metadata_schema_registry()
|
||||
for name, schema in list(plugin_schema.items()):
|
||||
try:
|
||||
if not issubclass(schema, BasicScanMetadata):
|
||||
logger.warning(
|
||||
f"Schema {schema} for {name} in the plugin registry is not valid! It must subclass BasicScanMetadata"
|
||||
)
|
||||
del plugin_schema[name]
|
||||
except TypeError:
|
||||
logger.warning(
|
||||
f"Schema {schema} for {name} in the plugin registry is not a valid type!"
|
||||
)
|
||||
del plugin_schema[name]
|
||||
return _METADATA_SCHEMA_REGISTRY | plugin_schema
|
||||
|
||||
|
||||
def cache_clear():
|
||||
return _get_metadata_schema_registry.cache_clear()
|
||||
|
||||
|
||||
def get_metadata_schema_for_scan(scan_name: str):
|
||||
"""Return the pydantic model (must be a subclass of BasicScanMetadata)
|
||||
associated with the given scan. If none is found, returns BasicScanMetadata."""
|
||||
return _get_metadata_schema_registry().get(scan_name) or BasicScanMetadata
|
@ -82,6 +82,19 @@ def get_file_writer_plugins() -> dict:
|
||||
return loaded_plugins
|
||||
|
||||
|
||||
def get_metadata_schema_registry() -> dict:
|
||||
module = _get_available_plugins("bec.scans.metadata_schema")
|
||||
if len(module) == 0:
|
||||
logger.warning("No plugin metadata schema module found!")
|
||||
return {}
|
||||
try:
|
||||
registry_module = importlib.import_module(module[0].__name__ + ".metadata_schema_registry")
|
||||
return registry_module.METADATA_SCHEMA
|
||||
except Exception as e:
|
||||
logger.error(f"Error while loading metadata schema registry from plugins: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_ipython_client_startup_plugins(state: Literal["pre", "post"]) -> dict:
|
||||
"""
|
||||
Load all IPython client startup plugins.
|
||||
|
59
bec_lib/tests/test_metadata_schema.py
Normal file
59
bec_lib/tests/test_metadata_schema.py
Normal file
@ -0,0 +1,59 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from bec_lib import metadata_schema
|
||||
from bec_lib.messages import ScanQueueMessage
|
||||
from bec_lib.metadata_schema import BasicScanMetadata
|
||||
|
||||
TEST_DICT = {"foo": "bar", "baz": 123}
|
||||
|
||||
|
||||
class ChildMetadata(BasicScanMetadata):
|
||||
number_field: int
|
||||
|
||||
|
||||
TEST_REGISTRY = {
|
||||
"fake_scan_with_extra_metadata": ChildMetadata,
|
||||
"fake_scan_with_basic_metadata": BasicScanMetadata,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def clear_schema_registry_cache():
|
||||
metadata_schema.cache_clear()
|
||||
|
||||
|
||||
def test_required_fields_validate():
|
||||
with pytest.raises(ValidationError):
|
||||
test_metadata = ChildMetadata.model_validate(TEST_DICT)
|
||||
|
||||
test_metadata = ChildMetadata.model_validate(TEST_DICT | {"number_field": 123})
|
||||
assert test_metadata.number_field == 123
|
||||
test_metadata.number_field = 234
|
||||
assert test_metadata.number_field == 234
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
test_metadata.number_field = "string"
|
||||
|
||||
|
||||
def test_creating_scan_queue_message_validates_metadata():
|
||||
with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True):
|
||||
with pytest.raises(ValidationError):
|
||||
ScanQueueMessage(scan_type="fake_scan_with_extra_metadata")
|
||||
with pytest.raises(ValidationError):
|
||||
ScanQueueMessage(
|
||||
scan_type="fake_scan_with_extra_metadata",
|
||||
parameter={},
|
||||
metadata={"number_field", "string"},
|
||||
)
|
||||
ScanQueueMessage(
|
||||
scan_type="fake_scan_with_extra_metadata", parameter={}, metadata={"number_field": 123}
|
||||
)
|
||||
msg_with_extra_keys = ScanQueueMessage(
|
||||
scan_type="fake_scan_with_extra_metadata",
|
||||
parameter={},
|
||||
metadata={"number_field": 123, "extra": "data"},
|
||||
)
|
||||
assert msg_with_extra_keys.metadata["extra"] == "data"
|
@ -2,33 +2,59 @@ import importlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
|
||||
import pytest
|
||||
from pytest import TempPathFactory
|
||||
|
||||
from bec_lib import plugin_helper
|
||||
from bec_lib import metadata_schema, plugin_helper
|
||||
|
||||
|
||||
def install(package):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "-v", "install", package])
|
||||
|
||||
|
||||
def uninstall(package):
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package])
|
||||
|
||||
|
||||
TEST_SCHEMA_FILE = """
|
||||
from bec_lib.metadata_schema import BasicScanMetadata
|
||||
|
||||
class ExampleSchema(BasicScanMetadata):
|
||||
treatment_description: str
|
||||
treatment_temperature_k: int
|
||||
|
||||
"""
|
||||
|
||||
TEST_SCHEMA_REGISTRY = """
|
||||
from .example_schema import ExampleSchema
|
||||
|
||||
METADATA_SCHEMA = {
|
||||
"test_scan_fail_on_type": "FailOnType",
|
||||
"example_scan": ExampleSchema,
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
TEST_SCAN_CLASS = """
|
||||
from bec_server.scan_server.scans import ScanBase
|
||||
class ScanForTesting(ScanBase):
|
||||
...
|
||||
"""
|
||||
|
||||
|
||||
class TestPluginSystem:
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
def setup_env(self, tmp_path_factory: TempPathFactory):
|
||||
print("\n\nSetting up plugin for tests: generating files...\n")
|
||||
TestPluginSystem._tmp_plugin_dir = tmp_path_factory.mktemp("test_plugin")
|
||||
TestPluginSystem._tmp_plugin_name = TestPluginSystem._tmp_plugin_dir.name
|
||||
TestPluginSystem._plugin_script = (
|
||||
os.path.dirname(os.path.abspath(__file__))
|
||||
+ "/../util_scripts/create_plugin_structure.py"
|
||||
)
|
||||
|
||||
print("Done. Modifying files with test code...\n")
|
||||
# run plugin generation script
|
||||
subprocess.check_call(
|
||||
[sys.executable, TestPluginSystem._plugin_script, str(TestPluginSystem._tmp_plugin_dir)]
|
||||
@ -40,24 +66,48 @@ class TestPluginSystem:
|
||||
/ f"{TestPluginSystem._tmp_plugin_name}/scans/__init__.py",
|
||||
"w+",
|
||||
) as f:
|
||||
f.writelines(
|
||||
[
|
||||
"from bec_server.scan_server.scans import ScanBase\n",
|
||||
"class ScanForTesting: ...\n",
|
||||
]
|
||||
)
|
||||
f.write(TEST_SCAN_CLASS)
|
||||
|
||||
with open(
|
||||
TestPluginSystem._tmp_plugin_dir
|
||||
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/metadata_schema_registry.py",
|
||||
"w",
|
||||
) as f:
|
||||
f.write(TEST_SCHEMA_REGISTRY)
|
||||
|
||||
with open(
|
||||
TestPluginSystem._tmp_plugin_dir
|
||||
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/example_schema.py",
|
||||
"w",
|
||||
) as f:
|
||||
f.write(TEST_SCHEMA_FILE)
|
||||
|
||||
print("\nDone. Installing into environment...\n")
|
||||
|
||||
# install into current environment
|
||||
install(TestPluginSystem._tmp_plugin_dir)
|
||||
importlib.invalidate_caches()
|
||||
plugin_helper._get_available_plugins.cache_clear()
|
||||
|
||||
piplist = subprocess.Popen(("pip", "list"), stdout=subprocess.PIPE)
|
||||
output = subprocess.check_output(("grep", "test_plugin"), stdin=piplist.stdout)
|
||||
piplist.wait()
|
||||
print("$ pip list | grep test_plugin: \n" + output.decode("utf-8"))
|
||||
|
||||
print("Done. Yielding to test class...\n")
|
||||
|
||||
yield
|
||||
|
||||
uninstall(TestPluginSystem._tmp_plugin_name)
|
||||
print("\n\nDone. Uninstalling test plugin:\n")
|
||||
|
||||
uninstall(TestPluginSystem._tmp_plugin_name)
|
||||
importlib.invalidate_caches()
|
||||
plugin_helper._get_available_plugins.cache_clear()
|
||||
metadata_schema._METADATA_SCHEMA_REGISTRY = {}
|
||||
del sys.modules["bec_lib.metadata_schema"]
|
||||
TestPluginSystem._tmp_plugin_dir = None
|
||||
|
||||
def test_files_in_plugin_deployment(self, setup_env):
|
||||
def test_generated_files_in_plugin_deployment(self):
|
||||
files = os.listdir(TestPluginSystem._tmp_plugin_dir)
|
||||
for file in [
|
||||
TestPluginSystem._tmp_plugin_name,
|
||||
@ -70,15 +120,49 @@ class TestPluginSystem:
|
||||
".gitlab-ci.yml",
|
||||
]:
|
||||
assert file in files
|
||||
files = os.listdir(TestPluginSystem._tmp_plugin_dir / TestPluginSystem._tmp_plugin_name)
|
||||
for file in ["scans"]:
|
||||
assert file in files
|
||||
|
||||
def test_plugin_module_import_from_file(self, setup_env):
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
TestPluginSystem._tmp_plugin_name,
|
||||
str(TestPluginSystem._tmp_plugin_dir) + "/__init__.py",
|
||||
def test_test_created_files_in_plugin_deployment(self):
|
||||
files = os.listdir(
|
||||
TestPluginSystem._tmp_plugin_dir
|
||||
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema"
|
||||
)
|
||||
plugin_module = importlib.util.module_from_spec(spec)
|
||||
for file in ["example_schema.py", "metadata_schema_registry.py"]:
|
||||
assert file in files
|
||||
|
||||
def test_plugin_modules_import_from_file(self, setup_env):
|
||||
def test_plugin_module_import_from_generated_file(self):
|
||||
try:
|
||||
package_spec = importlib.util.spec_from_file_location(
|
||||
TestPluginSystem._tmp_plugin_name,
|
||||
TestPluginSystem._tmp_plugin_dir
|
||||
/ TestPluginSystem._tmp_plugin_name
|
||||
/ "__init__.py",
|
||||
)
|
||||
plugin_module = importlib.util.module_from_spec(package_spec)
|
||||
package_spec.loader.exec_module(plugin_module)
|
||||
|
||||
md_reg_mod_name = (
|
||||
TestPluginSystem._tmp_plugin_name
|
||||
+ ".scans.metadata_schema.metadata_schema_registry"
|
||||
)
|
||||
md_reg_spec = importlib.util.spec_from_file_location(
|
||||
md_reg_mod_name,
|
||||
TestPluginSystem._tmp_plugin_dir
|
||||
/ TestPluginSystem._tmp_plugin_name
|
||||
/ "scans/metadata_schema"
|
||||
/ "metadata_schema_registry.py",
|
||||
)
|
||||
md_reg_module = importlib.util.module_from_spec(md_reg_spec)
|
||||
md_reg_spec.loader.exec_module(md_reg_module)
|
||||
assert md_reg_module.METADATA_SCHEMA is not None
|
||||
finally:
|
||||
for mod in [TestPluginSystem._tmp_plugin_name, md_reg_mod_name]:
|
||||
if mod in sys.modules:
|
||||
del sys.modules[mod]
|
||||
|
||||
def test_plugin_modules_import_from_file(self):
|
||||
importlib.import_module(TestPluginSystem._tmp_plugin_name)
|
||||
for submod in [
|
||||
"scans",
|
||||
@ -90,6 +174,16 @@ class TestPluginSystem:
|
||||
]:
|
||||
importlib.import_module(TestPluginSystem._tmp_plugin_name + "." + submod)
|
||||
|
||||
def test_plugin_helper(self, setup_env):
|
||||
def test_plugin_helper_for_scans(self):
|
||||
plugin_scans_modules = plugin_helper._get_available_plugins("bec.scans")
|
||||
assert len(plugin_scans_modules) > 0
|
||||
scan_plugins = plugin_helper.get_scan_plugins()
|
||||
assert "ScanForTesting" in scan_plugins.keys()
|
||||
|
||||
def test_plugin_helper_for_metadata_schema(self):
|
||||
metadata_schema_plugin_module = plugin_helper._get_available_plugins(
|
||||
"bec.scans.metadata_schema"
|
||||
)
|
||||
assert len(metadata_schema_plugin_module) > 0
|
||||
metadata_registry = plugin_helper.get_metadata_schema_registry()
|
||||
assert set(["test_scan_fail_on_type", "example_scan"]) == set(metadata_registry.keys())
|
||||
|
@ -70,6 +70,9 @@ class PluginStructure:
|
||||
self.create_dir(f"{self.plugin_name}/scans")
|
||||
self.create_init_file(f"{self.plugin_name}/scans")
|
||||
|
||||
self.create_dir(f"{self.plugin_name}/scans/metadata_schema")
|
||||
self.create_init_file(f"{self.plugin_name}/scans/metadata_schema")
|
||||
|
||||
# copy scan_plugin_template.py
|
||||
scan_plugin_template_file = os.path.join(
|
||||
current_dir, "plugin_setup_files", "scan_plugin_template.py"
|
||||
@ -191,6 +194,7 @@ if __name__ == "__main__":
|
||||
struc.add_plugins()
|
||||
struc.copy_plugin_setup_files()
|
||||
struc.add_scans()
|
||||
struc.add_metadata_schema()
|
||||
struc.add_client()
|
||||
struc.add_devices()
|
||||
struc.add_device_configs()
|
||||
|
@ -0,0 +1,8 @@
|
||||
# from .schema_template import ExampleSchema
|
||||
|
||||
METADATA_SCHEMA = {
|
||||
# Add models which should be used to validate scan metadata here.
|
||||
# Make a model according to the template, and import it as above
|
||||
# Then associate it with a scan like so:
|
||||
# "example_scan": ExampleSchema
|
||||
}
|
@ -0,0 +1,34 @@
|
||||
# # By inheriting from BasicScanMetadata you can define a schema by which metadata
|
||||
# # supplied to a scan must be validated.
|
||||
# # This schema is a Pydantic model: https://docs.pydantic.dev/latest/concepts/models/
|
||||
# # but by default it will still allow you to add any arbitrary information to it.
|
||||
# # That is to say, when you run a scan with which such a model has been associated in the
|
||||
# # metadata_schema_registry, you can supply any python dictionary with strings as keys
|
||||
# # and built-in python types (strings, integers, floats) as values, and these will be
|
||||
# # added to the experiment metadata, but it *must* contain the keys and values of the
|
||||
# # types defined in the schema class.
|
||||
# #
|
||||
# #
|
||||
# # For example, say that you would like to enforce recording information about sample
|
||||
# # pretreatment, you could define the following:
|
||||
# #
|
||||
#
|
||||
# from bec_lib.metadata_schema import BasicScanMetadata
|
||||
#
|
||||
#
|
||||
# class ExampleSchema(BasicScanMetadata):
|
||||
# treatment_description: str
|
||||
# treatment_temperature_k: int
|
||||
#
|
||||
#
|
||||
# # If this was used according to the example in metadata_schema_registry.py,
|
||||
# # then when calling the scan, the user would need to write something like:
|
||||
# >>> scans.example_scan(
|
||||
# >>> motor,
|
||||
# >>> 1,
|
||||
# >>> 2,
|
||||
# >>> 3,
|
||||
# >>> metadata={"treatment_description": "oven overnight", "treatment_temperature_k": 575},
|
||||
# >>> )
|
||||
#
|
||||
# # And the additional metadata would be saved in the HDF5 file created for the scan.
|
@ -38,6 +38,9 @@ plugin_file_writer = "{template_name}.file_writer"
|
||||
[project.entry-points."bec.scans"]
|
||||
plugin_scans = "{template_name}.scans"
|
||||
|
||||
[project.entry-points."bec.scans.metadata_schema"]
|
||||
plugin_metadata_schema = "{template_name}.scans.metadata_schema"
|
||||
|
||||
[project.entry-points."bec.ipython_client_startup"]
|
||||
plugin_ipython_client_pre = "{template_name}.bec_ipython_client.startup.pre_startup"
|
||||
plugin_ipython_client_post = "{template_name}.bec_ipython_client.startup"
|
||||
|
@ -370,6 +370,44 @@ def test_remove_queue_item(queuemanager_mock):
|
||||
assert len(queue_manager.queues["primary"].queue) == 0
|
||||
|
||||
|
||||
def test_invalid_scan_specified_in_message(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
msg = messages.ScanQueueMessage(
|
||||
scan_type="fake test scan which does not exist!",
|
||||
parameter={"args": {"samx": (1,)}, "kwargs": {}},
|
||||
queue="primary",
|
||||
metadata={"RID": "something"},
|
||||
)
|
||||
with mock.patch.object(queue_manager, "connector") as connector:
|
||||
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
|
||||
connector.raise_alarm.assert_called_once_with(
|
||||
severity=Alarms.MAJOR,
|
||||
source=msg.content,
|
||||
msg="fake test scan which does not exist!",
|
||||
alarm_type="KeyError",
|
||||
metadata={"RID": "something"},
|
||||
)
|
||||
|
||||
|
||||
def test_invalid_scan_specified_in_message(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
msg = messages.ScanQueueMessage(
|
||||
scan_type="fake test scan which does not exist!",
|
||||
parameter={"args": {"samx": (1,)}, "kwargs": {}},
|
||||
queue="primary",
|
||||
metadata={"RID": "something"},
|
||||
)
|
||||
with mock.patch.object(queue_manager, "connector") as connector:
|
||||
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
|
||||
connector.raise_alarm.assert_called_once_with(
|
||||
severity=Alarms.MAJOR,
|
||||
source=msg.content,
|
||||
msg="fake test scan which does not exist!",
|
||||
alarm_type="KeyError",
|
||||
metadata={"RID": "something"},
|
||||
)
|
||||
|
||||
|
||||
def test_set_clear(queuemanager_mock):
|
||||
queue_manager = queuemanager_mock()
|
||||
msg = messages.ScanQueueMessage(
|
||||
|
Loading…
x
Reference in New Issue
Block a user