Issue
This Content is from Stack Overflow. Question asked by danijar
I would like to use ZMQ in a client/server model but where each server waits until it has received 100 requests, processes them jointly, and then sends out the 100 responses back to the correct clients. The reason behind this is that the server performs a computation on GPU that is only computationally efficient when performed on batches. How can this be done with ZMQ?
Below is what I tried, which unsurprisingly raises zmq.error.ZMQError: Operation cannot be accomplished in current state
because the server is trying to receive multiple requests in sequence without intermediately interleaving the recv_pyobj()
calls with send_pyobj()
responses.
import multiprocessing as mp
import numpy as np
import time
import zmq
def computation(inputs):
time.sleep(1) # Simulate constant GPU overhead.
results = np.zeros((len(inputs), 8))
return results
def server(port, batch=100):
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f'tcp://*:{port}')
while True:
inputs = np.empty((100, 64))
for i in range(batch):
inputs[i] = socket.recv_pyobj()
results = computation(inputs)
for i in range(batch):
socket.send_pyobj(results[i])
def client(ports):
context = zmq.Context()
socket = context.socket(zmq.REQ)
for port in ports:
socket.connect(f'tcp://localhost:{port}')
while True:
input_ = np.zeros(64)
socket.send_pyobj(input_)
result = socket.recv_pyobj()
if __name__ == '__main__':
num_clients = 10
num_servers = 3
ports = list(range(5550, 5550 + num_servers))
for port in ports:
mp.Process(target=server, args=(port,)).start()
for _ in range(num_clients):
mp.Process(target=client, args=(ports,)).start()
Solution
The ROUTER
socket type combined with recv_multipart()
and send_multipart()
is what I was looking for. The most useful resources was the rtreq
example under Advanced Request-Reply Patterns in the official guide. Besides that, I added msgpack
serialization and reduced the batch size because it has to be smaller than the number of clients. The working version of the snippet from the question is below.
import msgpack
import multiprocessing as mp
import numpy as np
import time
import uuid
import zmq
def computation(inputs):
time.sleep(1) # Simulate constant GPU overhead.
results = np.zeros((len(inputs), 8))
return results
def server(port, batch=10):
context = zmq.Context.instance()
socket = context.socket(zmq.ROUTER)
socket.bind(f'tcp://*:{port}')
while True:
inputs = np.empty((batch, 64))
addresses = []
for i in range(batch):
address, empty, payload = socket.recv_multipart()
inputs[i] = unpack(payload)
addresses.append(address)
print('Collected request batch.', flush=True)
results = computation(inputs)
for i, address in enumerate(addresses):
payload = pack(results[i])
socket.send_multipart([address, b'', payload])
print('Send response batch.', flush=True)
def client(ports):
context = zmq.Context.instance()
socket = context.socket(zmq.REQ)
socket.setsockopt(zmq.IDENTITY, uuid.uuid4().bytes)
for port in ports:
socket.connect(f'tcp://localhost:{port}')
while True:
input_ = np.zeros(64)
socket.send(pack(input_))
result = unpack(socket.recv())
def pack(array):
return msgpack.packb((array.shape, str(array.dtype), array.tobytes()))
def unpack(buffer):
shape, dtype, value = msgpack.unpackb(buffer)
return np.frombuffer(value, dtype).reshape(shape)
if __name__ == '__main__':
num_clients = 100
num_servers = 3
ports = list(range(5550, 5550 + num_servers))
for port in ports:
mp.Process(target=server, args=(port,)).start()
for _ in range(num_clients):
mp.Process(target=client, args=(ports,)).start()
This Question was asked in StackOverflow by danijar and Answered by danijar It is licensed under the terms of CC BY-SA 2.5. - CC BY-SA 3.0. - CC BY-SA 4.0.