moved zmq receiving/sending into separate class

This commit is contained in:
2024-07-30 15:57:57 +02:00
parent 57e22c01b5
commit 444bf1de06
2 changed files with 55 additions and 26 deletions

View File

@ -7,10 +7,10 @@ from time import sleep
import jungfrau_utils as ju
import numpy as np
import zmq
from peakfinder8_extension import peakfinder_8
from algos import calc_radial_integration, calc_apply_additional_mask
from zmqsocks import ZMQSockets
FLAGS = 0
@ -57,27 +57,11 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host
ju_stream_adapter = ju.StreamAdapter()
zmq_context = zmq.Context(io_threads=4)
poller = zmq.Poller()
zmq_socks = ZMQSockets(backend_address, accumulator_host, accumulator_port, visualisation_host, visualisation_port)
# all the normal workers
worker = 1
# receive from backend:
backend_socket = zmq_context.socket(zmq.PULL)
backend_socket.connect(backend_address)
poller.register(backend_socket, zmq.POLLIN)
accumulator_socket = zmq_context.socket(zmq.PUSH)
accumulator_socket.connect(f"tcp://{accumulator_host}:{accumulator_port}")
visualisation_socket = zmq_context.socket(zmq.PUB)
visualisation_socket.connect(f"tcp://{visualisation_host}:{visualisation_port}")
# in case of problem with communication to visualisation, keep in 0mq buffer only few messages
visualisation_socket.set_hwm(10)
keep_pixels = None
r_radial_integration = None
center_radial_integration = None
@ -113,13 +97,10 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host
except Exception as e:
print(f"({pulseid}) problem ({e}) to read peakfinder parameters file, worker : {worker}", flush=True)
events = dict(poller.poll(2000)) # check every 2 seconds in each worker
if backend_socket not in events:
if not zmq_socks.has_data():
continue
metadata = backend_socket.recv_json(FLAGS)
image = backend_socket.recv(FLAGS, copy=False, track=False)
image = np.frombuffer(image, dtype=metadata["type"]).reshape(metadata["shape"])
image, metadata = zmq_socks.get_data()
results = copy(metadata)
if results["shape"][0] == 2 and results["shape"][1] == 2:
@ -360,7 +341,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host
results["shape"] = data.shape
accumulator_socket.send_json(results, FLAGS)
zmq_socks.send_accumulator(results)
send_empty_cond1 = (apply_aggregation and "aggregation_max" in results and not forceSendVisualisation)
@ -371,8 +352,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host
results["type"] = str(data.dtype)
results["shape"] = data.shape
visualisation_socket.send_json(results, FLAGS | zmq.SNDMORE)
visualisation_socket.send(data, FLAGS, copy=True, track=True)
zmq_socks.send_visualisation(results, data)

49
dap/zmqsocks.py Normal file
View File

@ -0,0 +1,49 @@
import numpy as np
import zmq
FLAGS = 0
class ZMQSockets:
def __init__(self, backend_address, accumulator_host, accumulator_port, visualisation_host, visualisation_port):
zmq_context = zmq.Context(io_threads=4)
self.poller = poller = zmq.Poller()
# receive from backend:
self.backend_socket = backend_socket = zmq_context.socket(zmq.PULL)
backend_socket.connect(backend_address)
poller.register(backend_socket, zmq.POLLIN)
self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PUSH)
accumulator_socket.connect(f"tcp://{accumulator_host}:{accumulator_port}")
self.visualisation_socket = visualisation_socket = zmq_context.socket(zmq.PUB)
visualisation_socket.connect(f"tcp://{visualisation_host}:{visualisation_port}")
# in case of problem with communication to visualisation, keep in 0mq buffer only few messages
visualisation_socket.set_hwm(10)
def has_data(self):
events = dict(self.poller.poll(2000)) # check every 2 seconds in each worker
return (self.backend_socket in events)
def get_data(self):
metadata = self.backend_socket.recv_json(FLAGS)
image = self.backend_socket.recv(FLAGS, copy=False, track=False)
image = np.frombuffer(image, dtype=metadata["type"]).reshape(metadata["shape"])
return image, metadata
def send_accumulator(self, results):
self.accumulator_socket.send_json(results, FLAGS)
def send_visualisation(self, results, data):
self.visualisation_socket.send_json(results, FLAGS | zmq.SNDMORE)
self.visualisation_socket.send(data, FLAGS, copy=True, track=True)