mirror of
https://github.com/ivan-usov-org/bec.git
synced 2025-04-21 18:20:01 +02:00
refactor(bec_lib): removed async data handler
This commit is contained in:
parent
3c4fec4d6d
commit
c91666f383
@ -1,111 +0,0 @@
|
||||
"""
|
||||
This module contains the AsyncDataHandler class which is used to receive and store async device data from the BEC.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bec_lib import messages
|
||||
from bec_lib.connector import ConnectorBase
|
||||
|
||||
|
||||
class AsyncDataHandler:
|
||||
def __init__(self, connector: ConnectorBase):
|
||||
self.connector = connector
|
||||
|
||||
def get_async_data_for_scan(self, scan_id: str) -> dict[list]:
|
||||
"""
|
||||
Get the async data for a given scan.
|
||||
|
||||
Args:
|
||||
scan_id(str): the scan id to get the async data for
|
||||
|
||||
Returns:
|
||||
dict[list]: the async data for the scan sorted by device name
|
||||
"""
|
||||
async_device_keys = self.connector.keys(
|
||||
MessageEndpoints.device_async_readback(scan_id, "*")
|
||||
)
|
||||
async_data = {}
|
||||
for device_key in async_device_keys:
|
||||
key = device_key.decode()
|
||||
device_name = key.split(MessageEndpoints.device_async_readback(scan_id, "").endpoint)[
|
||||
-1
|
||||
].split(":")[0]
|
||||
data = self.get_async_data_for_device(scan_id, device_name)
|
||||
if not data:
|
||||
continue
|
||||
async_data[device_name] = data
|
||||
return async_data
|
||||
|
||||
def get_async_data_for_device(self, scan_id: str, device_name: str) -> list:
|
||||
"""
|
||||
Get the async data for a given device in a scan.
|
||||
|
||||
Args:
|
||||
scan_id(str): the scan id to get the async data for
|
||||
device_name(str): the device name to get the async data for
|
||||
|
||||
Returns:
|
||||
list: the async data for the device
|
||||
"""
|
||||
key = MessageEndpoints.device_async_readback(scan_id, device_name)
|
||||
msgs = self.connector.xrange(key, min="-", max="+")
|
||||
if not msgs:
|
||||
return []
|
||||
return self.process_async_data(msgs)
|
||||
|
||||
@staticmethod
|
||||
def process_async_data(
|
||||
msgs: list[dict[Literal["data"], messages.DeviceMessage]]
|
||||
) -> dict | list[dict]:
|
||||
"""
|
||||
Process the async data.
|
||||
|
||||
Args:
|
||||
msgs(list[messages.DeviceMessage]): the async data to process
|
||||
|
||||
Returns:
|
||||
list: the processed async data
|
||||
"""
|
||||
concat_type = None
|
||||
data = []
|
||||
async_data = {}
|
||||
for msg in msgs:
|
||||
msg = msg["data"]
|
||||
if not concat_type:
|
||||
concat_type = msg.metadata.get("async_update", "append")
|
||||
data.append(msg.content["signals"])
|
||||
if len(data) == 1:
|
||||
async_data = data[0]
|
||||
return async_data
|
||||
if concat_type == "extend":
|
||||
# concatenate the dictionaries
|
||||
for signal in data[0].keys():
|
||||
async_data[signal] = {}
|
||||
for key in data[0][signal].keys():
|
||||
if hasattr(data[0][signal][key], "__iter__"):
|
||||
async_data[signal][key] = np.concatenate([d[signal][key] for d in data])
|
||||
else:
|
||||
async_data[signal][key] = [d[signal][key] for d in data]
|
||||
return async_data
|
||||
if concat_type == "append":
|
||||
# concatenate the lists
|
||||
for key in data[0].keys():
|
||||
async_data[key] = {"value": [], "timestamp": []}
|
||||
for d in data:
|
||||
async_data[key]["value"].append(d[key]["value"])
|
||||
if "timestamp" in d[key]:
|
||||
async_data[key]["timestamp"].append(d[key]["timestamp"])
|
||||
return async_data
|
||||
if concat_type == "replace":
|
||||
# replace the dictionaries
|
||||
async_data = data[-1]
|
||||
return async_data
|
||||
raise ValueError(f"Unknown async update type: {concat_type}")
|
@ -1,121 +0,0 @@
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from bec_lib import messages
|
||||
from bec_lib.async_data import AsyncDataHandler
|
||||
from bec_lib.endpoints import MessageEndpoints
|
||||
|
||||
# pylint: disable=protected-access
|
||||
# pylint: disable=missing-function-docstring
|
||||
# pylint: disable=missing-class-docstring
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_data():
|
||||
producer = mock.MagicMock()
|
||||
yield AsyncDataHandler(producer)
|
||||
|
||||
|
||||
def test_process_async_data_replace(async_data):
|
||||
data = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={"data": {"value": np.zeros((10, 10))}},
|
||||
metadata={"async_update": "replace"},
|
||||
)
|
||||
}
|
||||
for ii in range(10)
|
||||
]
|
||||
res = async_data.process_async_data(data)
|
||||
assert res["data"]["value"].shape == (10, 10)
|
||||
|
||||
|
||||
def test_process_async_multiple_signals(async_data):
|
||||
data = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={
|
||||
"signal1": {"value": np.zeros((10, 10))},
|
||||
"signal2": {"value": np.zeros((20, 20))},
|
||||
},
|
||||
metadata={"async_update": "replace"},
|
||||
)
|
||||
}
|
||||
for ii in range(10)
|
||||
]
|
||||
res = async_data.process_async_data(data)
|
||||
assert res["signal1"]["value"].shape == (10, 10)
|
||||
assert res["signal2"]["value"].shape == (20, 20)
|
||||
|
||||
|
||||
def test_process_async_data_extend(async_data):
|
||||
data = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={"async_update": "extend"}
|
||||
)
|
||||
}
|
||||
for ii in range(10)
|
||||
]
|
||||
res = async_data.process_async_data(data)
|
||||
assert res["data"]["value"].shape == (100, 10)
|
||||
|
||||
|
||||
def test_process_async_update_append(async_data):
|
||||
data = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={"async_update": "append"}
|
||||
)
|
||||
}
|
||||
for ii in range(10)
|
||||
]
|
||||
res = async_data.process_async_data(data)
|
||||
assert res["data"]["value"][0].shape == (10, 10)
|
||||
assert len(res["data"]["value"]) == 10
|
||||
|
||||
|
||||
def test_process_async_data_single(async_data):
|
||||
data = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={}
|
||||
)
|
||||
}
|
||||
]
|
||||
res = async_data.process_async_data(data)
|
||||
assert res["data"]["value"].shape == (10, 10)
|
||||
|
||||
|
||||
def test_get_async_data_for_scan():
|
||||
producer = mock.MagicMock()
|
||||
async_data = AsyncDataHandler(producer)
|
||||
producer.keys.return_value = [
|
||||
MessageEndpoints.device_async_readback("scan_id", "samx").endpoint.encode(),
|
||||
MessageEndpoints.device_async_readback("scan_id", "samy").endpoint.encode(),
|
||||
]
|
||||
with mock.patch.object(async_data, "get_async_data_for_device") as mock_get:
|
||||
async_data.get_async_data_for_scan("scan_id")
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
|
||||
def test_get_async_data_for_device():
|
||||
producer = mock.MagicMock()
|
||||
async_data = AsyncDataHandler(producer)
|
||||
producer.xrange.return_value = [
|
||||
{
|
||||
"data": messages.DeviceMessage(
|
||||
signals={"data": {"value": np.zeros((10, 10))}}, metadata={}
|
||||
)
|
||||
}
|
||||
]
|
||||
res = async_data.get_async_data_for_device("scan_id", "samx")
|
||||
assert res["data"]["value"].shape == (10, 10)
|
||||
assert len(res) == 1
|
||||
assert producer.xrange.call_count == 1
|
||||
producer.xrange.assert_called_with(
|
||||
MessageEndpoints.device_async_readback("scan_id", "samx"), min="-", max="+"
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user