mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-22 02:20:02 +02:00
refactor(bec_lib): removed async data handler
This commit is contained in:
parent
3c4fec4d6d
commit
c91666f383
@ -1,111 +0,0 @@
|
|||||||
"""
|
|
||||||
This module contains the AsyncDataHandler class which is used to receive and store async device data from the BEC.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from bec_lib.endpoints import MessageEndpoints
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from bec_lib import messages
|
|
||||||
from bec_lib.connector import ConnectorBase
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncDataHandler:
|
|
||||||
def __init__(self, connector: ConnectorBase):
|
|
||||||
self.connector = connector
|
|
||||||
|
|
||||||
def get_async_data_for_scan(self, scan_id: str) -> dict[list]:
|
|
||||||
"""
|
|
||||||
Get the async data for a given scan.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scan_id(str): the scan id to get the async data for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict[list]: the async data for the scan sorted by device name
|
|
||||||
"""
|
|
||||||
async_device_keys = self.connector.keys(
|
|
||||||
MessageEndpoints.device_async_readback(scan_id, "*")
|
|
||||||
)
|
|
||||||
async_data = {}
|
|
||||||
for device_key in async_device_keys:
|
|
||||||
key = device_key.decode()
|
|
||||||
device_name = key.split(MessageEndpoints.device_async_readback(scan_id, "").endpoint)[
|
|
||||||
-1
|
|
||||||
].split(":")[0]
|
|
||||||
data = self.get_async_data_for_device(scan_id, device_name)
|
|
||||||
if not data:
|
|
||||||
continue
|
|
||||||
async_data[device_name] = data
|
|
||||||
return async_data
|
|
||||||
|
|
||||||
def get_async_data_for_device(self, scan_id: str, device_name: str) -> list:
|
|
||||||
"""
|
|
||||||
Get the async data for a given device in a scan.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scan_id(str): the scan id to get the async data for
|
|
||||||
device_name(str): the device name to get the async data for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: the async data for the device
|
|
||||||
"""
|
|
||||||
key = MessageEndpoints.device_async_readback(scan_id, device_name)
|
|
||||||
msgs = self.connector.xrange(key, min="-", max="+")
|
|
||||||
if not msgs:
|
|
||||||
return []
|
|
||||||
return self.process_async_data(msgs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def process_async_data(
|
|
||||||
msgs: list[dict[Literal["data"], messages.DeviceMessage]]
|
|
||||||
) -> dict | list[dict]:
|
|
||||||
"""
|
|
||||||
Process the async data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
msgs(list[messages.DeviceMessage]): the async data to process
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: the processed async data
|
|
||||||
"""
|
|
||||||
concat_type = None
|
|
||||||
data = []
|
|
||||||
async_data = {}
|
|
||||||
for msg in msgs:
|
|
||||||
msg = msg["data"]
|
|
||||||
if not concat_type:
|
|
||||||
concat_type = msg.metadata.get("async_update", "append")
|
|
||||||
data.append(msg.content["signals"])
|
|
||||||
if len(data) == 1:
|
|
||||||
async_data = data[0]
|
|
||||||
return async_data
|
|
||||||
if concat_type == "extend":
|
|
||||||
# concatenate the dictionaries
|
|
||||||
for signal in data[0].keys():
|
|
||||||
async_data[signal] = {}
|
|
||||||
for key in data[0][signal].keys():
|
|
||||||
if hasattr(data[0][signal][key], "__iter__"):
|
|
||||||
async_data[signal][key] = np.concatenate([d[signal][key] for d in data])
|
|
||||||
else:
|
|
||||||
async_data[signal][key] = [d[signal][key] for d in data]
|
|
||||||
return async_data
|
|
||||||
if concat_type == "append":
|
|
||||||
# concatenate the lists
|
|
||||||
for key in data[0].keys():
|
|
||||||
async_data[key] = {"value": [], "timestamp": []}
|
|
||||||
for d in data:
|
|
||||||
async_data[key]["value"].append(d[key]["value"])
|
|
||||||
if "timestamp" in d[key]:
|
|
||||||
async_data[key]["timestamp"].append(d[key]["timestamp"])
|
|
||||||
return async_data
|
|
||||||
if concat_type == "replace":
|
|
||||||
# replace the dictionaries
|
|
||||||
async_data = data[-1]
|
|
||||||
return async_data
|
|
||||||
raise ValueError(f"Unknown async update type: {concat_type}")
|
|
@ -1,121 +0,0 @@
|
|||||||
from unittest import mock
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from bec_lib import messages
|
|
||||||
from bec_lib.async_data import AsyncDataHandler
|
|
||||||
from bec_lib.endpoints import MessageEndpoints
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
# pylint: disable=missing-function-docstring
|
|
||||||
# pylint: disable=missing-class-docstring
|
|
||||||
# pylint: disable=redefined-outer-name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def async_data():
|
|
||||||
producer = mock.MagicMock()
|
|
||||||
yield AsyncDataHandler(producer)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_async_data_replace(async_data):
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={"data": {"value": np.zeros((10, 10))}},
|
|
||||||
metadata={"async_update": "replace"},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for ii in range(10)
|
|
||||||
]
|
|
||||||
res = async_data.process_async_data(data)
|
|
||||||
assert res["data"]["value"].shape == (10, 10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_async_multiple_signals(async_data):
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={
|
|
||||||
"signal1": {"value": np.zeros((10, 10))},
|
|
||||||
"signal2": {"value": np.zeros((20, 20))},
|
|
||||||
},
|
|
||||||
metadata={"async_update": "replace"},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for ii in range(10)
|
|
||||||
]
|
|
||||||
res = async_data.process_async_data(data)
|
|
||||||
assert res["signal1"]["value"].shape == (10, 10)
|
|
||||||
assert res["signal2"]["value"].shape == (20, 20)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_async_data_extend(async_data):
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={"async_update": "extend"}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for ii in range(10)
|
|
||||||
]
|
|
||||||
res = async_data.process_async_data(data)
|
|
||||||
assert res["data"]["value"].shape == (100, 10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_async_update_append(async_data):
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={"async_update": "append"}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
for ii in range(10)
|
|
||||||
]
|
|
||||||
res = async_data.process_async_data(data)
|
|
||||||
assert res["data"]["value"][0].shape == (10, 10)
|
|
||||||
assert len(res["data"]["value"]) == 10
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_async_data_single(async_data):
|
|
||||||
data = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
res = async_data.process_async_data(data)
|
|
||||||
assert res["data"]["value"].shape == (10, 10)
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_async_data_for_scan():
|
|
||||||
producer = mock.MagicMock()
|
|
||||||
async_data = AsyncDataHandler(producer)
|
|
||||||
producer.keys.return_value = [
|
|
||||||
MessageEndpoints.device_async_readback("scan_id", "samx").endpoint.encode(),
|
|
||||||
MessageEndpoints.device_async_readback("scan_id", "samy").endpoint.encode(),
|
|
||||||
]
|
|
||||||
with mock.patch.object(async_data, "get_async_data_for_device") as mock_get:
|
|
||||||
async_data.get_async_data_for_scan("scan_id")
|
|
||||||
assert mock_get.call_count == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_async_data_for_device():
|
|
||||||
producer = mock.MagicMock()
|
|
||||||
async_data = AsyncDataHandler(producer)
|
|
||||||
producer.xrange.return_value = [
|
|
||||||
{
|
|
||||||
"data": messages.DeviceMessage(
|
|
||||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={}
|
|
||||||
)
|
|
||||||
}
|
|
||||||
]
|
|
||||||
res = async_data.get_async_data_for_device("scan_id", "samx")
|
|
||||||
assert res["data"]["value"].shape == (10, 10)
|
|
||||||
assert len(res) == 1
|
|
||||||
assert producer.xrange.call_count == 1
|
|
||||||
producer.xrange.assert_called_with(
|
|
||||||
MessageEndpoints.device_async_readback("scan_id", "samx"), min="-", max="+"
|
|
||||||
)
|
|
Loading…
x
Reference in New Issue
Block a user