diff --git a/bec_plugins/data_processing/__init__.py b/bec_plugins/data_processing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bec_plugins/data_processing/px_example.py b/bec_plugins/data_processing/px_example.py new file mode 100644 index 0000000..4f4dadb --- /dev/null +++ b/bec_plugins/data_processing/px_example.py @@ -0,0 +1,210 @@ +from __future__ import annotations +import numpy as np +from string import Template +from queue import Queue +import time + +# import bec_lib +from data_processing.stream_processor import StreamProcessor +from bec_lib.core import BECMessage, MessageEndpoints +from bec_lib.core.redis_connector import MessageObject, RedisConnector +from typing import Optional, Tuple + + +class StreamProcessorPx(StreamProcessor): + def __init__(self, connector: RedisConnector, config: dict) -> None: + """ + Initialize the LmfitProcessor class. + + Args: + connector (RedisConnector): Redis connector. + config (dict): Configuration for the processor. + """ + super().__init__(connector, config) + self.metadata_consumer = None + self.metadata = {} + self.queue = Queue() + self.start_metadata_consumer() + + def start_data_consumer(self): + pass + + def start_metadata_consumer(self): + if self.metadata_consumer and self.metadata_consumer.is_alive(): + self.metadata_consumer.shutdown() + self.metadata_consumer = self._connector.consumer( + pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self + ) + self.metadata_consumer.start() + + @staticmethod + def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None: + msg_raw = BECMessage.DeviceMessage.loads(msg.value) + parent._metadata_msg_handler(msg_raw, msg.topic.decode()) + + def _metadata_msg_handler(self, msg, topic): + if len(self.metadata) > 10: + first_key = next(iter(self.metadata)) + self.metadata.pop(first_key) + proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0]) + self.metadata.update({proj_nr : msg.content['signals']}) + self.queue.put((proj_nr, msg.content['signals'])) + + def _run_forever(self): + """""" + #TODO: Check if should skip entries in queue at beginning + proj_nr, metadata = self.queue.get() + while self.queue.empty(): + data_msgs = self._get_data(proj_nr) + data = [msg.content['signals']['data'] for msg in data_msgs if msg is not None] + result = self.process(data, metadata) + if not result: + continue + msg = BECMessage.ProcessedDataMessage(data=result[0], metadata=result[1]).dumps() + self._publish_result(msg) + + def _get_data(self, proj_nr: int) -> list: + msgs = self.producer.lrange(f'px_stream/projection_{proj_nr}/data', 0, -1) + if not msgs: + return [] + return [BECMessage.DeviceMessage.loads(msg) for msg in msgs] + + + def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]: + + if not data: + return None + # get the event data, hard coded + azint_data = np.asarray(data) + norm_sum = metadata["norm_sum"] + q = metadata["q"] + + ##################################### + # Pick contrast 0:f1amp, 1:f2amp, 2:f2phase + contrast = self.config["parameters"]["contrast"] + # user input/LinearRegionROI, maybe move to metadata + qranges = self.config["parameters"]["qranges"] + ##################################### + + f1amp, f2amp, f2phase = self._colorfulplot( + qranges=qranges, + q=q, + norm_sum=norm_sum, + data=azint_data, + aziangles=None, + percentile_value=96, + ) + if contrast == 0: + out = f1amp + elif contrast == 1: + out = f2amp + elif contrast == 2: + out = f2phase + + stream_output = { + self.config["output"]: { + # 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)}, + }, + # "input": self.config["input_xy"], + } + metadata["grid_scan"] = out.shape + + return (stream_output, metadata) + + def _bin_qrange(self, qranges, q, norm_sum, data): + """ + Args: + q ranges: list with indices of q edges, data is binned between neighbors. + q: all q + norm_sum: weights for q + data: full data + """ + output = np.zeros((*data.shape[:-1], len(qranges) - 1)) + output_norm = np.zeros((data.shape[-2], len(qranges) - 1)) + + with np.errstate(divide="ignore", invalid="ignore"): + for ii, qval in enumerate(qranges[:-1]): + q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]]) + output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1) + output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1) + output[..., ii] = np.divide( + output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii]) + ) + + return output, output_norm + + def _sym_data(self, data, norm_sum): + n_directions = norm_sum.shape[0] // 2 + output = np.divide( + data[..., :n_directions, :] * norm_sum[:n_directions, :] + + data[..., n_directions:, :] * norm_sum[n_directions:, :], + norm_sum[:n_directions, :] + norm_sum[n_directions:, :], + out=np.zeros_like(data[..., :n_directions, :]), + ) + return output + + def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96): + """ + Args: + q ranges: list with 2 indices for q edges + q: all q + norm_sum: weights for q + data: full data + """ + output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data) + output_sym = self._sym_data(data=output, norm_sum=output_norm) + output_sym = output_sym + shape = output_sym.shape[0:2] + + fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1) + if aziangles is None: + azi_angle = 0 + else: + azi_angle = aziangles[0] + + f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2] + f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2] + # This still slightly confused me and to get mapping to colorwheel correct it needs to match + f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle) + + # Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle! + f2phase = (f2angle + np.pi) / (2 * np.pi) + f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1 + + f1amp = f1amp.reshape(shape) + f2amp = f2amp.reshape(shape) + f2angle = f2angle.reshape(shape) + f2phase = f2phase.reshape(shape) + + # hsv output + h = f2phase + + max_scale = np.percentile(f2amp, percentile_value) + s = f2amp / max_scale + s[s > 1] = 1 + + max_scale = np.percentile(f1amp, percentile_value) + v = f1amp + v = v / max_scale + v[v > 1] = 1 + + hsv = np.stack((h, s, v), axis=2) + # comb_all = colors.hsv_to_rgb(hsv) + + return f1amp, f2amp, f2phase # , comb_all + + +if __name__ == "__main__": + config = { + "template": Template("px_stream/projection_$proj/$channel"), + "channels": ["data", "metadata", "q", "norm_sum"], + "output": "px_dap_worker", + "parameters": { + "qranges" : [20,50], + "roi_stream": "px_roi_stream", + "contrast": 0, # "contrast_stream" : 'px_contrast_stream', + }, + } + dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"]) + # dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"]) + # dap_process.start_met diff --git a/bec_plugins/data_processing/px_streamer.py b/bec_plugins/data_processing/px_streamer.py new file mode 100644 index 0000000..08f8aa2 --- /dev/null +++ b/bec_plugins/data_processing/px_streamer.py @@ -0,0 +1,79 @@ +import os +import h5py +import numpy as np +import time +import json + +from bec_lib.core import RedisConnector, BECMessage + + +def load_data() -> tuple: + """ + + Returns + """ + proj_nr = 180 + basedir = f"/das/work/units/pem/p19745/online_data/analysis/radial_integration_eiger/projection_{proj_nr:06d}/" + + metadata_name = f"/das/work/units/pem/p19745/online_data/metadata/projection_{proj_nr:06d}.json" + with open(metadata_name) as file: + metadata = json.load(file) + + filenames = [fname for fname in os.listdir(basedir) if fname.endswith(".h5")] + filenames.sort() + + for ii, fname in enumerate(filenames): + with h5py.File(os.path.join(basedir, fname), "r") as h5file: + if ii == 0: + q = h5file["q"][...].T.squeeze() + norm_sum = h5file["norm_sum"][...] + data = np.zeros((len(filenames), *h5file["I_all"][...].shape)) + data[ii, ...] = h5file["I_all"][...] + + return data, q, norm_sum, metadata + + +def _get_projection_keys(producer): + keys = producer.keys("px_stream/projection_*") + if not keys: + return [] + return keys + + +def send_data(data, q, norm_sum, bec_producer, metadata, proj_nr) -> None: + """""" + start = time.time() + + keys = _get_projection_keys(bec_producer) + pipe = bec_producer.pipeline() + proj_numbers = set(key.decode().split("px_stream/projection_")[1].split("/")[0] for key in keys) + if len(proj_numbers) > 5: + for entry in sorted(proj_numbers)[0:-5]: + for key in bec_producer.keys(f"px_stream/projection_{entry}/*"): + bec_producer.delete(topic=key, pipe=pipe) + print(f"Deleting {key}") + + # Add new data + return_dict = {"metadata": metadata, "q": q, "norm_sum": norm_sum} + msg = BECMessage.DeviceMessage(signals=return_dict).dumps() + bec_producer.set_and_publish(f"px_stream/projection_{proj_nr}/metadata", msg=msg, pipe=pipe) + + pipe.execute() + for line in range(data.shape[0]): + return_dict = {"data": data[line, ...]} + msg = BECMessage.DeviceMessage(signals=return_dict).dumps() + print(f"Sending line {line}") + bec_producer.rpush(topic=f"px_stream/projection_{proj_nr}/data", msgs=msg) + print(f"Time to send {time.time()-start} seconds") + print(f"Rate {data.shape[0]/(time.time()-start)} Hz") + print(f"Data volume {data.nbytes/1e6} MB") + + +if __name__ == "__main__": + data, q, norm_sum, metadata = load_data() + bec_producer = RedisConnector(["localhost:6379"]).producer() + proj_nr = 180 + while True: + send_data(data, q, norm_sum, bec_producer, metadata, proj_nr=proj_nr) + time.sleep(1) + proj_nr = proj_nr + 1