15 Commits

3 changed files with 586 additions and 0 deletions

View File

@@ -0,0 +1 @@
from .saxs_imaging_processor import SaxsImagingProcessor

View File

@@ -0,0 +1,430 @@
from __future__ import annotations
import numpy as np
import time
from queue import Queue
from typing import Optional, Tuple
from data_processing.stream_processor import StreamProcessor
from bec_lib.core import BECMessage
from bec_lib.core.redis_connector import MessageObject, RedisConnector
class SaxsImagingProcessor(StreamProcessor):
def __init__(self, connector: RedisConnector, config: dict) -> None:
""""""
super().__init__(connector, config)
self.metadata_consumer = None
self.parameter_consumer = None
self.metadata = {}
self.num_received_msgs = 0
self.queue = Queue()
self._init_parameter(endpoint="px_stream/gui_event")
self.start_parameter_consumer(endpoint="px_stream/gui_event")
self._init_metadata_and_proj_nr(endpoint="px_stream/proj_nr")
self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata")
def _init_parameter(self, endpoint: str) -> None:
"""Initialize the parameters azi_angle, contrast and horiz_roi.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
self.azi_angle = None
self.horiz_roi = [20, 50]
self.contrast = 0
msg = self.producer.get(topic=endpoint)
if msg is None:
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
self._parameter_msg_handler(msg_raw)
def start_parameter_consumer(self, endpoint: str) -> None:
"""Initialize the consumers for gui_event parameters.
Consumer is started with a callback function that updates
the parameters: azi_angle, contrast and horiz_roi.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
if self.parameter_consumer and self.parameter_consumer.is_alive():
self.parameter_consumer.shutdown()
self.parameter_consumer = self._connector.consumer(
pattern=endpoint, cb=self._update_parameter_cb, parent=self
)
self.parameter_consumer.start()
@staticmethod
def _update_parameter_cb(msg: MessageObject, parent: SaxsImagingProcessor) -> None:
"""Callback function for the parameter consumer.
Args:
msg (MessageObject): Message object.
parent (SaxsImagingProcessor): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._parameter_msg_handler(msg_raw)
def _parameter_msg_handler(self, msg: BECMessage) -> None:
"""Handle the parameter message.
There can be updates on three different parameters:
azi_angle, contrast and horiz_roi.
Args:
msg (BECMessage): Message object.
Returns:
None
"""
if msg.content["signals"].get("horiz_roi") is not None:
self.horiz_roi = msg.content["signals"]["horiz_roi"]
if msg.content["signals"].get("azi_angles") is not None:
self.azi_angle = msg.content["signals"]["azi_angle"]
if msg.content["signals"].get("contrast") is not None:
self.contrast = msg.content["signals"]["contrast"]
# self._init_parameter_updated = True
# if len(self.metadata) > 0:
# self._update_queue(self.metadata[self.proj_nr], self.proj_nr)
def _init_metadata_and_proj_nr(self, endpoint: str) -> None:
"""Initialize the metadata and proj_nr.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
msg = self.producer.get(topic=endpoint)
if msg is None:
self.proj_nr = None
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
self.proj_nr = msg_raw.content["signals"]["proj_nr"]
# TODO hardcoded endpoint, possibe to use more general solution?
msg = self.producer.get(topic=f"px_stream/projection_{self.proj_nr}/metadata")
msg_raw = BECMessage.DeviceMessage.loads(msg)
self._update_queue(msg_raw.content["signals"], self.proj_nr)
def _update_queue(self, metadata: dict, proj_nr: int) -> None:
"""Update the process queue.
Args:
metadata (dict): Metadata for the projection.
proj_nr (int): Projection number.
Returns:
None
"""
self.metadata.update({proj_nr: metadata})
self.queue.put((proj_nr, metadata))
def start_metadata_consumer(self, endpoint: str) -> None:
"""Start the metadata consumer.
Consumer is started with a callback function that updates the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
if self.metadata_consumer and self.metadata_consumer.is_alive():
self.metadata_consumer.shutdown()
self.metadata_consumer = self._connector.consumer(
pattern=endpoint, cb=self._update_metadata_cb, parent=self
)
self.metadata_consumer.start()
@staticmethod
def _update_metadata_cb(msg: MessageObject, parent: SaxsImagingProcessor) -> None:
"""Callback function for the metadata consumer.
Args:
msg (MessageObject): Message object.
parent (SaxsImagingProcessor): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
def _metadata_msg_handler(self, msg: BECMessage, topic) -> None:
"""Handle the metadata message.
If self.metadata is larger than 10, the oldest entry is removed.
Args:
msg (BECMessage): Message object.
topic (str): Topic for the message.
Returns:
None
"""
if len(self.metadata) > 10:
first_key = next(iter(self.metadata))
self.metadata.pop(first_key)
self.proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
self._update_queue(msg.content["signals"], self.proj_nr)
def start_data_consumer(self) -> None:
"""function from the parent class that we don't want to use here"""
pass
def _run_forever(self) -> None:
"""Loop that runs forever when the processor is started.
Upon update of the queue, the data is loaded and processed.
This processing continues as long as the queue is empty,
and proceeds to the next projection when the queue is updated.
Returns:
None
"""
proj_nr, metadata = self.queue.get()
self.num_received_msgs = 0
self.data = None
while self.queue.empty():
start = time.time()
self._get_data(proj_nr, metadata)
start = time.time()
result = self.process(self.data, metadata)
print(f"Processing took {time.time() - start}")
if result is None:
continue
print(f"Length of data is {result[0][0]['z'].shape}")
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
print("Publishing result")
self._publish_result(msg)
def _get_data(self, proj_nr: int, metadata: dict) -> None:
"""Get data for given proj_nr from redis.
Args:
proj_nr (int): Projection number.
Returns:
list: List of azimuthal integrated data.
"""
start = time.time()
msgs = self.producer.lrange(
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
)
print(f"Loading of {len(msgs)} took {time.time() - start}")
if not msgs:
return None
frame_shape = BECMessage.DeviceMessage.loads(msgs[0]).content["signals"]["data"].shape[-2:]
if self.data is None:
start = time.time()
self.data = np.empty(
(
metadata["metadata"]["number_of_rows"],
metadata["metadata"]["number_of_columns"],
*frame_shape,
)
)
print(f"Init output took {time.time() - start}")
start = time.time()
for msg in msgs:
self.data[
self.num_received_msgs : self.num_received_msgs + 1, ...
] = BECMessage.DeviceMessage.loads(msg).content["signals"]["data"]
self.num_received_msgs += 1
print(f"Casting data to array took {time.time() - start}")
def process(self, data: np.ndarray, metadata: dict) -> Optional[Tuple[dict, dict]]:
"""Process the scanning SAXS data
Args:
data (list): List of azimuthal integrated data.
metadata (dict): Metadata for the projection.
Returns:
Optional[Tuple[dict, dict]]: Processed data and metadata.
"""
if data is None:
return None
# TODO np.asarray is repsonsible for 95% of the processing time for function.
azint_data = data[0 : self.num_received_msgs, ...]
norm_sum = metadata["norm_sum"]
q = metadata["q"]
out = []
contrast = self.contrast
horiz_roi = self.horiz_roi
azi_angle = self.azi_angle
if azi_angle is None:
azi_angle = 0
f1amp, f2amp, f2phase = self._colorfulplot(
horiz_roi=horiz_roi,
q=q,
norm_sum=norm_sum,
data=azint_data,
azi_angle=azi_angle,
)
if contrast == 0:
out = f1amp
elif contrast == 1:
out = f2amp
elif contrast == 2:
out = f2phase
stream_output = {
# 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)},
0: {"z": np.asarray(out)},
# "input": self.config["input_xy"],
}
metadata["grid_scan"] = out.shape
return (stream_output, metadata)
def _colorfulplot(
self,
horiz_roi: list,
q: np.ndarray,
norm_sum: np.ndarray,
data: np.ndarray,
azi_angle: float,
percentile_value: int = 96,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute data for sSAXS colorful 2D plot.
Pending: hsv_to_rgb conversion for colorful output
Args:
horiz_roi (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
azi_angle (float, optional): Azimuthal angle for first segment, shifts f2phase. Defaults to 0.
percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase
"""
output, output_norm = self._bin_qrange(
horiz_roi=horiz_roi, q=q, norm_sum=norm_sum, data=data
)
output_sym = self._sym_data(data=output, norm_sum=output_norm)
output_sym = output_sym
shape = output_sym.shape[0:2]
fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1)
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
f2phase = (f2angle + np.pi) / (2 * np.pi)
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
f1amp = f1amp.reshape(shape)
f2amp = f2amp.reshape(shape)
f2angle = f2angle.reshape(shape)
f2phase = f2phase.reshape(shape)
h = f2phase
max_scale = np.percentile(f2amp, percentile_value)
s = f2amp / max_scale
s[s > 1] = 1
max_scale = np.percentile(f1amp, percentile_value)
v = f1amp
v = v / max_scale
v[v > 1] = 1
# hsv = np.stack((h, s, v), axis=2)
# comb_all = colors.hsv_to_rgb(hsv)
return f1amp, f2amp, f2phase # , comb_all
def _bin_qrange(self, horiz_roi, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]:
"""Reintegrate data for given q range.
Weighted sum for data using norm_sum as weights
Args:
horiz_roi (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
Returns:
np.ndarray: Binned data.
np.ndarray: Binned normalization sum.
"""
output = np.zeros((*data.shape[:-1], len(horiz_roi) - 1))
output_norm = np.zeros((data.shape[-2], len(horiz_roi) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
q_mask = np.logical_and(q >= q[horiz_roi[0]], q <= q[horiz_roi[1]])
output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.divide(
output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0])
)
return output, output_norm
def _sym_data(self, data, norm_sum) -> np.ndarray:
"""Symmetrize data by averaging over the two opposing directions.
Helpful to remove detector gaps for x-ray detectors
Args:
data (np.ndarray): Data to be symmetrized.
norm_sum (np.ndarray): Normalization sum.
Returns:
np.ndarray: Symmetrized data.
"""
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
if __name__ == "__main__":
config = {
"output": "px_dap_worker",
}
dap_process = SaxsImagingProcessor.run(config=config, connector_host=["localhost:6379"])

View File

@@ -0,0 +1,155 @@
import os
import h5py
import numpy as np
import time
import json
import argparse
from bec_lib.core import RedisConnector, BECMessage
def load_data(datadir: str, metadata_path: str) -> tuple:
"""Load data from disk
Args:
datapath (str): Path to the data directory with data for projection (h5 files)
metadata_path (str): Path to the metadata file
Returns:
tuple: data, q, norm_sum, metadata
"""
with open(metadata_path) as file:
metadata = json.load(file)
filenames = [fname for fname in os.listdir(datadir) if fname.endswith(".h5")]
filenames.sort()
for ii, fname in enumerate(filenames):
with h5py.File(os.path.join(datadir, fname), "r") as h5file:
if ii == 0:
q = h5file["q"][...].T.squeeze()
norm_sum = h5file["norm_sum"][...]
data = np.zeros((len(filenames), *h5file["I_all"][...].shape))
data[ii, ...] = h5file["I_all"][...]
return data, q, norm_sum, metadata
def _get_projection_keys(producer) -> list:
"""Get all keys for projections with endpoint px_stream/projection_* in redis
Args:
producer (RedisProducer): Redis producer
Returns:
list: List of keys or [] if no keys are found"""
keys = producer.keys("px_stream/projection_*")
if not keys:
return []
return keys
def send_data(
data: np.ndarray,
q: np.ndarray,
norm_sum: np.ndarray,
bec_producer: RedisConnector.producer,
metadata: dict,
proj_nr: int,
) -> None:
"""Send data to redis and delete old data > 5 projections
Args:
data (np.ndarray): Data to send
q (np.ndarray): q values
norm_sum (np.ndarray): Normalization sum
bec_producer (RedisProducer): Redis producer
metadata (dict): Metadata
proj_nr (int): Projection number
Returns:
None
"""
start = time.time()
keys = _get_projection_keys(bec_producer)
pipe = bec_producer.pipeline()
proj_numbers = set(key.decode().split("px_stream/projection_")[1].split("/")[0] for key in keys)
if len(proj_numbers) > 5:
for entry in sorted(proj_numbers)[0:-5]:
for key in bec_producer.keys(f"px_stream/projection_{entry}/*"):
bec_producer.delete(topic=key, pipe=pipe)
print(f"Deleting {key}")
return_dict = {"metadata": metadata, "q": q, "norm_sum": norm_sum}
msg = BECMessage.DeviceMessage(signals=return_dict).dumps()
bec_producer.set_and_publish(f"px_stream/projection_{proj_nr}/metadata", msg=msg, pipe=pipe)
return_dict = {"proj_nr": proj_nr}
msg = BECMessage.DeviceMessage(signals=return_dict).dumps()
bec_producer.set_and_publish(f"px_stream/proj_nr", msg=msg, pipe=pipe)
pipe.execute()
for line in range(data.shape[0]):
return_dict = {"data": data[line, ...]}
msg = BECMessage.DeviceMessage(signals=return_dict).dumps()
print(f"Sending line {line}")
bec_producer.rpush(topic=f"px_stream/projection_{proj_nr}/data", msgs=msg)
print(f"Time to send {time.time()-start} seconds")
print(f"Rate {data.shape[0]/(time.time()-start)} Hz")
print(f"Data volume {data.nbytes/1e6} MB")
if __name__ == "__main__":
"""Start the stream simulator, defaults to px_stream/projection_* in redis on localhost:6379
Example usage:
>>> python saxs_imaging_streamsimulator.py -d ~/datadir/ -m ~/metadatafile.json -p 180 -d 30 -r localhost:6379
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"-d",
"--datadir",
type=str,
help="filepath to datadir for projection files (in h5 format)",
required=True,
)
parser.add_argument(
"-m",
"--metadata",
type=str,
help="filepath to metadata json file",
required=True,
)
parser.add_argument(
"-p",
"--proj_nr",
type=int,
help="Projection number matching the data",
required=True,
)
parser.add_argument(
"-w",
"--wait_delay",
type=int,
help="delay between sending data in seconds (int)",
default=30,
)
parser.add_argument(
"-r",
"--redis",
type=str,
help="Redis_host:port",
default="localhost:6379",
)
values = parser.parse_args()
data, q, norm_sum, metadata = load_data(datadir=values.datadir, metadata_path=values.metadata)
bec_producer = RedisConnector([f"{values.redis}"]).producer()
proj_nr = values.proj_nr
delay = values.wait_delay
while True:
send_data(data, q, norm_sum, bec_producer, metadata, proj_nr=proj_nr)
time.sleep(delay)
bec_producer.delete(topic=f"px_stream/projection_{proj_nr}/data:val")