Files
pxiii_bec/bec_plugins/data_processing/saxs_imaging_processor.py
2023-08-11 09:46:57 +02:00

421 lines
14 KiB
Python

from __future__ import annotations
import numpy as np
import time
from queue import Queue
from typing import Optional, Tuple
from data_processing.stream_processor import StreamProcessor
from bec_lib.core import BECMessage
from bec_lib.core.redis_connector import MessageObject, RedisConnector
class SaxsImagingProcessor(StreamProcessor):
def __init__(self, connector: RedisConnector, config: dict) -> None:
""""""
super().__init__(connector, config)
self.metadata_consumer = None
self.parameter_consumer = None
self.metadata = {}
self.num_received_msgs = 0
self.queue = Queue()
self._init_parameter(endpoint="px_stream/gui_event")
self.start_parameter_consumer(endpoint="px_stream/gui_event")
self._init_metadata_and_proj_nr(endpoint="px_stream/proj_nr")
self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata")
def _init_parameter(self, endpoint: str) -> None:
"""Initialize the parameters azi_angle, contrast and horiz_roi.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
self.azi_angle = None
self.horiz_roi = [20, 50]
self.contrast = 0
msg = self.producer.get(topic=endpoint)
if msg is None:
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
self._parameter_msg_handler(msg_raw)
def start_parameter_consumer(self, endpoint: str) -> None:
"""Initialize the consumers for gui_event parameters.
Consumer is started with a callback function that updates
the parameters: azi_angle, contrast and horiz_roi.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
if self.parameter_consumer and self.parameter_consumer.is_alive():
self.parameter_consumer.shutdown()
self.parameter_consumer = self._connector.consumer(
pattern=endpoint, cb=self._update_parameter_cb, parent=self
)
self.parameter_consumer.start()
@staticmethod
def _update_parameter_cb(msg: MessageObject, parent: SaxsImagingProcessor) -> None:
"""Callback function for the parameter consumer.
Args:
msg (MessageObject): Message object.
parent (SaxsImagingProcessor): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._parameter_msg_handler(msg_raw)
def _parameter_msg_handler(self, msg: BECMessage) -> None:
"""Handle the parameter message.
There can be updates on three different parameters:
azi_angle, contrast and horiz_roi.
Args:
msg (BECMessage): Message object.
Returns:
None
"""
if msg.content["signals"].get("horiz_roi") is not None:
self.horiz_roi = msg.content["signals"]["horiz_roi"]
if msg.content["signals"].get("azi_angles") is not None:
self.azi_angle = msg.content["signals"]["azi_angle"]
if msg.content["signals"].get("contrast") is not None:
self.contrast = msg.content["signals"]["contrast"]
# self._init_parameter_updated = True
# if len(self.metadata) > 0:
# self._update_queue(self.metadata[self.proj_nr], self.proj_nr)
def _init_metadata_and_proj_nr(self, endpoint: str) -> None:
"""Initialize the metadata and proj_nr.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
msg = self.producer.get(topic=endpoint)
if msg is None:
self.proj_nr = None
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
self.proj_nr = msg_raw.content["signals"]["proj_nr"]
# TODO hardcoded endpoint, possibe to use more general solution?
msg = self.producer.get(topic=f"px_stream/projection_{self.proj_nr}/metadata")
msg_raw = BECMessage.DeviceMessage.loads(msg)
def _update_queue(self, metadata: dict, proj_nr: int) -> None:
"""Update the process queue.
Args:
metadata (dict): Metadata for the projection.
proj_nr (int): Projection number.
Returns:
None
"""
self.metadata.update({proj_nr: metadata})
self.queue.put((proj_nr, metadata))
def start_metadata_consumer(self, endpoint: str) -> None:
"""Start the metadata consumer.
Consumer is started with a callback function that updates the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
if self.metadata_consumer and self.metadata_consumer.is_alive():
self.metadata_consumer.shutdown()
self.metadata_consumer = self._connector.consumer(
pattern=endpoint, cb=self._update_metadata_cb, parent=self
)
self.metadata_consumer.start()
@staticmethod
def _update_metadata_cb(msg: MessageObject, parent: SaxsImagingProcessor) -> None:
"""Callback function for the metadata consumer.
Args:
msg (MessageObject): Message object.
parent (SaxsImagingProcessor): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
def _metadata_msg_handler(self, msg: BECMessage, topic) -> None:
"""Handle the metadata message.
If self.metadata is larger than 10, the oldest entry is removed.
Args:
msg (BECMessage): Message object.
topic (str): Topic for the message.
Returns:
None
"""
if len(self.metadata) > 10:
first_key = next(iter(self.metadata))
self.metadata.pop(first_key)
self.proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
self._update_queue(msg.content["signals"], self.proj_nr)
def _init_data_output(self) -> None:
"""Initialize the data output.
Not yet used. Should be used to initialize the output for the processed data.
"""
self.data = None
def start_data_consumer(self) -> None:
"""function from the parent class that we don't want to use here"""
pass
def _run_forever(self) -> None:
"""Loop that runs forever when the processor is started.
Upon update of the queue, the data is loaded and processed.
This processing continues as long as the queue is empty,
and proceeds to the next projection when the queue is updated.
Returns:
None
"""
proj_nr, metadata = self.queue.get()
self.num_received_msgs = 0
data = []
while self.queue.empty():
data_msgs = self._get_data(proj_nr)
if data_msgs is not None:
data.extend(
[msg.content["signals"]["data"] for msg in data_msgs if msg is not None]
)
# print(f"Loading took {time.time() - start}")
# start = time.time()
result = self.process(data, metadata)
# print(f"Processing took {time.time() - start}")
if result is None:
continue
print(f"Length of data is {result[0][0]['z'].shape}")
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
print("Publishing result")
self._publish_result(msg)
def _get_data(self, proj_nr: int) -> list:
"""Get data for given proj_nr from redis.
Args:
proj_nr (int): Projection number.
Returns:
list: List of azimuthal integrated data.
"""
msgs = self.producer.lrange(
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
)
if not msgs:
return None
self.num_received_msgs += len(msgs)
return [BECMessage.DeviceMessage.loads(msg) for msg in msgs]
def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]:
"""Process the scanning SAXS data
Args:
data (list): List of azimuthal integrated data.
metadata (dict): Metadata for the projection.
Returns:
Optional[Tuple[dict, dict]]: Processed data and metadata.
"""
if not data:
return None
# TODO np.asarray is repsonsible for 95% of the processing time for function.
azint_data = np.asarray(data)
norm_sum = metadata["norm_sum"]
q = metadata["q"]
out = []
contrast = self.contrast
horiz_roi = self.horiz_roi
azi_angle = self.azi_angle
if azi_angle is None:
azi_angle = 0
f1amp, f2amp, f2phase = self._colorfulplot(
horiz_roi=horiz_roi,
q=q,
norm_sum=norm_sum,
data=azint_data,
azi_angle=azi_angle,
)
if contrast == 0:
out = f1amp
elif contrast == 1:
out = f2amp
elif contrast == 2:
out = f2phase
stream_output = {
# 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)},
0: {"z": np.asarray(out)},
# "input": self.config["input_xy"],
}
metadata["grid_scan"] = out.shape
return (stream_output, metadata)
def _colorfulplot(
self,
horiz_roi: list,
q: np.ndarray,
norm_sum: np.ndarray,
data: np.ndarray,
azi_angle: float,
percentile_value: int = 96,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute data for sSAXS colorful 2D plot.
Pending: hsv_to_rgb conversion for colorful output
Args:
horiz_roi (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
azi_angle (float, optional): Azimuthal angle for first segment, shifts f2phase. Defaults to 0.
percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase
"""
output, output_norm = self._bin_qrange(
horiz_roi=horiz_roi, q=q, norm_sum=norm_sum, data=data
)
output_sym = self._sym_data(data=output, norm_sum=output_norm)
output_sym = output_sym
shape = output_sym.shape[0:2]
fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1)
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
f2phase = (f2angle + np.pi) / (2 * np.pi)
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
f1amp = f1amp.reshape(shape)
f2amp = f2amp.reshape(shape)
f2angle = f2angle.reshape(shape)
f2phase = f2phase.reshape(shape)
h = f2phase
max_scale = np.percentile(f2amp, percentile_value)
s = f2amp / max_scale
s[s > 1] = 1
max_scale = np.percentile(f1amp, percentile_value)
v = f1amp
v = v / max_scale
v[v > 1] = 1
# hsv = np.stack((h, s, v), axis=2)
# comb_all = colors.hsv_to_rgb(hsv)
return f1amp, f2amp, f2phase # , comb_all
def _bin_qrange(self, horiz_roi, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]:
"""Reintegrate data for given q range.
Weighted sum for data using norm_sum as weights
Args:
horiz_roi (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
Returns:
np.ndarray: Binned data.
np.ndarray: Binned normalization sum.
"""
output = np.zeros((*data.shape[:-1], len(horiz_roi) - 1))
output_norm = np.zeros((data.shape[-2], len(horiz_roi) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
q_mask = np.logical_and(q >= q[horiz_roi[0]], q <= q[horiz_roi[1]])
output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.divide(
output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0])
)
return output, output_norm
def _sym_data(self, data, norm_sum) -> np.ndarray:
"""Symmetrize data by averaging over the two opposing directions.
Helpful to remove detector gaps for x-ray detectors
Args:
data (np.ndarray): Data to be symmetrized.
norm_sum (np.ndarray): Normalization sum.
Returns:
np.ndarray: Symmetrized data.
"""
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
if __name__ == "__main__":
config = {
"output": "px_dap_worker",
}
dap_process = SaxsImagingProcessor.run(config=config, connector_host=["localhost:6379"])