Files
dap/dap/zmqsocks.py

77 lines
2.3 KiB
Python

import numpy as np
import zmq
FLAGS = 0
class ZMQSocketsAccumulator:
def __init__(self, accumulator_addr):
zmq_context = zmq.Context(io_threads=4)
self.poller = poller = zmq.Poller()
# receive from workers:
self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PULL)
accumulator_socket.bind(accumulator_addr)
poller.register(accumulator_socket, zmq.POLLIN)
def has_data(self):
events = dict(self.poller.poll(10)) # check for worker output every 0.01 seconds
return (self.accumulator_socket in events)
def get_data(self):
return self.accumulator_socket.recv_json(FLAGS)
class ZMQSocketsWorker:
def __init__(self, backend_addr, accumulator_addr, visualisation_addr):
zmq_context = zmq.Context(io_threads=4)
self.poller = poller = zmq.Poller()
# receive from backend:
self.backend_socket = backend_socket = zmq_context.socket(zmq.PULL)
backend_socket.connect(backend_addr)
poller.register(backend_socket, zmq.POLLIN)
self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PUSH)
accumulator_socket.connect(accumulator_addr)
self.visualisation_socket = visualisation_socket = zmq_context.socket(zmq.PUB)
visualisation_socket.connect(visualisation_addr)
# in case of problem with communication to visualisation, keep in 0mq buffer only few messages
visualisation_socket.set_hwm(10)
def has_data(self):
events = dict(self.poller.poll(2000)) # check every 2 seconds in each worker
return (self.backend_socket in events)
def get_data(self):
metadata = self.backend_socket.recv_json(FLAGS)
image = self.backend_socket.recv(FLAGS, copy=False, track=False)
image = np.frombuffer(image, dtype=metadata["type"]).reshape(metadata["shape"])
return image, metadata
def send_accumulator(self, results):
self.accumulator_socket.send_json(results, FLAGS)
def send_visualisation(self, results, data):
self.visualisation_socket.send_json(results, FLAGS | zmq.SNDMORE)
self.visualisation_socket.send(data, FLAGS, copy=True, track=True)
def make_address(host, port)
return f"tcp://{host}:{port}"