Refactor of ZmqSocket (#109)

Changed data type of address from char[1000] to string to reduce stack size of object.

Removed redundant calls to Close

Removed function exposing the socket descriptor

Using functions from network_utils instead of duplicates

Added tests
This commit is contained in:
Erik Fröjdh
2020-06-19 12:41:03 +02:00
committed by GitHub
parent 12b40a44a2
commit 5bf6b7a338
7 changed files with 154 additions and 163 deletions

View File

@ -1,35 +1,25 @@
#include "ZmqSocket.h"
#include "logger.h"
#include <arpa/inet.h> //inet_ntoa
#include <errno.h>
#include <iostream>
#include <netdb.h> //gethostbyname()
#include <string.h>
#include <unistd.h> //usleep in some machines
#include <vector>
#include <sstream>
#include <zmq.h>
#include "network_utils.h" //ip
using namespace rapidjson;
ZmqSocket::ZmqSocket(const char *const hostname_or_ip,
const uint32_t portnumber)
: portno(portnumber)
// headerMessage(0)
: portno(portnumber), sockfd(false)
{
char ip[MAX_STR_LENGTH] = "";
memset(ip, 0, MAX_STR_LENGTH);
// convert hostname to ip (not required, but a test that returns if failed)
struct addrinfo *result;
if ((ConvertHostnameToInternetAddress(hostname_or_ip, &result)) ||
(ConvertInternetAddresstoIpString(result, ip, MAX_STR_LENGTH)))
throw sls::ZmqSocketError("Could convert IP to string");
std::string sip(ip);
// construct address
sprintf(sockfd.serverAddress, "tcp://%s:%d", sip.c_str(), portno);
#ifdef VERBOSE
cprintf(BLUE, "address:%s\n", sockfd.serverAddress);
#endif
// Extra check that throws if conversion fails, could be removed
auto ipstr = sls::HostnameToIp(hostname_or_ip).str();
std::ostringstream oss;
oss << "tcp://" << ipstr << ":" << portno;
sockfd.serverAddress = oss.str();
LOG(logDEBUG) << "zmq address: " << sockfd.serverAddress;
// create context
sockfd.contextDescriptor = zmq_ctx_new();
@ -40,7 +30,6 @@ ZmqSocket::ZmqSocket(const char *const hostname_or_ip,
sockfd.socketDescriptor = zmq_socket(sockfd.contextDescriptor, ZMQ_SUB);
if (sockfd.socketDescriptor == nullptr) {
PrintError();
Close();
throw sls::ZmqSocketError("Could not create socket");
}
@ -48,104 +37,57 @@ ZmqSocket::ZmqSocket(const char *const hostname_or_ip,
// an empty string implies receiving any messages
if (zmq_setsockopt(sockfd.socketDescriptor, ZMQ_SUBSCRIBE, "", 0)) {
PrintError();
Close();
throw sls::ZmqSocketError("Could set socket opt");
}
// ZMQ_LINGER default is already -1 means no messages discarded. use this
// options if optimizing required ZMQ_SNDHWM default is 0 means no limit.
// use this to optimize if optimizing required eg. int value = -1;
int value = 0;
const int value = 0;
if (zmq_setsockopt(sockfd.socketDescriptor, ZMQ_LINGER, &value,
sizeof(value))) {
PrintError();
Close();
throw sls::ZmqSocketError("Could not set ZMQ_LINGER");
}
}
ZmqSocket::ZmqSocket(const uint32_t portnumber, const char *ethip)
:
portno(portnumber)
// headerMessage(0)
:portno(portnumber), sockfd(true)
{
sockfd.server = true;
// create context
sockfd.contextDescriptor = zmq_ctx_new();
if (sockfd.contextDescriptor == nullptr)
throw sls::ZmqSocketError("Could not create contextDescriptor");
// create publisher
sockfd.socketDescriptor = zmq_socket(sockfd.contextDescriptor, ZMQ_PUB);
if (sockfd.socketDescriptor == nullptr) {
PrintError();
Close();
throw sls::ZmqSocketError("Could not create socket");
}
// Socket Options provided above
// construct address, can be refactored with libfmt
std::ostringstream oss;
oss << "tcp://" << ethip << ":" << portno;
sockfd.serverAddress = oss.str();
LOG(logDEBUG) << "zmq address: " << sockfd.serverAddress;
// construct addresss
sprintf(sockfd.serverAddress, "tcp://%s:%d", ethip, portno);
#ifdef VERBOSE
cprintf(BLUE, "address:%s\n", sockfd.serverAddress);
#endif
// bind address
if (zmq_bind(sockfd.socketDescriptor, sockfd.serverAddress) < 0) {
if (zmq_bind(sockfd.socketDescriptor, sockfd.serverAddress.c_str())) {
PrintError();
Close();
throw sls::ZmqSocketError("Could not bind socket");
}
// sleep for a few milliseconds to allow a slow-joiner
usleep(200 * 1000);
};
int ZmqSocket::Connect() {
if (zmq_connect(sockfd.socketDescriptor, sockfd.serverAddress) < 0) {
if (zmq_connect(sockfd.socketDescriptor, sockfd.serverAddress.c_str())) {
PrintError();
return 1;
}
return 0;
}
int ZmqSocket::ConvertHostnameToInternetAddress(const char *const hostname,
struct addrinfo **res) {
// criteria in selecting socket address structures returned by res
struct addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
// get host info into res
int errcode = getaddrinfo(hostname, nullptr, &hints, res);
if (errcode != 0) {
LOG(logERROR) << "Error: Could not convert hostname " << hostname
<< " to internet address (zmq):" << gai_strerror(errcode);
} else {
if (*res == nullptr) {
LOG(logERROR) << "Could not convert hostname " << hostname
<< " to internet address (zmq): "
"gettaddrinfo returned null";
} else {
return 0;
}
}
LOG(logERROR) << "Could not convert hostname to internet address";
return 1;
};
int ZmqSocket::ConvertInternetAddresstoIpString(struct addrinfo *res, char *ip,
const int ipsize) {
if (inet_ntop(res->ai_family,
&((struct sockaddr_in *)res->ai_addr)->sin_addr, ip,
ipsize) != nullptr) {
freeaddrinfo(res);
return 0;
}
LOG(logERROR) << "Could not convert internet address to ip string";
return 1;
}
int ZmqSocket::SendHeader(int index, zmqHeader header) {
/** Json Header Format */
@ -182,8 +124,8 @@ int ZmqSocket::SendHeader(int index, zmqHeader header) {
"\"quad\":%u"
; //"}\n";
char buf[MAX_STR_LENGTH] = "";
sprintf(buf, jsonHeaderFormat, header.jsonversion, header.dynamicRange,
memset(header_buffer.get(),'\0',MAX_STR_LENGTH); //TODO! Do we need this
sprintf(header_buffer.get(), jsonHeaderFormat, header.jsonversion, header.dynamicRange,
header.fileIndex, header.ndetx, header.ndety, header.npixelsx,
header.npixelsy, header.imageSize, header.acqIndex,
header.frameIndex, header.progress, header.fname.c_str(),
@ -198,31 +140,31 @@ int ZmqSocket::SendHeader(int index, zmqHeader header) {
header.flippedDataX, header.quad);
if (header.addJsonHeader.size() > 0) {
strcat(buf, ", ");
strcat(buf, "\"addJsonHeader\": {");
strcat(header_buffer.get(), ", ");
strcat(header_buffer.get(), "\"addJsonHeader\": {");
for (auto it = header.addJsonHeader.begin();
it != header.addJsonHeader.end(); ++it) {
if (it != header.addJsonHeader.begin()) {
strcat(buf, ", ");
strcat(header_buffer.get(), ", ");
}
strcat(buf, "\"");
strcat(buf, it->first.c_str());
strcat(buf, "\":\"");
strcat(buf, it->second.c_str());
strcat(buf, "\"");
strcat(header_buffer.get(), "\"");
strcat(header_buffer.get(), it->first.c_str());
strcat(header_buffer.get(), "\":\"");
strcat(header_buffer.get(), it->second.c_str());
strcat(header_buffer.get(), "\"");
}
strcat(buf, " } ");
strcat(header_buffer.get(), " } ");
}
strcat(buf, "}\n");
int length = strlen(buf);
strcat(header_buffer.get(), "}\n");
int length = strlen(header_buffer.get());
#ifdef VERBOSE
// if(!index)
cprintf(BLUE, "%d : Streamer: buf: %s\n", index, buf);
#endif
if (zmq_send(sockfd.socketDescriptor, buf, length,
if (zmq_send(sockfd.socketDescriptor, header_buffer.get(), length,
header.data ? ZMQ_SNDMORE : 0) < 0) {
PrintError();
return 0;
@ -243,18 +185,17 @@ int ZmqSocket::SendData(char *buf, int length) {
int ZmqSocket::ReceiveHeader(const int index, zmqHeader &zHeader,
uint32_t version) {
std::vector<char> buffer(MAX_STR_LENGTH);
int len =
zmq_recv(sockfd.socketDescriptor, buffer.data(), buffer.size(), 0);
if (len > 0) {
const int bytes_received =
zmq_recv(sockfd.socketDescriptor, header_buffer.get(), MAX_STR_LENGTH, 0);
if (bytes_received > 0) {
#ifdef ZMQ_DETAIL
cprintf(BLUE, "Header %d [%d] Length: %d Header:%s \n", index, portno,
len, buffer.data());
bytes_received, buffer.data());
#endif
if (ParseHeader(index, len, buffer.data(), zHeader, version)) {
if (ParseHeader(index, bytes_received, header_buffer.get(), zHeader, version)) {
#ifdef ZMQ_DETAIL
cprintf(RED, "Parsed Header %d [%d] Length: %d Header:%s \n", index,
portno, len, buffer.data());
portno, bytes_received, buffer.data());
#endif
if (!zHeader.data) {
#ifdef ZMQ_DETAIL
@ -278,13 +219,10 @@ int ZmqSocket::ParseHeader(const int index, int length, char *buff,
if (document.Parse(buff, length).HasParseError()) {
LOG(logERROR) << index << " Could not parse. len:" << length
<< ": Message:" << buff;
fflush(stdout);
// char* buf = (char*) zmq_msg_data (&message);
for (int i = 0; i < length; ++i) {
cprintf(RED, "%02x ", buff[i]);
}
printf("\n");
fflush(stdout);
std::cout << std::endl;
return 0;
}
@ -435,17 +373,17 @@ void ZmqSocket::PrintError() {
}
// Nested class to do RAII handling of socket descriptors
ZmqSocket::mySocketDescriptors::mySocketDescriptors()
: server(false), contextDescriptor(nullptr), socketDescriptor(nullptr){};
ZmqSocket::mySocketDescriptors::mySocketDescriptors(bool server)
: server(server), contextDescriptor(nullptr), socketDescriptor(nullptr){};
ZmqSocket::mySocketDescriptors::~mySocketDescriptors() {
Disconnect();
Close();
}
void ZmqSocket::mySocketDescriptors::Disconnect() {
if (server)
zmq_unbind(socketDescriptor, serverAddress);
zmq_unbind(socketDescriptor, serverAddress.c_str());
else
zmq_disconnect(socketDescriptor, serverAddress);
zmq_disconnect(socketDescriptor, serverAddress.c_str());
};
void ZmqSocket::mySocketDescriptors::Close() {
if (socketDescriptor != nullptr) {