docs: add documentation to functions

This commit is contained in:
2023-08-10 14:53:54 +02:00
parent 7ee421e0d2
commit ba4b8f8f02

View File

@@ -13,71 +13,159 @@ from typing import Optional, Tuple
class StreamProcessorPx(StreamProcessor):
def __init__(self, connector: RedisConnector, config: dict) -> None:
"""
Initialize the LmfitProcessor class.
Args:
connector (RedisConnector): Redis connector.
config (dict): Configuration for the processor.
"""
""""""
super().__init__(connector, config)
self.metadata_consumer = None
self.metadata = {}
# self._init_data_output()
self.num_received_msgs = 0
self.queue = Queue()
self.start_metadata_consumer()
self._init_metadata(endpoint="px_stream/proj_nr")
self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata")
def start_data_consumer(self):
pass
def _init_metadata(self, endpoint: str) -> None:
"""Initialize the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
msg = self.producer.get(topic=endpoint)
if msg is None:
return None
msg_raw = BECMessage.DeviceMessage.loads(msg)
proj_nr = msg_raw.content["signals"]["proj_nr"]
# TODO hardcoded endpoint, possibe to use more general solution?
msg = self.producer.get(topic=f"px_stream/projection_{proj_nr}/metadata")
msg_raw = BECMessage.DeviceMessage.loads(msg)
self._update_queue(msg_raw.content["signals"], proj_nr)
def _update_queue(self, metadata: dict, proj_nr: int) -> None:
"""Update the process queue.
Args:
metadata (dict): Metadata for the projection.
proj_nr (int): Projection number.
Returns:
None
"""
self.metadata.update({proj_nr: metadata})
self.queue.put((proj_nr, metadata))
def start_metadata_consumer(self, endpoint: str) -> None:
"""Start the metadata consumer.
Consumer is started with a callback function that updates the metadata.
Args:
endpoint (str): Endpoint for redis topic.
Returns:
None
"""
def start_metadata_consumer(self):
if self.metadata_consumer and self.metadata_consumer.is_alive():
self.metadata_consumer.shutdown()
self.metadata_consumer = self._connector.consumer(
pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self
pattern=endpoint, cb=self._update_metadata_cb, parent=self
)
self.metadata_consumer.start()
@staticmethod
def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None:
def _update_metadata_cb(msg: MessageObject, parent: StreamProcessorPx) -> None:
"""Callback function for the metadata consumer.
Args:
msg (MessageObject): Message object.
parent (StreamProcessorPx): Parent class.
Returns:
None
"""
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
def _metadata_msg_handler(self, msg, topic):
def _metadata_msg_handler(self, msg: BECMessage, topic) -> None:
"""Handle the metadata message.
If self.metadata is larger than 10, the oldest entry is removed.
Args:
msg (BECMessage): Message object.
topic (str): Topic for the message.
Returns:
None
"""
if len(self.metadata) > 10:
first_key = next(iter(self.metadata))
self.metadata.pop(first_key)
proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
self.metadata.update({proj_nr: msg.content["signals"]})
self.queue.put((proj_nr, msg.content["signals"]))
self._update_queue(msg.content["signals"], proj_nr)
def _init_data_output(self) -> None:
"""Initialize the data output.
Not yet used. Should be used to initialize the output for the processed data.
"""
self.data = None
def start_data_consumer(self) -> None:
"""function from the parent class that we don't want to use here"""
pass
def _run_forever(self) -> None:
"""Loop that runs forever when the processor is started.
Upon update of the queue, the data is loaded and processed.
This processing continues as long as the queue is empty,
and proceeds to the next projection when the queue is updated.
Returns:
None
"""
def _run_forever(self):
""""""
# TODO: Check if should skip entries in queue at beginning
proj_nr, metadata = self.queue.get()
self.num_received_msgs = 0
# TODO initiate output, such that self.process only runs on new data
# self._init_data_output()
data = []
while self.queue.empty():
start = time.time()
# TODO debug code for timing
#
data_msgs = self._get_data(proj_nr)
data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None])
#if len(data) > 80:
# out = np.asarray(data)
# result = ({0: {"z": np.sum(out, axis=(-1, -2))}}, {1: {}})
#else:
# continue
print(f"Loading took {time.time() - start}")
# print(f"Loading took {time.time() - start}")
start = time.time()
result = self.process(data, metadata)
print(f"Processing took {time.time() - start}")
if not result:
continue
print(f"Length of data is {result[0][0]['z'].shape}")
msg = BECMessage.ProcessedDataMessage(
data=result[0][0], metadata=result[1]).dumps()
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
print("Publishing result")
self._publish_result(msg)
def _get_data(self, proj_nr: int) -> list:
"""Get data for given proj_nr from redis.
Args:
proj_nr (int): Projection number.
Returns:
list: List of azimuthal integrated data.
"""
msgs = self.producer.lrange(
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
)
@@ -87,27 +175,36 @@ class StreamProcessorPx(StreamProcessor):
return [BECMessage.DeviceMessage.loads(msg) for msg in msgs]
def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]:
"""Process the scanning SAXS data
Args:
data (list): List of azimuthal integrated data.
metadata (dict): Metadata for the projection.
Returns:
Optional[Tuple[dict, dict]]: Processed data and metadata.
"""
if not data:
return None
# get the event data, hard coded
start = time.time()
azint_data = np.asarray(data)
print(f"Processing took {time.time() - start}")
norm_sum = metadata["norm_sum"]
q = metadata["q"]
out = []
#####################################
# Pick contrast 0:f1amp, 1:f2amp, 2:f2phase
contrast = self.config["parameters"]["contrast"]
# user input/LinearRegionROI, maybe move to metadata
qranges = self.config["parameters"]["qranges"]
#####################################
aziangles = self.config["parameters"]["aziangles"]
f1amp, f2amp, f2phase = self._colorfulplot(
qranges=qranges,
q=q,
norm_sum=norm_sum,
data=azint_data,
aziangles=None,
aziangles=aziangles,
percentile_value=96,
)
if contrast == 0:
@@ -126,46 +223,31 @@ class StreamProcessorPx(StreamProcessor):
return (stream_output, metadata)
def _bin_qrange(self, qranges, q, norm_sum, data):
"""
def _colorfulplot(
self,
qranges: list,
q: np.ndarray,
norm_sum: np.ndarray,
data: np.ndarray,
aziangles: list,
percentile_value: int = 96,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute data for sSAXS colorful 2D plot.
Pending: hsv_to_rgb conversion for colorful output
Args:
q ranges: list with indices of q edges, data is binned between neighbors.
q: all q
norm_sum: weights for q
data: full data
qranges (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
aziangles (list, optional): List of azimuthal angles to shift f2phase. Defaults to None.
percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase
"""
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
for ii, qval in enumerate(qranges[:-1]):
q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]])
output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., ii] = np.divide(
output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii])
)
return output, output_norm
def _sym_data(self, data, norm_sum):
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96):
"""
Args:
q ranges: list with 2 indices for q edges
q: all q
norm_sum: weights for q
data: full data
"""
output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data)
output_sym = self._sym_data(data=output, norm_sum=output_norm)
output_sym = output_sym
@@ -179,10 +261,8 @@ class StreamProcessorPx(StreamProcessor):
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
# This still slightly confused me and to get mapping to colorwheel correct it needs to match
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
# Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle!
f2phase = (f2angle + np.pi) / (2 * np.pi)
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
@@ -191,7 +271,6 @@ class StreamProcessorPx(StreamProcessor):
f2angle = f2angle.reshape(shape)
f2phase = f2phase.reshape(shape)
# hsv output
h = f2phase
max_scale = np.percentile(f2amp, percentile_value)
@@ -203,22 +282,70 @@ class StreamProcessorPx(StreamProcessor):
v = v / max_scale
v[v > 1] = 1
hsv = np.stack((h, s, v), axis=2)
# hsv = np.stack((h, s, v), axis=2)
# comb_all = colors.hsv_to_rgb(hsv)
return f1amp, f2amp, f2phase # , comb_all
def _bin_qrange(self, qranges, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]:
"""Reintegrate data for given q range.
Weighted sum for data using norm_sum as weights
Args:
qranges (list): List with q edges for binning.
q (np.ndarray): q values.
norm_sum (np.ndarray): Normalization sum.
data (np.ndarray): Data to be binned.
Returns:
np.ndarray: Binned data.
np.ndarray: Binned normalization sum.
"""
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
with np.errstate(divide="ignore", invalid="ignore"):
q_mask = np.logical_and(q >= q[qranges[0]], q <= q[qranges[1]])
output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
output[..., 0] = np.divide(
output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0])
)
return output, output_norm
def _sym_data(self, data, norm_sum) -> np.ndarray:
"""Symmetrize data by averaging over the two opposing directions.
Helpful to remove detector gaps for x-ray detectors
Args:
data (np.ndarray): Data to be symmetrized.
norm_sum (np.ndarray): Normalization sum.
Returns:
np.ndarray: Symmetrized data.
"""
n_directions = norm_sum.shape[0] // 2
output = np.divide(
data[..., :n_directions, :] * norm_sum[:n_directions, :]
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
out=np.zeros_like(data[..., :n_directions, :]),
)
return output
if __name__ == "__main__":
config = {
"template": Template("px_stream/projection_$proj/$channel"),
"channels": ["data", "metadata", "q", "norm_sum"],
"output": "px_dap_worker",
"parameters": {
"qranges": [20, 50], # TODO this will be signal from ROI selector
"contrast": 0, # "contrast_stream" : 'px_contrast_stream',
# TODO these three inputs could be made available for change from the GUI
"qranges": [20, 50],
"contrast": 0,
"aziangles": None,
},
}
dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"])
# dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"])
# dap_process.start_met