mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-22 02:20:02 +02:00
262 lines
8.9 KiB
Python
262 lines
8.9 KiB
Python
import asyncio
|
|
import logging
|
|
import uuid
|
|
from contextlib import ContextDecorator
|
|
|
|
import msgpack
|
|
from bec_utils import BECMessage as BMessage
|
|
from bec_utils import MessageEndpoints
|
|
from bec_utils.connector import ConsumerConnector
|
|
from cytoolz import partition
|
|
|
|
from .callbacks import live_updates_readback, live_updates_table
|
|
from .devicemanager_client import Device
|
|
from .scan_queue import ScanReport
|
|
|
|
logger = logging.getLogger("scans")
|
|
|
|
|
|
class ScanObject:
|
|
def __init__(self, scan_name: str, scan_info: dict, parent=None) -> None:
|
|
self.scan_name = scan_name
|
|
self.scan_info = scan_info
|
|
self.parent = parent
|
|
|
|
# run must be an anonymous function to allow for multiple doc strings
|
|
self.run = lambda *args, **kwargs: asyncio.run(self._run(*args, **kwargs))
|
|
|
|
async def _run(self, *args, **kwargs):
|
|
with self.parent._sighandler:
|
|
# handle reserved kwargs:
|
|
hide_report_kwarg = False
|
|
if "hide_report" in kwargs:
|
|
hide_report_kwarg = kwargs["hide_report"]
|
|
hide_report = hide_report_kwarg or self.parent.scans._hide_report
|
|
|
|
if self.parent.scans._scan_group is not None:
|
|
if "md" not in kwargs:
|
|
kwargs["md"] = {}
|
|
kwargs["md"]["queue_group"] = self.parent.scans._scan_group
|
|
if self.parent.scans._scan_def_id is not None:
|
|
if "md" not in kwargs:
|
|
kwargs["md"] = {}
|
|
kwargs["md"]["scan_def_id"] = self.parent.scans._scan_def_id
|
|
|
|
request = Scans._prepare_scan_request(self.scan_name, self.scan_info, *args, **kwargs)
|
|
requestID = str(uuid.uuid4()) # TODO: move this to the API server
|
|
request.metadata["RID"] = requestID
|
|
|
|
scan_report = ScanReport.from_request(request, client=self.parent)
|
|
|
|
scan_report_type = self._get_scan_report_type(hide_report)
|
|
if scan_report_type == "readback":
|
|
consumer = self._start_consumer(request)
|
|
self._send_scan_request(request)
|
|
await live_updates_readback(
|
|
self.parent.devicemanager,
|
|
request.content["parameter"]["args"],
|
|
consumer,
|
|
)
|
|
consumer.shutdown()
|
|
elif scan_report_type == "table":
|
|
self._send_scan_request(request)
|
|
if (
|
|
self.parent.scans._scan_def_id is None
|
|
or request.content["scan_type"] == "close_scan_def"
|
|
):
|
|
await live_updates_table(self.parent, request)
|
|
else:
|
|
self._send_scan_request(request)
|
|
return scan_report
|
|
|
|
def _get_scan_report_type(self, hide_report) -> str:
|
|
if hide_report:
|
|
return None
|
|
else:
|
|
return self.scan_info.get("scan_report_hint")
|
|
|
|
def _start_consumer(self, request: BMessage.ScanQueueMessage) -> ConsumerConnector:
|
|
consumer = self.parent.devicemanager.connector.consumer(
|
|
[
|
|
MessageEndpoints.device_readback(dev)
|
|
for dev in request.content["parameter"]["args"].keys()
|
|
],
|
|
threaded=False,
|
|
cb=(lambda msg: msg),
|
|
)
|
|
return consumer
|
|
|
|
def _send_scan_request(self, request: BMessage.ScanQueueMessage) -> None:
|
|
self.parent.devicemanager.producer.send(
|
|
MessageEndpoints.scan_queue_request(), request.dumps()
|
|
)
|
|
|
|
|
|
class Scans:
|
|
def __init__(self, parent):
|
|
self.parent = parent
|
|
self._available_scans = {}
|
|
self._import_scans()
|
|
self._scan_group = None
|
|
self._scan_def_id = None
|
|
self._scan_group_ctx = ScanGroup(parent=self)
|
|
self._scan_def_ctx = ScanDef(parent=self)
|
|
self._hide_report = None
|
|
self._hide_report_ctx = HideReport(parent=self)
|
|
|
|
def _import_scans(self):
|
|
|
|
available_scans = msgpack.loads(
|
|
self.parent.producer.get(MessageEndpoints.available_scans())
|
|
)
|
|
for scan_name, scan_info in available_scans.items():
|
|
self._available_scans[scan_name] = ScanObject(scan_name, scan_info, parent=self.parent)
|
|
setattr(
|
|
self,
|
|
scan_name,
|
|
self._available_scans[scan_name].run,
|
|
)
|
|
setattr(getattr(self, scan_name), "__doc__", scan_info.get("doc"))
|
|
|
|
@staticmethod
|
|
def get_arg_type(in_type: str):
|
|
if in_type == "float":
|
|
return (float, int)
|
|
elif in_type == "int":
|
|
return int
|
|
elif in_type == "list":
|
|
return list
|
|
elif in_type == "boolean":
|
|
return bool
|
|
elif in_type == "str":
|
|
return str
|
|
elif in_type == "dict":
|
|
return dict
|
|
elif in_type == "device":
|
|
return Device
|
|
else:
|
|
TypeError(f"Unknown type {in_type}")
|
|
|
|
@staticmethod
|
|
def _prepare_scan_request(
|
|
scan_name: str, scan_info: dict, *args, **kwargs
|
|
) -> BMessage.ScanQueueMessage:
|
|
"""Prepare scan request message with given scan arguments
|
|
|
|
Args:
|
|
scan_name (str): scan name (matching a scan name on the scan server)
|
|
scan_info (dict): dictionary describing the scan (e.g. doc string, required kwargs etc.)
|
|
|
|
Raises:
|
|
TypeError: Raised if not all required keyword arguments have been specified.
|
|
TypeError: Raised if the number of args do fit into the required bundling pattern.
|
|
TypeError: Raised if an argument is not of the required type as specified in scan_info.
|
|
|
|
Returns:
|
|
BMessage.ScanQueueMessage: _description_
|
|
"""
|
|
arg_input = scan_info.get("arg_input")
|
|
if arg_input is not None:
|
|
arg_bundle_size = len(arg_input)
|
|
if len(arg_input) > 0:
|
|
if len(args) % len(arg_input) != 0:
|
|
raise TypeError(
|
|
f"{scan_info.get('doc')}\n {scan_name} takes multiples of {len(arg_input)} arguments ({len(args)} given).",
|
|
)
|
|
if not all(req_kwarg in kwargs for req_kwarg in scan_info.get("required_kwargs")):
|
|
raise TypeError(
|
|
f"{scan_info.get('doc')}\n Not all required keyword arguments have been specified. The required arguments are: {scan_info.get('required_kwargs')}"
|
|
)
|
|
for ii, arg in enumerate(args):
|
|
if not isinstance(arg, Scans.get_arg_type(arg_input[ii % len(arg_input)])):
|
|
raise TypeError(
|
|
f"{scan_info.get('doc')}\n Argument {ii} must be of type {arg_input[ii%len(arg_input)]}, not {type(arg).__name__}."
|
|
)
|
|
else:
|
|
logger.warning("Could not check arguments against scan input types.")
|
|
arg_bundle_size = len(args)
|
|
md = {}
|
|
if "md" in kwargs:
|
|
md = kwargs.pop("md")
|
|
params = {
|
|
"args": Scans._parameter_bundler(args, arg_bundle_size),
|
|
"kwargs": kwargs,
|
|
}
|
|
return BMessage.ScanQueueMessage(
|
|
scan_type=scan_name, parameter=params, queue="primary", metadata=md
|
|
)
|
|
|
|
@staticmethod
|
|
def _parameter_bundler(args, bundle_size):
|
|
"""
|
|
|
|
Args:
|
|
args:
|
|
bundle_size: number of parameters per bundle
|
|
|
|
Returns:
|
|
|
|
"""
|
|
params = {}
|
|
for cmds in partition(bundle_size, args):
|
|
cmds_serialized = [cmd.name if hasattr(cmd, "name") else cmd for cmd in cmds]
|
|
params[cmds_serialized[0]] = cmds_serialized[1:]
|
|
return params
|
|
|
|
@property
|
|
def scan_group(self):
|
|
return self._scan_group_ctx
|
|
|
|
@property
|
|
def scan_def(self):
|
|
return self._scan_def_ctx
|
|
|
|
@property
|
|
def hide_report(self):
|
|
return self._hide_report_ctx
|
|
|
|
|
|
class ScanGroup(ContextDecorator):
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__()
|
|
self.parent = parent
|
|
|
|
def __enter__(self):
|
|
group_id = str(uuid.uuid4())
|
|
self.parent._scan_group = group_id
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.parent.close_scan_group()
|
|
self.parent._scan_group = None
|
|
|
|
|
|
class ScanDef(ContextDecorator):
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__()
|
|
self.parent = parent
|
|
|
|
def __enter__(self):
|
|
scan_def_id = str(uuid.uuid4())
|
|
self.parent._scan_def_id = scan_def_id
|
|
self.parent.open_scan_def()
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.parent.close_scan_def()
|
|
self.parent._scan_def_id = None
|
|
|
|
|
|
class HideReport(ContextDecorator):
|
|
def __init__(self, parent=None) -> None:
|
|
super().__init__()
|
|
self.parent = parent
|
|
|
|
def __enter__(self):
|
|
if self.parent._hide_report is None:
|
|
self.parent._hide_report = True
|
|
return self
|
|
|
|
def __exit__(self, *exc):
|
|
self.parent._hide_report = None
|