feat: Scan metadata schema plugin

- adds a new submodule to the `scans` module in the plugin repo where
  models may be defined against which the metadata added to
  ScanQueueMessage should be validated
This commit is contained in:
perl_d 2025-01-28 17:22:55 +01:00
parent 070e20d9aa
commit 8bfe544066
10 changed files with 334 additions and 20 deletions

View File

@ -7,7 +7,9 @@ from copy import deepcopy
from typing import Any, ClassVar, Literal from typing import Any, ClassVar, Literal
import numpy as np import numpy as np
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
from bec_lib.metadata_schema import BasicScanMetadata, get_metadata_schema_for_scan
class BECStatus(enum.Enum): class BECStatus(enum.Enum):
@ -127,6 +129,20 @@ class ScanQueueMessage(BECMessage):
parameter: dict parameter: dict
queue: str = Field(default="primary") queue: str = Field(default="primary")
@model_validator(mode="after")
@classmethod
def _validate_metadata(cls, data):
"""Make sure the metadata conforms to the registered schema, but
leave it as a dict"""
schema = get_metadata_schema_for_scan(data.scan_type)
try:
schema.model_validate(data.metadata)
except ValidationError as e:
raise ValueError(
f"Scan metadata {data.metadata} does not conform to registered schema {schema}. \n Errors: {str(e)}"
) from e
return data
class ScanQueueHistoryMessage(BECMessage): class ScanQueueHistoryMessage(BECMessage):
"""Sent after removal from the active queue. Contains information about the scan. """Sent after removal from the active queue. Contains information about the scan.

View File

@ -0,0 +1,45 @@
from functools import cache
from pydantic import BaseModel, ConfigDict
from bec_lib import plugin_helper
from bec_lib.logger import bec_logger
logger = bec_logger.logger
_METADATA_SCHEMA_REGISTRY = {}
class BasicScanMetadata(BaseModel):
"""Scan metadata base class which behaves like a dict, and will accept any keys,
like the existing metadata field in messages, but can be extended to add required
fields for specific scans."""
model_config = ConfigDict(extra="allow", validate_assignment=True)
@cache
def _get_metadata_schema_registry() -> dict[str, type[BasicScanMetadata]]:
plugin_schema = plugin_helper.get_metadata_schema_registry()
for name, schema in list(plugin_schema.items()):
try:
if not issubclass(schema, BasicScanMetadata):
logger.warning(
f"Schema {schema} for {name} in the plugin registry is not valid! It must subclass BasicScanMetadata"
)
del plugin_schema[name]
except TypeError:
logger.warning(
f"Schema {schema} for {name} in the plugin registry is not a valid type!"
)
del plugin_schema[name]
return _METADATA_SCHEMA_REGISTRY | plugin_schema
def cache_clear():
return _get_metadata_schema_registry.cache_clear()
def get_metadata_schema_for_scan(scan_name: str):
"""Return the pydantic model (must be a subclass of BasicScanMetadata)
associated with the given scan. If none is found, returns BasicScanMetadata."""
return _get_metadata_schema_registry().get(scan_name) or BasicScanMetadata

View File

@ -82,6 +82,19 @@ def get_file_writer_plugins() -> dict:
return loaded_plugins return loaded_plugins
def get_metadata_schema_registry() -> dict:
module = _get_available_plugins("bec.scans.metadata_schema")
if len(module) == 0:
logger.warning("No plugin metadata schema module found!")
return {}
try:
registry_module = importlib.import_module(module[0].__name__ + ".metadata_schema_registry")
return registry_module.METADATA_SCHEMA
except Exception as e:
logger.error(f"Error while loading metadata schema registry from plugins: {e}")
return {}
def get_ipython_client_startup_plugins(state: Literal["pre", "post"]) -> dict: def get_ipython_client_startup_plugins(state: Literal["pre", "post"]) -> dict:
""" """
Load all IPython client startup plugins. Load all IPython client startup plugins.

View File

@ -0,0 +1,59 @@
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from bec_lib import metadata_schema
from bec_lib.messages import ScanQueueMessage
from bec_lib.metadata_schema import BasicScanMetadata
TEST_DICT = {"foo": "bar", "baz": 123}
class ChildMetadata(BasicScanMetadata):
number_field: int
TEST_REGISTRY = {
"fake_scan_with_extra_metadata": ChildMetadata,
"fake_scan_with_basic_metadata": BasicScanMetadata,
}
@pytest.fixture(scope="module", autouse=True)
def clear_schema_registry_cache():
metadata_schema.cache_clear()
def test_required_fields_validate():
with pytest.raises(ValidationError):
test_metadata = ChildMetadata.model_validate(TEST_DICT)
test_metadata = ChildMetadata.model_validate(TEST_DICT | {"number_field": 123})
assert test_metadata.number_field == 123
test_metadata.number_field = 234
assert test_metadata.number_field == 234
with pytest.raises(ValidationError):
test_metadata.number_field = "string"
def test_creating_scan_queue_message_validates_metadata():
with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True):
with pytest.raises(ValidationError):
ScanQueueMessage(scan_type="fake_scan_with_extra_metadata")
with pytest.raises(ValidationError):
ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata",
parameter={},
metadata={"number_field", "string"},
)
ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata", parameter={}, metadata={"number_field": 123}
)
msg_with_extra_keys = ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata",
parameter={},
metadata={"number_field": 123, "extra": "data"},
)
assert msg_with_extra_keys.metadata["extra"] == "data"

View File

@ -2,33 +2,59 @@ import importlib
import os import os
import subprocess import subprocess
import sys import sys
from time import sleep
import pytest import pytest
from pytest import TempPathFactory from pytest import TempPathFactory
from bec_lib import plugin_helper from bec_lib import metadata_schema, plugin_helper
def install(package): def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) subprocess.check_call([sys.executable, "-m", "pip", "-v", "install", package])
def uninstall(package): def uninstall(package):
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package]) subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package])
TEST_SCHEMA_FILE = """
from bec_lib.metadata_schema import BasicScanMetadata
class ExampleSchema(BasicScanMetadata):
treatment_description: str
treatment_temperature_k: int
"""
TEST_SCHEMA_REGISTRY = """
from .example_schema import ExampleSchema
METADATA_SCHEMA = {
"test_scan_fail_on_type": "FailOnType",
"example_scan": ExampleSchema,
}
"""
TEST_SCAN_CLASS = """
from bec_server.scan_server.scans import ScanBase
class ScanForTesting(ScanBase):
...
"""
class TestPluginSystem: class TestPluginSystem:
@pytest.fixture(scope="class", autouse=True) @pytest.fixture(scope="class", autouse=True)
def setup_env(self, tmp_path_factory: TempPathFactory): def setup_env(self, tmp_path_factory: TempPathFactory):
print("\n\nSetting up plugin for tests: generating files...\n")
TestPluginSystem._tmp_plugin_dir = tmp_path_factory.mktemp("test_plugin") TestPluginSystem._tmp_plugin_dir = tmp_path_factory.mktemp("test_plugin")
TestPluginSystem._tmp_plugin_name = TestPluginSystem._tmp_plugin_dir.name TestPluginSystem._tmp_plugin_name = TestPluginSystem._tmp_plugin_dir.name
TestPluginSystem._plugin_script = ( TestPluginSystem._plugin_script = (
os.path.dirname(os.path.abspath(__file__)) os.path.dirname(os.path.abspath(__file__))
+ "/../util_scripts/create_plugin_structure.py" + "/../util_scripts/create_plugin_structure.py"
) )
print("Done. Modifying files with test code...\n")
# run plugin generation script # run plugin generation script
subprocess.check_call( subprocess.check_call(
[sys.executable, TestPluginSystem._plugin_script, str(TestPluginSystem._tmp_plugin_dir)] [sys.executable, TestPluginSystem._plugin_script, str(TestPluginSystem._tmp_plugin_dir)]
@ -40,24 +66,48 @@ class TestPluginSystem:
/ f"{TestPluginSystem._tmp_plugin_name}/scans/__init__.py", / f"{TestPluginSystem._tmp_plugin_name}/scans/__init__.py",
"w+", "w+",
) as f: ) as f:
f.writelines( f.write(TEST_SCAN_CLASS)
[
"from bec_server.scan_server.scans import ScanBase\n", with open(
"class ScanForTesting: ...\n", TestPluginSystem._tmp_plugin_dir
] / f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/metadata_schema_registry.py",
) "w",
) as f:
f.write(TEST_SCHEMA_REGISTRY)
with open(
TestPluginSystem._tmp_plugin_dir
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/example_schema.py",
"w",
) as f:
f.write(TEST_SCHEMA_FILE)
print("\nDone. Installing into environment...\n")
# install into current environment # install into current environment
install(TestPluginSystem._tmp_plugin_dir) install(TestPluginSystem._tmp_plugin_dir)
importlib.invalidate_caches() importlib.invalidate_caches()
plugin_helper._get_available_plugins.cache_clear()
piplist = subprocess.Popen(("pip", "list"), stdout=subprocess.PIPE)
output = subprocess.check_output(("grep", "test_plugin"), stdin=piplist.stdout)
piplist.wait()
print("$ pip list | grep test_plugin: \n" + output.decode("utf-8"))
print("Done. Yielding to test class...\n")
yield yield
uninstall(TestPluginSystem._tmp_plugin_name) print("\n\nDone. Uninstalling test plugin:\n")
uninstall(TestPluginSystem._tmp_plugin_name)
importlib.invalidate_caches()
plugin_helper._get_available_plugins.cache_clear()
metadata_schema._METADATA_SCHEMA_REGISTRY = {}
del sys.modules["bec_lib.metadata_schema"]
TestPluginSystem._tmp_plugin_dir = None TestPluginSystem._tmp_plugin_dir = None
def test_files_in_plugin_deployment(self, setup_env): def test_generated_files_in_plugin_deployment(self):
files = os.listdir(TestPluginSystem._tmp_plugin_dir) files = os.listdir(TestPluginSystem._tmp_plugin_dir)
for file in [ for file in [
TestPluginSystem._tmp_plugin_name, TestPluginSystem._tmp_plugin_name,
@ -70,15 +120,49 @@ class TestPluginSystem:
".gitlab-ci.yml", ".gitlab-ci.yml",
]: ]:
assert file in files assert file in files
files = os.listdir(TestPluginSystem._tmp_plugin_dir / TestPluginSystem._tmp_plugin_name)
for file in ["scans"]:
assert file in files
def test_plugin_module_import_from_file(self, setup_env): def test_test_created_files_in_plugin_deployment(self):
spec = importlib.util.spec_from_file_location( files = os.listdir(
TestPluginSystem._tmp_plugin_name, TestPluginSystem._tmp_plugin_dir
str(TestPluginSystem._tmp_plugin_dir) + "/__init__.py", / f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema"
) )
plugin_module = importlib.util.module_from_spec(spec) for file in ["example_schema.py", "metadata_schema_registry.py"]:
assert file in files
def test_plugin_modules_import_from_file(self, setup_env): def test_plugin_module_import_from_generated_file(self):
try:
package_spec = importlib.util.spec_from_file_location(
TestPluginSystem._tmp_plugin_name,
TestPluginSystem._tmp_plugin_dir
/ TestPluginSystem._tmp_plugin_name
/ "__init__.py",
)
plugin_module = importlib.util.module_from_spec(package_spec)
package_spec.loader.exec_module(plugin_module)
md_reg_mod_name = (
TestPluginSystem._tmp_plugin_name
+ ".scans.metadata_schema.metadata_schema_registry"
)
md_reg_spec = importlib.util.spec_from_file_location(
md_reg_mod_name,
TestPluginSystem._tmp_plugin_dir
/ TestPluginSystem._tmp_plugin_name
/ "scans/metadata_schema"
/ "metadata_schema_registry.py",
)
md_reg_module = importlib.util.module_from_spec(md_reg_spec)
md_reg_spec.loader.exec_module(md_reg_module)
assert md_reg_module.METADATA_SCHEMA is not None
finally:
for mod in [TestPluginSystem._tmp_plugin_name, md_reg_mod_name]:
if mod in sys.modules:
del sys.modules[mod]
def test_plugin_modules_import_from_file(self):
importlib.import_module(TestPluginSystem._tmp_plugin_name) importlib.import_module(TestPluginSystem._tmp_plugin_name)
for submod in [ for submod in [
"scans", "scans",
@ -90,6 +174,16 @@ class TestPluginSystem:
]: ]:
importlib.import_module(TestPluginSystem._tmp_plugin_name + "." + submod) importlib.import_module(TestPluginSystem._tmp_plugin_name + "." + submod)
def test_plugin_helper(self, setup_env): def test_plugin_helper_for_scans(self):
plugin_scans_modules = plugin_helper._get_available_plugins("bec.scans")
assert len(plugin_scans_modules) > 0
scan_plugins = plugin_helper.get_scan_plugins() scan_plugins = plugin_helper.get_scan_plugins()
assert "ScanForTesting" in scan_plugins.keys() assert "ScanForTesting" in scan_plugins.keys()
def test_plugin_helper_for_metadata_schema(self):
metadata_schema_plugin_module = plugin_helper._get_available_plugins(
"bec.scans.metadata_schema"
)
assert len(metadata_schema_plugin_module) > 0
metadata_registry = plugin_helper.get_metadata_schema_registry()
assert set(["test_scan_fail_on_type", "example_scan"]) == set(metadata_registry.keys())

View File

@ -70,6 +70,9 @@ class PluginStructure:
self.create_dir(f"{self.plugin_name}/scans") self.create_dir(f"{self.plugin_name}/scans")
self.create_init_file(f"{self.plugin_name}/scans") self.create_init_file(f"{self.plugin_name}/scans")
self.create_dir(f"{self.plugin_name}/scans/metadata_schema")
self.create_init_file(f"{self.plugin_name}/scans/metadata_schema")
# copy scan_plugin_template.py # copy scan_plugin_template.py
scan_plugin_template_file = os.path.join( scan_plugin_template_file = os.path.join(
current_dir, "plugin_setup_files", "scan_plugin_template.py" current_dir, "plugin_setup_files", "scan_plugin_template.py"
@ -191,6 +194,7 @@ if __name__ == "__main__":
struc.add_plugins() struc.add_plugins()
struc.copy_plugin_setup_files() struc.copy_plugin_setup_files()
struc.add_scans() struc.add_scans()
struc.add_metadata_schema()
struc.add_client() struc.add_client()
struc.add_devices() struc.add_devices()
struc.add_device_configs() struc.add_device_configs()

View File

@ -0,0 +1,8 @@
# from .schema_template import ExampleSchema
METADATA_SCHEMA = {
# Add models which should be used to validate scan metadata here.
# Make a model according to the template, and import it as above
# Then associate it with a scan like so:
# "example_scan": ExampleSchema
}

View File

@ -0,0 +1,34 @@
# # By inheriting from BasicScanMetadata you can define a schema by which metadata
# # supplied to a scan must be validated.
# # This schema is a Pydantic model: https://docs.pydantic.dev/latest/concepts/models/
# # but by default it will still allow you to add any arbitrary information to it.
# # That is to say, when you run a scan with which such a model has been associated in the
# # metadata_schema_registry, you can supply any python dictionary with strings as keys
# # and built-in python types (strings, integers, floats) as values, and these will be
# # added to the experiment metadata, but it *must* contain the keys and values of the
# # types defined in the schema class.
# #
# #
# # For example, say that you would like to enforce recording information about sample
# # pretreatment, you could define the following:
# #
#
# from bec_lib.metadata_schema import BasicScanMetadata
#
#
# class ExampleSchema(BasicScanMetadata):
# treatment_description: str
# treatment_temperature_k: int
#
#
# # If this was used according to the example in metadata_schema_registry.py,
# # then when calling the scan, the user would need to write something like:
# >>> scans.example_scan(
# >>> motor,
# >>> 1,
# >>> 2,
# >>> 3,
# >>> metadata={"treatment_description": "oven overnight", "treatment_temperature_k": 575},
# >>> )
#
# # And the additional metadata would be saved in the HDF5 file created for the scan.

View File

@ -38,6 +38,9 @@ plugin_file_writer = "{template_name}.file_writer"
[project.entry-points."bec.scans"] [project.entry-points."bec.scans"]
plugin_scans = "{template_name}.scans" plugin_scans = "{template_name}.scans"
[project.entry-points."bec.scans.metadata_schema"]
plugin_metadata_schema = "{template_name}.scans.metadata_schema"
[project.entry-points."bec.ipython_client_startup"] [project.entry-points."bec.ipython_client_startup"]
plugin_ipython_client_pre = "{template_name}.bec_ipython_client.startup.pre_startup" plugin_ipython_client_pre = "{template_name}.bec_ipython_client.startup.pre_startup"
plugin_ipython_client_post = "{template_name}.bec_ipython_client.startup" plugin_ipython_client_post = "{template_name}.bec_ipython_client.startup"

View File

@ -370,6 +370,44 @@ def test_remove_queue_item(queuemanager_mock):
assert len(queue_manager.queues["primary"].queue) == 0 assert len(queue_manager.queues["primary"].queue) == 0
def test_invalid_scan_specified_in_message(queuemanager_mock):
queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage(
scan_type="fake test scan which does not exist!",
parameter={"args": {"samx": (1,)}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
with mock.patch.object(queue_manager, "connector") as connector:
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
connector.raise_alarm.assert_called_once_with(
severity=Alarms.MAJOR,
source=msg.content,
msg="fake test scan which does not exist!",
alarm_type="KeyError",
metadata={"RID": "something"},
)
def test_invalid_scan_specified_in_message(queuemanager_mock):
queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage(
scan_type="fake test scan which does not exist!",
parameter={"args": {"samx": (1,)}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
with mock.patch.object(queue_manager, "connector") as connector:
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
connector.raise_alarm.assert_called_once_with(
severity=Alarms.MAJOR,
source=msg.content,
msg="fake test scan which does not exist!",
alarm_type="KeyError",
metadata={"RID": "something"},
)
def test_set_clear(queuemanager_mock): def test_set_clear(queuemanager_mock):
queue_manager = queuemanager_mock() queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage( msg = messages.ScanQueueMessage(