Files
pxiii_bec/bec_plugins/data_processing/px_example.py
2023-08-09 11:20:11 +02:00

217 lines
7.9 KiB
Python

from __future__ import annotations
import numpy as np
from string import Template
from queue import Queue
import time
# import bec_lib
from data_processing.stream_processor import StreamProcessor
from bec_lib.core import BECMessage, MessageEndpoints
from bec_lib.core.redis_connector import MessageObject, RedisConnector
from typing import Optional, Tuple
class StreamProcessorPx(StreamProcessor):
def __init__(self, connector: RedisConnector, config: dict) -> None:
"""
Initialize the LmfitProcessor class.
Args:
connector (RedisConnector): Redis connector.
config (dict): Configuration for the processor.
"""
super().__init__(connector, config)
self.metadata_consumer = None
self.metadata = {}
self.num_received_msgs = 0
self.queue = Queue()
self.start_metadata_consumer()
def start_data_consumer(self):
pass
def start_metadata_consumer(self):
if self.metadata_consumer and self.metadata_consumer.is_alive():
self.metadata_consumer.shutdown()
self.metadata_consumer = self._connector.consumer(
pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self
)
self.metadata_consumer.start()
@staticmethod
def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None:
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
def _metadata_msg_handler(self, msg, topic):
if len(self.metadata) > 10:
first_key = next(iter(self.metadata))
self.metadata.pop(first_key)
proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
self.metadata.update({proj_nr: msg.content["signals"]})
self.queue.put((proj_nr, msg.content["signals"]))
def _run_forever(self):
""""""
# TODO: Check if should skip entries in queue at beginning
proj_nr, metadata = self.queue.get()
self.num_received_msgs = 0
data = []
while self.queue.empty():
start = time.time()
data_msgs = self._get_data(proj_nr)
print(f"Processing took {time.time() - start}")
data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None])
# if len(data) > :
result = self.process(data, metadata)
if not result:
continue
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
print("Publishing result")
self._publish_result(msg)
def _get_data(self, proj_nr: int) -> list:
msgs = self.producer.lrange(
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
)
if not msgs:
return []
self.num_received_msgs += len(msgs)
return [BECMessage.DeviceMessage.loads(msg) for msg in msgs]
def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]:
if not data:
return None
# get the event data, hard coded
azint_data = np.asarray(data)
norm_sum = metadata["norm_sum"]
q = metadata["q"]
#####################################
# Pick contrast 0:f1amp, 1:f2amp, 2:f2phase
contrast = self.config["parameters"]["contrast"]
# user input/LinearRegionROI, maybe move to metadata
qranges = self.config["parameters"]["qranges"]
#####################################
f1amp, f2amp, f2phase = self._colorfulplot(
qranges=qranges,
q=q,
norm_sum=norm_sum,
data=azint_data,
aziangles=None,
percentile_value=96,
)
if contrast == 0:
out = f1amp
elif contrast == 1:
out = f2amp
elif contrast == 2:
out = f2phase
stream_output = {
# 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)},
0: {"z": np.asarray(out)},
# "input": self.config["input_xy"],
}
metadata["grid_scan"] = out.shape
return (stream_output, metadata)
def _bin_qrange(self, qranges, q, norm_sum, data):
"""
Args:
q ranges: list with indices of q edges, data is binned between neighbors.
q: all q
norm_sum: weights for q
data: full data
"""
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
for ii, qval in enumerate(qranges[:-1]):
q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]])
output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., ii] = np.divide(
output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii])
)
return output, output_norm
def _sym_data(self, data, norm_sum):
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96):
"""
Args:
q ranges: list with 2 indices for q edges
q: all q
norm_sum: weights for q
data: full data
"""
output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data)
output_sym = self._sym_data(data=output, norm_sum=output_norm)
output_sym = output_sym
shape = output_sym.shape[0:2]
fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1)
if aziangles is None:
azi_angle = 0
else:
azi_angle = aziangles[0]
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
# This still slightly confused me and to get mapping to colorwheel correct it needs to match
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
# Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle!
f2phase = (f2angle + np.pi) / (2 * np.pi)
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
f1amp = f1amp.reshape(shape)
f2amp = f2amp.reshape(shape)
f2angle = f2angle.reshape(shape)
f2phase = f2phase.reshape(shape)
# hsv output
h = f2phase
max_scale = np.percentile(f2amp, percentile_value)
s = f2amp / max_scale
s[s > 1] = 1
max_scale = np.percentile(f1amp, percentile_value)
v = f1amp
v = v / max_scale
v[v > 1] = 1
hsv = np.stack((h, s, v), axis=2)
# comb_all = colors.hsv_to_rgb(hsv)
return f1amp, f2amp, f2phase # , comb_all
if __name__ == "__main__":
config = {
"template": Template("px_stream/projection_$proj/$channel"),
"channels": ["data", "metadata", "q", "norm_sum"],
"output": "px_dap_worker",
"parameters": {
"qranges": [20, 50], # TODO this will be signal from ROI selector
"contrast": 0, # "contrast_stream" : 'px_contrast_stream',
},
}
dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"])
# dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"])
# dap_process.start_met