docs: add documentation to functions
This commit is contained in:
@@ -13,71 +13,159 @@ from typing import Optional, Tuple
|
||||
|
||||
class StreamProcessorPx(StreamProcessor):
|
||||
def __init__(self, connector: RedisConnector, config: dict) -> None:
|
||||
"""
|
||||
Initialize the LmfitProcessor class.
|
||||
|
||||
Args:
|
||||
connector (RedisConnector): Redis connector.
|
||||
config (dict): Configuration for the processor.
|
||||
"""
|
||||
""""""
|
||||
super().__init__(connector, config)
|
||||
self.metadata_consumer = None
|
||||
self.metadata = {}
|
||||
# self._init_data_output()
|
||||
self.num_received_msgs = 0
|
||||
self.queue = Queue()
|
||||
self.start_metadata_consumer()
|
||||
self._init_metadata(endpoint="px_stream/proj_nr")
|
||||
self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata")
|
||||
|
||||
def start_data_consumer(self):
|
||||
pass
|
||||
def _init_metadata(self, endpoint: str) -> None:
|
||||
"""Initialize the metadata.
|
||||
|
||||
Args:
|
||||
endpoint (str): Endpoint for redis topic.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
msg = self.producer.get(topic=endpoint)
|
||||
if msg is None:
|
||||
return None
|
||||
msg_raw = BECMessage.DeviceMessage.loads(msg)
|
||||
proj_nr = msg_raw.content["signals"]["proj_nr"]
|
||||
# TODO hardcoded endpoint, possibe to use more general solution?
|
||||
msg = self.producer.get(topic=f"px_stream/projection_{proj_nr}/metadata")
|
||||
msg_raw = BECMessage.DeviceMessage.loads(msg)
|
||||
self._update_queue(msg_raw.content["signals"], proj_nr)
|
||||
|
||||
def _update_queue(self, metadata: dict, proj_nr: int) -> None:
|
||||
"""Update the process queue.
|
||||
|
||||
Args:
|
||||
metadata (dict): Metadata for the projection.
|
||||
proj_nr (int): Projection number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
self.metadata.update({proj_nr: metadata})
|
||||
self.queue.put((proj_nr, metadata))
|
||||
|
||||
def start_metadata_consumer(self, endpoint: str) -> None:
|
||||
"""Start the metadata consumer.
|
||||
Consumer is started with a callback function that updates the metadata.
|
||||
|
||||
Args:
|
||||
endpoint (str): Endpoint for redis topic.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
def start_metadata_consumer(self):
|
||||
if self.metadata_consumer and self.metadata_consumer.is_alive():
|
||||
self.metadata_consumer.shutdown()
|
||||
self.metadata_consumer = self._connector.consumer(
|
||||
pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self
|
||||
pattern=endpoint, cb=self._update_metadata_cb, parent=self
|
||||
)
|
||||
self.metadata_consumer.start()
|
||||
|
||||
@staticmethod
|
||||
def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None:
|
||||
def _update_metadata_cb(msg: MessageObject, parent: StreamProcessorPx) -> None:
|
||||
"""Callback function for the metadata consumer.
|
||||
|
||||
Args:
|
||||
msg (MessageObject): Message object.
|
||||
parent (StreamProcessorPx): Parent class.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
msg_raw = BECMessage.DeviceMessage.loads(msg.value)
|
||||
parent._metadata_msg_handler(msg_raw, msg.topic.decode())
|
||||
|
||||
def _metadata_msg_handler(self, msg, topic):
|
||||
def _metadata_msg_handler(self, msg: BECMessage, topic) -> None:
|
||||
"""Handle the metadata message.
|
||||
If self.metadata is larger than 10, the oldest entry is removed.
|
||||
|
||||
Args:
|
||||
msg (BECMessage): Message object.
|
||||
topic (str): Topic for the message.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
if len(self.metadata) > 10:
|
||||
first_key = next(iter(self.metadata))
|
||||
self.metadata.pop(first_key)
|
||||
proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0])
|
||||
self.metadata.update({proj_nr: msg.content["signals"]})
|
||||
self.queue.put((proj_nr, msg.content["signals"]))
|
||||
self._update_queue(msg.content["signals"], proj_nr)
|
||||
|
||||
def _init_data_output(self) -> None:
|
||||
"""Initialize the data output.
|
||||
Not yet used. Should be used to initialize the output for the processed data.
|
||||
"""
|
||||
self.data = None
|
||||
|
||||
def start_data_consumer(self) -> None:
|
||||
"""function from the parent class that we don't want to use here"""
|
||||
pass
|
||||
|
||||
def _run_forever(self) -> None:
|
||||
"""Loop that runs forever when the processor is started.
|
||||
Upon update of the queue, the data is loaded and processed.
|
||||
This processing continues as long as the queue is empty,
|
||||
and proceeds to the next projection when the queue is updated.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
def _run_forever(self):
|
||||
""""""
|
||||
# TODO: Check if should skip entries in queue at beginning
|
||||
proj_nr, metadata = self.queue.get()
|
||||
self.num_received_msgs = 0
|
||||
# TODO initiate output, such that self.process only runs on new data
|
||||
# self._init_data_output()
|
||||
data = []
|
||||
while self.queue.empty():
|
||||
start = time.time()
|
||||
# TODO debug code for timing
|
||||
#
|
||||
data_msgs = self._get_data(proj_nr)
|
||||
data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None])
|
||||
#if len(data) > 80:
|
||||
# out = np.asarray(data)
|
||||
# result = ({0: {"z": np.sum(out, axis=(-1, -2))}}, {1: {}})
|
||||
#else:
|
||||
# continue
|
||||
print(f"Loading took {time.time() - start}")
|
||||
# print(f"Loading took {time.time() - start}")
|
||||
start = time.time()
|
||||
result = self.process(data, metadata)
|
||||
print(f"Processing took {time.time() - start}")
|
||||
if not result:
|
||||
continue
|
||||
print(f"Length of data is {result[0][0]['z'].shape}")
|
||||
msg = BECMessage.ProcessedDataMessage(
|
||||
data=result[0][0], metadata=result[1]).dumps()
|
||||
msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps()
|
||||
print("Publishing result")
|
||||
self._publish_result(msg)
|
||||
|
||||
def _get_data(self, proj_nr: int) -> list:
|
||||
"""Get data for given proj_nr from redis.
|
||||
|
||||
Args:
|
||||
proj_nr (int): Projection number.
|
||||
|
||||
Returns:
|
||||
list: List of azimuthal integrated data.
|
||||
|
||||
"""
|
||||
|
||||
msgs = self.producer.lrange(
|
||||
f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1
|
||||
)
|
||||
@@ -87,27 +175,36 @@ class StreamProcessorPx(StreamProcessor):
|
||||
return [BECMessage.DeviceMessage.loads(msg) for msg in msgs]
|
||||
|
||||
def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]:
|
||||
"""Process the scanning SAXS data
|
||||
|
||||
Args:
|
||||
data (list): List of azimuthal integrated data.
|
||||
metadata (dict): Metadata for the projection.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[dict, dict]]: Processed data and metadata.
|
||||
|
||||
"""
|
||||
|
||||
if not data:
|
||||
return None
|
||||
# get the event data, hard coded
|
||||
start = time.time()
|
||||
azint_data = np.asarray(data)
|
||||
print(f"Processing took {time.time() - start}")
|
||||
norm_sum = metadata["norm_sum"]
|
||||
q = metadata["q"]
|
||||
out = []
|
||||
|
||||
#####################################
|
||||
# Pick contrast 0:f1amp, 1:f2amp, 2:f2phase
|
||||
contrast = self.config["parameters"]["contrast"]
|
||||
# user input/LinearRegionROI, maybe move to metadata
|
||||
qranges = self.config["parameters"]["qranges"]
|
||||
#####################################
|
||||
aziangles = self.config["parameters"]["aziangles"]
|
||||
|
||||
f1amp, f2amp, f2phase = self._colorfulplot(
|
||||
qranges=qranges,
|
||||
q=q,
|
||||
norm_sum=norm_sum,
|
||||
data=azint_data,
|
||||
aziangles=None,
|
||||
aziangles=aziangles,
|
||||
percentile_value=96,
|
||||
)
|
||||
if contrast == 0:
|
||||
@@ -126,46 +223,31 @@ class StreamProcessorPx(StreamProcessor):
|
||||
|
||||
return (stream_output, metadata)
|
||||
|
||||
def _bin_qrange(self, qranges, q, norm_sum, data):
|
||||
"""
|
||||
def _colorfulplot(
|
||||
self,
|
||||
qranges: list,
|
||||
q: np.ndarray,
|
||||
norm_sum: np.ndarray,
|
||||
data: np.ndarray,
|
||||
aziangles: list,
|
||||
percentile_value: int = 96,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Compute data for sSAXS colorful 2D plot.
|
||||
Pending: hsv_to_rgb conversion for colorful output
|
||||
|
||||
Args:
|
||||
q ranges: list with indices of q edges, data is binned between neighbors.
|
||||
q: all q
|
||||
norm_sum: weights for q
|
||||
data: full data
|
||||
qranges (list): List with q edges for binning.
|
||||
q (np.ndarray): q values.
|
||||
norm_sum (np.ndarray): Normalization sum.
|
||||
data (np.ndarray): Data to be binned.
|
||||
aziangles (list, optional): List of azimuthal angles to shift f2phase. Defaults to None.
|
||||
percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase
|
||||
|
||||
"""
|
||||
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
|
||||
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
|
||||
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
for ii, qval in enumerate(qranges[:-1]):
|
||||
q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]])
|
||||
output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1)
|
||||
output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
|
||||
output[..., ii] = np.divide(
|
||||
output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii])
|
||||
)
|
||||
|
||||
return output, output_norm
|
||||
|
||||
def _sym_data(self, data, norm_sum):
|
||||
n_directions = norm_sum.shape[0] // 2
|
||||
output = np.divide(
|
||||
data[..., :n_directions, :] * norm_sum[:n_directions, :]
|
||||
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
|
||||
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
|
||||
out=np.zeros_like(data[..., :n_directions, :]),
|
||||
)
|
||||
return output
|
||||
|
||||
def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96):
|
||||
"""
|
||||
Args:
|
||||
q ranges: list with 2 indices for q edges
|
||||
q: all q
|
||||
norm_sum: weights for q
|
||||
data: full data
|
||||
"""
|
||||
output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data)
|
||||
output_sym = self._sym_data(data=output, norm_sum=output_norm)
|
||||
output_sym = output_sym
|
||||
@@ -179,10 +261,8 @@ class StreamProcessorPx(StreamProcessor):
|
||||
|
||||
f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2]
|
||||
f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2]
|
||||
# This still slightly confused me and to get mapping to colorwheel correct it needs to match
|
||||
f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle)
|
||||
|
||||
# Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle!
|
||||
f2phase = (f2angle + np.pi) / (2 * np.pi)
|
||||
f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1
|
||||
|
||||
@@ -191,7 +271,6 @@ class StreamProcessorPx(StreamProcessor):
|
||||
f2angle = f2angle.reshape(shape)
|
||||
f2phase = f2phase.reshape(shape)
|
||||
|
||||
# hsv output
|
||||
h = f2phase
|
||||
|
||||
max_scale = np.percentile(f2amp, percentile_value)
|
||||
@@ -203,22 +282,70 @@ class StreamProcessorPx(StreamProcessor):
|
||||
v = v / max_scale
|
||||
v[v > 1] = 1
|
||||
|
||||
hsv = np.stack((h, s, v), axis=2)
|
||||
# hsv = np.stack((h, s, v), axis=2)
|
||||
# comb_all = colors.hsv_to_rgb(hsv)
|
||||
|
||||
return f1amp, f2amp, f2phase # , comb_all
|
||||
|
||||
def _bin_qrange(self, qranges, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Reintegrate data for given q range.
|
||||
Weighted sum for data using norm_sum as weights
|
||||
|
||||
Args:
|
||||
qranges (list): List with q edges for binning.
|
||||
q (np.ndarray): q values.
|
||||
norm_sum (np.ndarray): Normalization sum.
|
||||
data (np.ndarray): Data to be binned.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Binned data.
|
||||
np.ndarray: Binned normalization sum.
|
||||
"""
|
||||
|
||||
output = np.zeros((*data.shape[:-1], len(qranges) - 1))
|
||||
output_norm = np.zeros((data.shape[-2], len(qranges) - 1))
|
||||
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
q_mask = np.logical_and(q >= q[qranges[0]], q <= q[qranges[1]])
|
||||
output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1)
|
||||
output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1)
|
||||
output[..., 0] = np.divide(
|
||||
output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0])
|
||||
)
|
||||
|
||||
return output, output_norm
|
||||
|
||||
def _sym_data(self, data, norm_sum) -> np.ndarray:
|
||||
"""Symmetrize data by averaging over the two opposing directions.
|
||||
Helpful to remove detector gaps for x-ray detectors
|
||||
|
||||
Args:
|
||||
data (np.ndarray): Data to be symmetrized.
|
||||
norm_sum (np.ndarray): Normalization sum.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Symmetrized data.
|
||||
|
||||
"""
|
||||
|
||||
n_directions = norm_sum.shape[0] // 2
|
||||
output = np.divide(
|
||||
data[..., :n_directions, :] * norm_sum[:n_directions, :]
|
||||
+ data[..., n_directions:, :] * norm_sum[n_directions:, :],
|
||||
norm_sum[:n_directions, :] + norm_sum[n_directions:, :],
|
||||
out=np.zeros_like(data[..., :n_directions, :]),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"template": Template("px_stream/projection_$proj/$channel"),
|
||||
"channels": ["data", "metadata", "q", "norm_sum"],
|
||||
"output": "px_dap_worker",
|
||||
"parameters": {
|
||||
"qranges": [20, 50], # TODO this will be signal from ROI selector
|
||||
"contrast": 0, # "contrast_stream" : 'px_contrast_stream',
|
||||
# TODO these three inputs could be made available for change from the GUI
|
||||
"qranges": [20, 50],
|
||||
"contrast": 0,
|
||||
"aziangles": None,
|
||||
},
|
||||
}
|
||||
dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"])
|
||||
# dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"])
|
||||
# dap_process.start_met
|
||||
|
||||
Reference in New Issue
Block a user