Files
pxiii_bec/bec_plugins/data_processing/px_example.py

352 lines
11 KiB
Python

from __future__ import annotations
import numpy as np
from string import Template
from queue import Queue
import time
# import bec_lib
from data_processing.stream_processor import StreamProcessor
from bec_lib.core import BECMessage, MessageEndpoints
from bec_lib.core.redis_connector import MessageObject, RedisConnector
from typing import Optional, Tuple
class StreamProcessorPx(StreamProcessor):
def __init__(self, connector: RedisConnector, config: dict) -> None:
""""""
super().__init__(connector, config)
self.metadata_consumer = None
self.metadata = {}
# self._init_data_output()
self.num_received_msgs = 0
self.queue = Queue()
self._init_metadata(endpoint="px_stream/proj_nr")
self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata")
def _init_metadata(self, endpoint: str) -> None:
"""Initialize the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
msg = self.producer.get(topic=endpoint)
if msg is None:
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
proj_nr = msg_raw.content["signals"]["proj_nr"]
# TODO hardcoded endpoint, possibe to use more general solution?
msg = self.producer.get(topic=f"px_stream/projection_{proj_nr}/metadata")
msg_raw = BECMessage.DeviceMessage.loads(msg)
self._update_queue(msg_raw.content["signals"], proj_nr)
def _update_queue(self, metadata: dict, proj_nr: int) -> None:
"""Update the process queue.
Args:
metadata (dict): Metadata for the projection.
proj_nr (int): Projection number.
Returns:
None
"""
self.metadata.update({proj_nr: metadata})
self.queue.put((proj_nr, metadata))
def start_metadata_consumer(self, endpoint: str) -> None:
"""Start the metadata consumer.
Consumer is started with a callback function that updates the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
if self.metadata_consumer and self.metadata_consumer.is_alive():
self.metadata_consumer.shutdown()
self.metadata_consumer = self._connector.consumer(
pattern=endpoint, cb=self._update_metadata_cb, parent=self
)
self.metadata_consumer.start()
@staticmethod
def _update_metadata_cb(msg: MessageObject, parent: StreamProcessorPx) -> None:
"""Callback function for the metadata consumer.
Args:
msg (MessageObject): Message object.
parent (StreamProcessorPx): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
def _metadata_msg_handler(self, msg: BECMessage, topic) -> None:
"""Handle the metadata message.
If self.metadata is larger than 10, the oldest entry is removed.
Args:
msg (BECMessage): Message object.
topic (str): Topic for the message.
Returns:
None
"""
if len(self.metadata) > 10:
first_key = next(iter(self.metadata))
self.metadata.pop(first_key)
proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
self._update_queue(msg.content["signals"], proj_nr)
def _init_data_output(self) -> None:
"""Initialize the data output.
Not yet used. Should be used to initialize the output for the processed data.
"""
self.data = None
def start_data_consumer(self) -> None:
"""function from the parent class that we don't want to use here"""
pass
def _run_forever(self) -> None:
"""Loop that runs forever when the processor is started.
Upon update of the queue, the data is loaded and processed.
This processing continues as long as the queue is empty,
and proceeds to the next projection when the queue is updated.
Returns:
None
"""
proj_nr, metadata = self.queue.get()
self.num_received_msgs = 0
# TODO initiate output, such that self.process only runs on new data
# self._init_data_output()
data = []
while self.queue.empty():
# TODO debug code for timing
#
data_msgs = self._get_data(proj_nr)
data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None])
# print(f"Loading took {time.time() - start}")
start = time.time()
result = self.process(data, metadata)
print(f"Processing took {time.time() - start}")
if not result:
continue
print(f"Length of data is {result[0][0]['z'].shape}")
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
print("Publishing result")
self._publish_result(msg)
def _get_data(self, proj_nr: int) -> list:
"""Get data for given proj_nr from redis.
Args:
proj_nr (int): Projection number.
Returns:
list: List of azimuthal integrated data.
"""
msgs = self.producer.lrange(
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
)
if not msgs:
return []
self.num_received_msgs += len(msgs)
return [BECMessage.DeviceMessage.loads(msg) for msg in msgs]
def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]:
"""Process the scanning SAXS data
Args:
data (list): List of azimuthal integrated data.
metadata (dict): Metadata for the projection.
Returns:
Optional[Tuple[dict, dict]]: Processed data and metadata.
"""
if not data:
return None
start = time.time()
azint_data = np.asarray(data)
print(f"Processing took {time.time() - start}")
norm_sum = metadata["norm_sum"]
q = metadata["q"]
out = []
contrast = self.config["parameters"]["contrast"]
qranges = self.config["parameters"]["qranges"]
aziangles = self.config["parameters"]["aziangles"]
f1amp, f2amp, f2phase = self._colorfulplot(
qranges=qranges,
q=q,
norm_sum=norm_sum,
data=azint_data,
aziangles=aziangles,
percentile_value=96,
)
if contrast == 0:
out = f1amp
elif contrast == 1:
out = f2amp
elif contrast == 2:
out = f2phase
stream_output = {
# 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)},
0: {"z": np.asarray(out)},
# "input": self.config["input_xy"],
}
metadata["grid_scan"] = out.shape
return (stream_output, metadata)
def _colorfulplot(
self,
qranges: list,
q: np.ndarray,
norm_sum: np.ndarray,
data: np.ndarray,
aziangles: list,
percentile_value: int = 96,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute data for sSAXS colorful 2D plot.
Pending: hsv_to_rgb conversion for colorful output
Args:
qranges (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
aziangles (list, optional): List of azimuthal angles to shift f2phase. Defaults to None.
percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase
"""
output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data)
output_sym = self._sym_data(data=output, norm_sum=output_norm)
output_sym = output_sym
shape = output_sym.shape[0:2]
fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1)
if aziangles is None:
azi_angle = 0
else:
azi_angle = aziangles[0]
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
f2phase = (f2angle + np.pi) / (2 * np.pi)
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
f1amp = f1amp.reshape(shape)
f2amp = f2amp.reshape(shape)
f2angle = f2angle.reshape(shape)
f2phase = f2phase.reshape(shape)
h = f2phase
max_scale = np.percentile(f2amp, percentile_value)
s = f2amp / max_scale
s[s > 1] = 1
max_scale = np.percentile(f1amp, percentile_value)
v = f1amp
v = v / max_scale
v[v > 1] = 1
# hsv = np.stack((h, s, v), axis=2)
# comb_all = colors.hsv_to_rgb(hsv)
return f1amp, f2amp, f2phase # , comb_all
def _bin_qrange(self, qranges, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]:
"""Reintegrate data for given q range.
Weighted sum for data using norm_sum as weights
Args:
qranges (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
Returns:
np.ndarray: Binned data.
np.ndarray: Binned normalization sum.
"""
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
q_mask = np.logical_and(q >= q[qranges[0]], q <= q[qranges[1]])
output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.divide(
output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0])
)
return output, output_norm
def _sym_data(self, data, norm_sum) -> np.ndarray:
"""Symmetrize data by averaging over the two opposing directions.
Helpful to remove detector gaps for x-ray detectors
Args:
data (np.ndarray): Data to be symmetrized.
norm_sum (np.ndarray): Normalization sum.
Returns:
np.ndarray: Symmetrized data.
"""
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
if __name__ == "__main__":
config = {
"output": "px_dap_worker",
"parameters": {
# TODO these three inputs could be made available for change from the GUI
"qranges": [20, 50],
"contrast": 0,
"aziangles": None,
},
}
dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"])