703rc/fix port size (#802)

* validate port numbers in client

* validate port numbers created at virtual servers and receiver process as tcp ports
This commit is contained in:
maliakal_d 2023-09-18 08:59:53 +02:00 committed by GitHub
parent b367b7e431
commit 48759f440e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 148 additions and 11 deletions

View File

@ -13,6 +13,7 @@
#include "slsDetectorServer_funcs.h" #include "slsDetectorServer_funcs.h"
#include <getopt.h> #include <getopt.h>
#include <limits.h>
#include <signal.h> #include <signal.h>
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
@ -276,6 +277,14 @@ int main(int argc, char *argv[]) {
LOG(logERROR, ("Could not set handler function for SIGINT")); LOG(logERROR, ("Could not set handler function for SIGINT"));
} }
// validate control and stop port number
if (0 >= portno || portno > USHRT_MAX || 0 >= (portno + 1) ||
(portno + 1) > USHRT_MAX) {
LOG(logERROR, ("Invalid control server or stop server port "
"numbers (%d, %d). It must be in range 1 - %d",
portno, portno + 1, USHRT_MAX));
return -1;
}
if (sharedMemory_create(portno) == FAIL) { if (sharedMemory_create(portno) == FAIL) {
return -1; return -1;
} }

View File

@ -2008,7 +2008,7 @@ class Detector {
///@} ///@}
private: private:
std::vector<int> getPortNumbers(int start_port); std::vector<int> getValidPortNumbers(int start_port);
void updateRxRateCorrections(); void updateRxRateCorrections();
void setNumberofUDPInterfaces_(int n, Positions pos); void setNumberofUDPInterfaces_(int n, Positions pos);
}; };

View File

@ -108,6 +108,9 @@ void Detector::setHostname(const std::vector<std::string> &hostname) {
} }
void Detector::setVirtualDetectorServers(int numServers, int startingPort) { void Detector::setVirtualDetectorServers(int numServers, int startingPort) {
for (int i = 0; i != numServers; ++i) {
validatePortNumber(startingPort + i * 2);
}
pimpl->setVirtualDetectorServers(numServers, startingPort); pimpl->setVirtualDetectorServers(numServers, startingPort);
} }
@ -1087,12 +1090,13 @@ Result<int> Detector::getDestinationUDPPort(Positions pos) const {
void Detector::setDestinationUDPPort(int port, int module_id) { void Detector::setDestinationUDPPort(int port, int module_id) {
if (module_id == -1) { if (module_id == -1) {
std::vector<int> port_list = getPortNumbers(port); std::vector<int> port_list = getValidPortNumbers(port);
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
pimpl->Parallel(&Module::setDestinationUDPPort, {idet}, pimpl->Parallel(&Module::setDestinationUDPPort, {idet},
port_list[idet]); port_list[idet]);
} }
} else { } else {
validatePortNumber(port);
pimpl->Parallel(&Module::setDestinationUDPPort, {module_id}, port); pimpl->Parallel(&Module::setDestinationUDPPort, {module_id}, port);
} }
} }
@ -1103,12 +1107,13 @@ Result<int> Detector::getDestinationUDPPort2(Positions pos) const {
void Detector::setDestinationUDPPort2(int port, int module_id) { void Detector::setDestinationUDPPort2(int port, int module_id) {
if (module_id == -1) { if (module_id == -1) {
std::vector<int> port_list = getPortNumbers(port); std::vector<int> port_list = getValidPortNumbers(port);
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
pimpl->Parallel(&Module::setDestinationUDPPort2, {idet}, pimpl->Parallel(&Module::setDestinationUDPPort2, {idet},
port_list[idet]); port_list[idet]);
} }
} else { } else {
validatePortNumber(port);
pimpl->Parallel(&Module::setDestinationUDPPort2, {module_id}, port); pimpl->Parallel(&Module::setDestinationUDPPort2, {module_id}, port);
} }
} }
@ -1220,9 +1225,11 @@ void Detector::setRxPort(int port, int module_id) {
it = port++; it = port++;
} }
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
validatePortNumber(port_list[idet]);
pimpl->Parallel(&Module::setReceiverPort, {idet}, port_list[idet]); pimpl->Parallel(&Module::setReceiverPort, {idet}, port_list[idet]);
} }
} else { } else {
validatePortNumber(port);
pimpl->Parallel(&Module::setReceiverPort, {module_id}, port); pimpl->Parallel(&Module::setReceiverPort, {module_id}, port);
} }
} }
@ -1420,12 +1427,13 @@ void Detector::setRxZmqPort(int port, int module_id) {
bool previouslyReceiverStreaming = bool previouslyReceiverStreaming =
getRxZmqDataStream(std::vector<int>{module_id}).squash(false); getRxZmqDataStream(std::vector<int>{module_id}).squash(false);
if (module_id == -1) { if (module_id == -1) {
std::vector<int> port_list = getPortNumbers(port); std::vector<int> port_list = getValidPortNumbers(port);
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
pimpl->Parallel(&Module::setReceiverStreamingPort, {idet}, pimpl->Parallel(&Module::setReceiverStreamingPort, {idet},
port_list[idet]); port_list[idet]);
} }
} else { } else {
validatePortNumber(port);
pimpl->Parallel(&Module::setReceiverStreamingPort, {module_id}, port); pimpl->Parallel(&Module::setReceiverStreamingPort, {module_id}, port);
} }
if (previouslyReceiverStreaming) { if (previouslyReceiverStreaming) {
@ -1454,12 +1462,13 @@ Result<int> Detector::getClientZmqPort(Positions pos) const {
void Detector::setClientZmqPort(int port, int module_id) { void Detector::setClientZmqPort(int port, int module_id) {
bool previouslyClientStreaming = pimpl->getDataStreamingToClient(); bool previouslyClientStreaming = pimpl->getDataStreamingToClient();
if (module_id == -1) { if (module_id == -1) {
std::vector<int> port_list = getPortNumbers(port); std::vector<int> port_list = getValidPortNumbers(port);
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
pimpl->Parallel(&Module::setClientStreamingPort, {idet}, pimpl->Parallel(&Module::setClientStreamingPort, {idet},
port_list[idet]); port_list[idet]);
} }
} else { } else {
validatePortNumber(port);
pimpl->Parallel(&Module::setClientStreamingPort, {module_id}, port); pimpl->Parallel(&Module::setClientStreamingPort, {module_id}, port);
} }
if (previouslyClientStreaming) { if (previouslyClientStreaming) {
@ -2463,6 +2472,7 @@ Result<int> Detector::getControlPort(Positions pos) const {
} }
void Detector::setControlPort(int value, Positions pos) { void Detector::setControlPort(int value, Positions pos) {
validatePortNumber(value);
pimpl->Parallel(&Module::setControlPort, pos, value); pimpl->Parallel(&Module::setControlPort, pos, value);
} }
@ -2471,6 +2481,7 @@ Result<int> Detector::getStopPort(Positions pos) const {
} }
void Detector::setStopPort(int value, Positions pos) { void Detector::setStopPort(int value, Positions pos) {
validatePortNumber(value);
pimpl->Parallel(&Module::setStopPort, pos, value); pimpl->Parallel(&Module::setStopPort, pos, value);
} }
@ -2505,13 +2516,17 @@ Result<ns> Detector::getMeasurementTime(Positions pos) const {
std::string Detector::getUserDetails() const { return pimpl->getUserDetails(); } std::string Detector::getUserDetails() const { return pimpl->getUserDetails(); }
std::vector<int> Detector::getPortNumbers(int start_port) { std::vector<int> Detector::getValidPortNumbers(int start_port) {
int num_sockets_per_detector = getNumberofUDPInterfaces({}).tsquash( int num_sockets_per_detector = getNumberofUDPInterfaces({}).tsquash(
"Number of UDP Interfaces is not consistent among modules"); "Number of UDP Interfaces is not consistent among modules");
std::vector<int> res; std::vector<int> res;
res.reserve(size()); res.reserve(size());
for (int idet = 0; idet < size(); ++idet) { for (int idet = 0; idet < size(); ++idet) {
res.push_back(start_port + (idet * num_sockets_per_detector)); int port = start_port + (idet * num_sockets_per_detector);
for (int i = 0; i != num_sockets_per_detector; ++i) {
validatePortNumber(port + i);
}
res.push_back(port);
} }
return res; return res;
} }

View File

@ -288,6 +288,7 @@ void DetectorImpl::addModule(const std::string &hostname) {
if (res.size() > 1) { if (res.size() > 1) {
host = res[0]; host = res[0];
port = StringTo<int>(res[1]); port = StringTo<int>(res[1]);
validatePortNumber(port);
} }
if (host != "localhost") { if (host != "localhost") {

View File

@ -1350,7 +1350,9 @@ void Module::setReceiverHostname(const std::string &receiverIP,
auto res = split(host, ':'); auto res = split(host, ':');
if (res.size() > 1) { if (res.size() > 1) {
host = res[0]; host = res[0];
shm()->rxTCPPort = std::stoi(res[1]); int port = StringTo<int>(res[1]);
validatePortNumber(port);
shm()->rxTCPPort = port;
} }
strcpy_safe(shm()->rxHostname, host.c_str()); strcpy_safe(shm()->rxHostname, host.c_str());
shm()->useReceiverFlag = true; shm()->useReceiverFlag = true;

View File

@ -11,6 +11,31 @@ namespace sls {
using test::GET; using test::GET;
using test::PUT; using test::PUT;
void test_valid_port(const std::string &command,
const std::vector<std::string> &arguments, int detector_id,
int action, int port_number) {
Detector det;
CmdProxy proxy(&det);
std::string string_port_number = std::to_string(port_number);
REQUIRE_THROWS_WITH(proxy.Call(command, arguments, detector_id, action),
"Invalid port number " + string_port_number +
". It must be in range 1 - 65535");
}
void test_valid_port(const std::string &command,
const std::vector<std::string> &arguments, int detector_id,
int action) {
std::vector<std::string> arg(arguments);
arg.push_back("0");
int test_values[2] = {77797, -1};
for (int i = 0; i != 2; ++i) {
int port_number = test_values[i];
arg[arg.size() - 1] = std::to_string(port_number);
test_valid_port(command, arg, detector_id, action, port_number);
}
}
void test_dac(defs::dacIndex index, const std::string &dacname, int dacvalue) { void test_dac(defs::dacIndex index, const std::string &dacname, int dacvalue) {
Detector det; Detector det;
CmdProxy proxy(&det); CmdProxy proxy(&det);

View File

@ -5,6 +5,14 @@
namespace sls { namespace sls {
void test_valid_port(const std::string &command,
const std::vector<std::string> &arguments, int detector_id,
int action, int port_number);
void test_valid_port(const std::string &command,
const std::vector<std::string> &arguments, int detector_id,
int action);
void test_dac(slsDetectorDefs::dacIndex index, const std::string &dacname, void test_dac(slsDetectorDefs::dacIndex index, const std::string &dacname,
int dacvalue); int dacvalue);
void test_onchip_dac(slsDetectorDefs::dacIndex index, void test_onchip_dac(slsDetectorDefs::dacIndex index,

View File

@ -5,6 +5,7 @@
#include "sls/Detector.h" #include "sls/Detector.h"
#include "sls/Version.h" #include "sls/Version.h"
#include "sls/sls_detector_defs.h" #include "sls/sls_detector_defs.h"
#include "test-CmdProxy-global.h"
#include <sstream> #include <sstream>
#include "sls/versionAPI.h" #include "sls/versionAPI.h"
@ -237,6 +238,22 @@ TEST_CASE("rx_tcpport", "[.cmd][.rx]") {
proxy.Call("rx_tcpport", {}, i, GET, oss); proxy.Call("rx_tcpport", {}, i, GET, oss);
REQUIRE(oss.str() == "rx_tcpport " + std::to_string(port + i) + '\n'); REQUIRE(oss.str() == "rx_tcpport " + std::to_string(port + i) + '\n');
} }
test_valid_port("rx_tcpport", {}, -1, PUT);
test_valid_port("rx_tcpport", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("rx_tcpport", {"65535"}, -1, PUT, 65536);
auto rxHostname = det.getRxHostname().squash("none");
if (rxHostname != "none") {
std::ostringstream oss;
for (int i = 0; i != det.size(); ++i) {
oss << rxHostname << ":" << 65536 + i << "+";
}
test_valid_port("rx_hostname", {oss.str()}, -1, PUT, 65536);
}
}
for (int i = 0; i != det.size(); ++i) { for (int i = 0; i != det.size(); ++i) {
det.setRxPort(prev_val[i], i); det.setRxPort(prev_val[i], i);
} }
@ -828,6 +845,12 @@ TEST_CASE("rx_zmqport", "[.cmd][.rx]") {
std::to_string(port + i * socketsperdetector) + std::to_string(port + i * socketsperdetector) +
'\n'); '\n');
} }
test_valid_port("rx_zmqport", {}, -1, PUT);
test_valid_port("rx_zmqport", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("rx_zmqport", {"65535"}, -1, PUT, 65536);
}
for (int i = 0; i != det.size(); ++i) { for (int i = 0; i != det.size(); ++i) {
det.setRxZmqPort(prev_val_zmqport[i], i); det.setRxZmqPort(prev_val_zmqport[i], i);
} }

View File

@ -5,6 +5,7 @@
#include "sls/Detector.h" #include "sls/Detector.h"
#include "sls/file_utils.h" #include "sls/file_utils.h"
#include "sls/sls_detector_defs.h" #include "sls/sls_detector_defs.h"
#include "test-CmdProxy-global.h"
#include <chrono> #include <chrono>
#include <sstream> #include <sstream>
@ -76,7 +77,13 @@ TEST_CASE("hostname", "[.cmd]") {
REQUIRE_NOTHROW(proxy.Call("hostname", {}, -1, GET)); REQUIRE_NOTHROW(proxy.Call("hostname", {}, -1, GET));
} }
// virtual: not testing TEST_CASE("virtual", "[.cmd]") {
Detector det;
CmdProxy proxy(&det);
REQUIRE_THROWS(proxy.Call("virtual", {}, -1, GET));
test_valid_port("virtual", {"1"}, -1, PUT);
test_valid_port("virtual", {"3", "65534"}, -1, PUT, 65536);
}
TEST_CASE("versions", "[.cmd]") { TEST_CASE("versions", "[.cmd]") {
Detector det; Detector det;
@ -2618,6 +2625,13 @@ TEST_CASE("udp_dstport", "[.cmd]") {
proxy.Call("udp_dstport", {"50084"}, -1, PUT, oss); proxy.Call("udp_dstport", {"50084"}, -1, PUT, oss);
REQUIRE(oss.str() == "udp_dstport 50084\n"); REQUIRE(oss.str() == "udp_dstport 50084\n");
} }
test_valid_port("udp_dstport", {}, -1, PUT);
test_valid_port("udp_dstport", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("udp_dstport", {"65535"}, -1, PUT, 65536);
}
for (int i = 0; i != det.size(); ++i) { for (int i = 0; i != det.size(); ++i) {
det.setDestinationUDPPort(prev_val[i], {i}); det.setDestinationUDPPort(prev_val[i], {i});
} }
@ -2702,8 +2716,18 @@ TEST_CASE("udp_dstport2", "[.cmd]") {
proxy.Call("udp_dstport2", {"50084"}, -1, PUT, oss); proxy.Call("udp_dstport2", {"50084"}, -1, PUT, oss);
REQUIRE(oss.str() == "udp_dstport2 50084\n"); REQUIRE(oss.str() == "udp_dstport2 50084\n");
} }
test_valid_port("udp_dstport2", {}, -1, PUT);
test_valid_port("udp_dstport2", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("udp_dstport2", {"65535"}, -1, PUT, 65536);
}
for (int i = 0; i != det.size(); ++i) { for (int i = 0; i != det.size(); ++i) {
det.setDestinationUDPPort2(prev_val[i], {i}); if (prev_val[i] != 0) {
det.setDestinationUDPPort2(prev_val[i], {i});
}
} }
} else { } else {
REQUIRE_THROWS(proxy.Call("udp_dstport2", {}, -1, GET)); REQUIRE_THROWS(proxy.Call("udp_dstport2", {}, -1, GET));
@ -2922,6 +2946,13 @@ TEST_CASE("zmqport", "[.cmd]") {
std::to_string(port + i * socketsperdetector) + std::to_string(port + i * socketsperdetector) +
'\n'); '\n');
} }
test_valid_port("zmqport", {}, -1, PUT);
test_valid_port("zmqport", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("zmqport", {"65535"}, -1, PUT, 65536);
}
if (det_type == defs::JUNGFRAU) { if (det_type == defs::JUNGFRAU) {
det.setNumberofUDPInterfaces(prev); det.setNumberofUDPInterfaces(prev);
} }
@ -3266,6 +3297,13 @@ TEST_CASE("port", "[.cmd]") {
proxy.Call("port", {}, 0, GET, oss); proxy.Call("port", {}, 0, GET, oss);
REQUIRE(oss.str() == "port 1942\n"); REQUIRE(oss.str() == "port 1942\n");
} }
test_valid_port("port", {}, -1, PUT);
test_valid_port("port", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("port", {"65536"}, -1, PUT, 65536);
}
det.setControlPort(prev_val, {0}); det.setControlPort(prev_val, {0});
} }
@ -3283,6 +3321,12 @@ TEST_CASE("stopport", "[.cmd]") {
proxy.Call("stopport", {}, 0, GET, oss); proxy.Call("stopport", {}, 0, GET, oss);
REQUIRE(oss.str() == "stopport 1942\n"); REQUIRE(oss.str() == "stopport 1942\n");
} }
test_valid_port("stopport", {}, -1, PUT);
test_valid_port("stopport", {}, 0, PUT);
// should fail for the second module
if (det.size() > 1) {
test_valid_port("stopport", {"65536"}, -1, PUT, 65536);
}
det.setStopPort(prev_val, {0}); det.setStopPort(prev_val, {0});
} }

View File

@ -45,6 +45,7 @@ ClientInterface::ClientInterface(int portNumber)
: detType(GOTTHARD), : detType(GOTTHARD),
portNumber(portNumber > 0 ? portNumber : DEFAULT_TCP_RX_PORTNO), portNumber(portNumber > 0 ? portNumber : DEFAULT_TCP_RX_PORTNO),
server(portNumber) { server(portNumber) {
validatePortNumber(portNumber);
functionTable(); functionTable();
parentThreadId = gettid(); parentThreadId = gettid();
tcpThread = tcpThread =

View File

@ -88,5 +88,5 @@ IpAddr HostnameToIp(const char *hostname);
std::string IpToInterfaceName(const std::string &ip); std::string IpToInterfaceName(const std::string &ip);
MacAddr InterfaceNameToMac(const std::string &inf); MacAddr InterfaceNameToMac(const std::string &inf);
IpAddr InterfaceNameToIp(const std::string &ifn); IpAddr InterfaceNameToIp(const std::string &ifn);
void validatePortNumber(int port);
} // namespace sls } // namespace sls

View File

@ -203,4 +203,13 @@ MacAddr InterfaceNameToMac(const std::string &inf) {
return MacAddr(mac); return MacAddr(mac);
} }
void validatePortNumber(int port) {
if (0 >= port || port > std::numeric_limits<uint16_t>::max()) {
std::ostringstream oss;
oss << "Invalid port number " << port << ". It must be in range 1 - "
<< std::numeric_limits<uint16_t>::max();
throw RuntimeError(oss.str());
}
}
} // namespace sls } // namespace sls