feat: Scan metadata schema plugin

- adds a new submodule to the `scans` module in the plugin repo where
  models may be defined against which the metadata added to
  ScanQueueMessage should be validated
This commit is contained in:
perl_d 2025-01-28 17:22:55 +01:00
parent 070e20d9aa
commit 8bfe544066
10 changed files with 334 additions and 20 deletions

View File

@ -7,7 +7,9 @@ from copy import deepcopy
from typing import Any, ClassVar, Literal
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator
from bec_lib.metadata_schema import BasicScanMetadata, get_metadata_schema_for_scan
class BECStatus(enum.Enum):
@ -127,6 +129,20 @@ class ScanQueueMessage(BECMessage):
parameter: dict
queue: str = Field(default="primary")
@model_validator(mode="after")
@classmethod
def _validate_metadata(cls, data):
"""Make sure the metadata conforms to the registered schema, but
leave it as a dict"""
schema = get_metadata_schema_for_scan(data.scan_type)
try:
schema.model_validate(data.metadata)
except ValidationError as e:
raise ValueError(
f"Scan metadata {data.metadata} does not conform to registered schema {schema}. \n Errors: {str(e)}"
) from e
return data
class ScanQueueHistoryMessage(BECMessage):
"""Sent after removal from the active queue. Contains information about the scan.

View File

@ -0,0 +1,45 @@
from functools import cache
from pydantic import BaseModel, ConfigDict
from bec_lib import plugin_helper
from bec_lib.logger import bec_logger
logger = bec_logger.logger
_METADATA_SCHEMA_REGISTRY = {}
class BasicScanMetadata(BaseModel):
"""Scan metadata base class which behaves like a dict, and will accept any keys,
like the existing metadata field in messages, but can be extended to add required
fields for specific scans."""
model_config = ConfigDict(extra="allow", validate_assignment=True)
@cache
def _get_metadata_schema_registry() -> dict[str, type[BasicScanMetadata]]:
plugin_schema = plugin_helper.get_metadata_schema_registry()
for name, schema in list(plugin_schema.items()):
try:
if not issubclass(schema, BasicScanMetadata):
logger.warning(
f"Schema {schema} for {name} in the plugin registry is not valid! It must subclass BasicScanMetadata"
)
del plugin_schema[name]
except TypeError:
logger.warning(
f"Schema {schema} for {name} in the plugin registry is not a valid type!"
)
del plugin_schema[name]
return _METADATA_SCHEMA_REGISTRY | plugin_schema
def cache_clear():
return _get_metadata_schema_registry.cache_clear()
def get_metadata_schema_for_scan(scan_name: str):
"""Return the pydantic model (must be a subclass of BasicScanMetadata)
associated with the given scan. If none is found, returns BasicScanMetadata."""
return _get_metadata_schema_registry().get(scan_name) or BasicScanMetadata

View File

@ -82,6 +82,19 @@ def get_file_writer_plugins() -> dict:
return loaded_plugins
def get_metadata_schema_registry() -> dict:
module = _get_available_plugins("bec.scans.metadata_schema")
if len(module) == 0:
logger.warning("No plugin metadata schema module found!")
return {}
try:
registry_module = importlib.import_module(module[0].__name__ + ".metadata_schema_registry")
return registry_module.METADATA_SCHEMA
except Exception as e:
logger.error(f"Error while loading metadata schema registry from plugins: {e}")
return {}
def get_ipython_client_startup_plugins(state: Literal["pre", "post"]) -> dict:
"""
Load all IPython client startup plugins.

View File

@ -0,0 +1,59 @@
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from bec_lib import metadata_schema
from bec_lib.messages import ScanQueueMessage
from bec_lib.metadata_schema import BasicScanMetadata
TEST_DICT = {"foo": "bar", "baz": 123}
class ChildMetadata(BasicScanMetadata):
number_field: int
TEST_REGISTRY = {
"fake_scan_with_extra_metadata": ChildMetadata,
"fake_scan_with_basic_metadata": BasicScanMetadata,
}
@pytest.fixture(scope="module", autouse=True)
def clear_schema_registry_cache():
metadata_schema.cache_clear()
def test_required_fields_validate():
with pytest.raises(ValidationError):
test_metadata = ChildMetadata.model_validate(TEST_DICT)
test_metadata = ChildMetadata.model_validate(TEST_DICT | {"number_field": 123})
assert test_metadata.number_field == 123
test_metadata.number_field = 234
assert test_metadata.number_field == 234
with pytest.raises(ValidationError):
test_metadata.number_field = "string"
def test_creating_scan_queue_message_validates_metadata():
with patch.dict(metadata_schema._METADATA_SCHEMA_REGISTRY, TEST_REGISTRY, clear=True):
with pytest.raises(ValidationError):
ScanQueueMessage(scan_type="fake_scan_with_extra_metadata")
with pytest.raises(ValidationError):
ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata",
parameter={},
metadata={"number_field", "string"},
)
ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata", parameter={}, metadata={"number_field": 123}
)
msg_with_extra_keys = ScanQueueMessage(
scan_type="fake_scan_with_extra_metadata",
parameter={},
metadata={"number_field": 123, "extra": "data"},
)
assert msg_with_extra_keys.metadata["extra"] == "data"

View File

@ -2,33 +2,59 @@ import importlib
import os
import subprocess
import sys
from time import sleep
import pytest
from pytest import TempPathFactory
from bec_lib import plugin_helper
from bec_lib import metadata_schema, plugin_helper
def install(package):
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
subprocess.check_call([sys.executable, "-m", "pip", "-v", "install", package])
def uninstall(package):
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", package])
TEST_SCHEMA_FILE = """
from bec_lib.metadata_schema import BasicScanMetadata
class ExampleSchema(BasicScanMetadata):
treatment_description: str
treatment_temperature_k: int
"""
TEST_SCHEMA_REGISTRY = """
from .example_schema import ExampleSchema
METADATA_SCHEMA = {
"test_scan_fail_on_type": "FailOnType",
"example_scan": ExampleSchema,
}
"""
TEST_SCAN_CLASS = """
from bec_server.scan_server.scans import ScanBase
class ScanForTesting(ScanBase):
...
"""
class TestPluginSystem:
@pytest.fixture(scope="class", autouse=True)
def setup_env(self, tmp_path_factory: TempPathFactory):
print("\n\nSetting up plugin for tests: generating files...\n")
TestPluginSystem._tmp_plugin_dir = tmp_path_factory.mktemp("test_plugin")
TestPluginSystem._tmp_plugin_name = TestPluginSystem._tmp_plugin_dir.name
TestPluginSystem._plugin_script = (
os.path.dirname(os.path.abspath(__file__))
+ "/../util_scripts/create_plugin_structure.py"
)
print("Done. Modifying files with test code...\n")
# run plugin generation script
subprocess.check_call(
[sys.executable, TestPluginSystem._plugin_script, str(TestPluginSystem._tmp_plugin_dir)]
@ -40,24 +66,48 @@ class TestPluginSystem:
/ f"{TestPluginSystem._tmp_plugin_name}/scans/__init__.py",
"w+",
) as f:
f.writelines(
[
"from bec_server.scan_server.scans import ScanBase\n",
"class ScanForTesting: ...\n",
]
)
f.write(TEST_SCAN_CLASS)
with open(
TestPluginSystem._tmp_plugin_dir
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/metadata_schema_registry.py",
"w",
) as f:
f.write(TEST_SCHEMA_REGISTRY)
with open(
TestPluginSystem._tmp_plugin_dir
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema/example_schema.py",
"w",
) as f:
f.write(TEST_SCHEMA_FILE)
print("\nDone. Installing into environment...\n")
# install into current environment
install(TestPluginSystem._tmp_plugin_dir)
importlib.invalidate_caches()
plugin_helper._get_available_plugins.cache_clear()
piplist = subprocess.Popen(("pip", "list"), stdout=subprocess.PIPE)
output = subprocess.check_output(("grep", "test_plugin"), stdin=piplist.stdout)
piplist.wait()
print("$ pip list | grep test_plugin: \n" + output.decode("utf-8"))
print("Done. Yielding to test class...\n")
yield
uninstall(TestPluginSystem._tmp_plugin_name)
print("\n\nDone. Uninstalling test plugin:\n")
uninstall(TestPluginSystem._tmp_plugin_name)
importlib.invalidate_caches()
plugin_helper._get_available_plugins.cache_clear()
metadata_schema._METADATA_SCHEMA_REGISTRY = {}
del sys.modules["bec_lib.metadata_schema"]
TestPluginSystem._tmp_plugin_dir = None
def test_files_in_plugin_deployment(self, setup_env):
def test_generated_files_in_plugin_deployment(self):
files = os.listdir(TestPluginSystem._tmp_plugin_dir)
for file in [
TestPluginSystem._tmp_plugin_name,
@ -70,15 +120,49 @@ class TestPluginSystem:
".gitlab-ci.yml",
]:
assert file in files
files = os.listdir(TestPluginSystem._tmp_plugin_dir / TestPluginSystem._tmp_plugin_name)
for file in ["scans"]:
assert file in files
def test_plugin_module_import_from_file(self, setup_env):
spec = importlib.util.spec_from_file_location(
TestPluginSystem._tmp_plugin_name,
str(TestPluginSystem._tmp_plugin_dir) + "/__init__.py",
def test_test_created_files_in_plugin_deployment(self):
files = os.listdir(
TestPluginSystem._tmp_plugin_dir
/ f"{TestPluginSystem._tmp_plugin_name}/scans/metadata_schema"
)
plugin_module = importlib.util.module_from_spec(spec)
for file in ["example_schema.py", "metadata_schema_registry.py"]:
assert file in files
def test_plugin_modules_import_from_file(self, setup_env):
def test_plugin_module_import_from_generated_file(self):
try:
package_spec = importlib.util.spec_from_file_location(
TestPluginSystem._tmp_plugin_name,
TestPluginSystem._tmp_plugin_dir
/ TestPluginSystem._tmp_plugin_name
/ "__init__.py",
)
plugin_module = importlib.util.module_from_spec(package_spec)
package_spec.loader.exec_module(plugin_module)
md_reg_mod_name = (
TestPluginSystem._tmp_plugin_name
+ ".scans.metadata_schema.metadata_schema_registry"
)
md_reg_spec = importlib.util.spec_from_file_location(
md_reg_mod_name,
TestPluginSystem._tmp_plugin_dir
/ TestPluginSystem._tmp_plugin_name
/ "scans/metadata_schema"
/ "metadata_schema_registry.py",
)
md_reg_module = importlib.util.module_from_spec(md_reg_spec)
md_reg_spec.loader.exec_module(md_reg_module)
assert md_reg_module.METADATA_SCHEMA is not None
finally:
for mod in [TestPluginSystem._tmp_plugin_name, md_reg_mod_name]:
if mod in sys.modules:
del sys.modules[mod]
def test_plugin_modules_import_from_file(self):
importlib.import_module(TestPluginSystem._tmp_plugin_name)
for submod in [
"scans",
@ -90,6 +174,16 @@ class TestPluginSystem:
]:
importlib.import_module(TestPluginSystem._tmp_plugin_name + "." + submod)
def test_plugin_helper(self, setup_env):
def test_plugin_helper_for_scans(self):
plugin_scans_modules = plugin_helper._get_available_plugins("bec.scans")
assert len(plugin_scans_modules) > 0
scan_plugins = plugin_helper.get_scan_plugins()
assert "ScanForTesting" in scan_plugins.keys()
def test_plugin_helper_for_metadata_schema(self):
metadata_schema_plugin_module = plugin_helper._get_available_plugins(
"bec.scans.metadata_schema"
)
assert len(metadata_schema_plugin_module) > 0
metadata_registry = plugin_helper.get_metadata_schema_registry()
assert set(["test_scan_fail_on_type", "example_scan"]) == set(metadata_registry.keys())

View File

@ -70,6 +70,9 @@ class PluginStructure:
self.create_dir(f"{self.plugin_name}/scans")
self.create_init_file(f"{self.plugin_name}/scans")
self.create_dir(f"{self.plugin_name}/scans/metadata_schema")
self.create_init_file(f"{self.plugin_name}/scans/metadata_schema")
# copy scan_plugin_template.py
scan_plugin_template_file = os.path.join(
current_dir, "plugin_setup_files", "scan_plugin_template.py"
@ -191,6 +194,7 @@ if __name__ == "__main__":
struc.add_plugins()
struc.copy_plugin_setup_files()
struc.add_scans()
struc.add_metadata_schema()
struc.add_client()
struc.add_devices()
struc.add_device_configs()

View File

@ -0,0 +1,8 @@
# from .schema_template import ExampleSchema
METADATA_SCHEMA = {
# Add models which should be used to validate scan metadata here.
# Make a model according to the template, and import it as above
# Then associate it with a scan like so:
# "example_scan": ExampleSchema
}

View File

@ -0,0 +1,34 @@
# # By inheriting from BasicScanMetadata you can define a schema by which metadata
# # supplied to a scan must be validated.
# # This schema is a Pydantic model: https://docs.pydantic.dev/latest/concepts/models/
# # but by default it will still allow you to add any arbitrary information to it.
# # That is to say, when you run a scan with which such a model has been associated in the
# # metadata_schema_registry, you can supply any python dictionary with strings as keys
# # and built-in python types (strings, integers, floats) as values, and these will be
# # added to the experiment metadata, but it *must* contain the keys and values of the
# # types defined in the schema class.
# #
# #
# # For example, say that you would like to enforce recording information about sample
# # pretreatment, you could define the following:
# #
#
# from bec_lib.metadata_schema import BasicScanMetadata
#
#
# class ExampleSchema(BasicScanMetadata):
# treatment_description: str
# treatment_temperature_k: int
#
#
# # If this was used according to the example in metadata_schema_registry.py,
# # then when calling the scan, the user would need to write something like:
# >>> scans.example_scan(
# >>> motor,
# >>> 1,
# >>> 2,
# >>> 3,
# >>> metadata={"treatment_description": "oven overnight", "treatment_temperature_k": 575},
# >>> )
#
# # And the additional metadata would be saved in the HDF5 file created for the scan.

View File

@ -38,6 +38,9 @@ plugin_file_writer = "{template_name}.file_writer"
[project.entry-points."bec.scans"]
plugin_scans = "{template_name}.scans"
[project.entry-points."bec.scans.metadata_schema"]
plugin_metadata_schema = "{template_name}.scans.metadata_schema"
[project.entry-points."bec.ipython_client_startup"]
plugin_ipython_client_pre = "{template_name}.bec_ipython_client.startup.pre_startup"
plugin_ipython_client_post = "{template_name}.bec_ipython_client.startup"

View File

@ -370,6 +370,44 @@ def test_remove_queue_item(queuemanager_mock):
assert len(queue_manager.queues["primary"].queue) == 0
def test_invalid_scan_specified_in_message(queuemanager_mock):
queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage(
scan_type="fake test scan which does not exist!",
parameter={"args": {"samx": (1,)}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
with mock.patch.object(queue_manager, "connector") as connector:
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
connector.raise_alarm.assert_called_once_with(
severity=Alarms.MAJOR,
source=msg.content,
msg="fake test scan which does not exist!",
alarm_type="KeyError",
metadata={"RID": "something"},
)
def test_invalid_scan_specified_in_message(queuemanager_mock):
queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage(
scan_type="fake test scan which does not exist!",
parameter={"args": {"samx": (1,)}, "kwargs": {}},
queue="primary",
metadata={"RID": "something"},
)
with mock.patch.object(queue_manager, "connector") as connector:
queue_manager.add_to_queue(scan_queue="dummy", msg=msg)
connector.raise_alarm.assert_called_once_with(
severity=Alarms.MAJOR,
source=msg.content,
msg="fake test scan which does not exist!",
alarm_type="KeyError",
metadata={"RID": "something"},
)
def test_set_clear(queuemanager_mock):
queue_manager = queuemanager_mock()
msg = messages.ScanQueueMessage(