From 48759f440e81fae26c9746fe2cd334241c9412bb Mon Sep 17 00:00:00 2001 From: Dhanya Thattil Date: Mon, 18 Sep 2023 08:59:53 +0200 Subject: [PATCH] 703rc/fix port size (#802) * validate port numbers in client * validate port numbers created at virtual servers and receiver process as tcp ports --- .../slsDetectorServer/src/slsDetectorServer.c | 9 ++++ slsDetectorSoftware/include/sls/Detector.h | 2 +- slsDetectorSoftware/src/Detector.cpp | 27 ++++++++--- slsDetectorSoftware/src/DetectorImpl.cpp | 1 + slsDetectorSoftware/src/Module.cpp | 4 +- .../tests/test-CmdProxy-global.cpp | 25 ++++++++++ .../tests/test-CmdProxy-global.h | 8 ++++ .../tests/test-CmdProxy-rx.cpp | 23 +++++++++ slsDetectorSoftware/tests/test-CmdProxy.cpp | 48 ++++++++++++++++++- slsReceiverSoftware/src/ClientInterface.cpp | 1 + slsSupportLib/include/sls/network_utils.h | 2 +- slsSupportLib/src/network_utils.cpp | 9 ++++ 12 files changed, 148 insertions(+), 11 deletions(-) diff --git a/slsDetectorServers/slsDetectorServer/src/slsDetectorServer.c b/slsDetectorServers/slsDetectorServer/src/slsDetectorServer.c index 3542ba431..f893a34cb 100644 --- a/slsDetectorServers/slsDetectorServer/src/slsDetectorServer.c +++ b/slsDetectorServers/slsDetectorServer/src/slsDetectorServer.c @@ -13,6 +13,7 @@ #include "slsDetectorServer_funcs.h" #include +#include #include #include #include @@ -276,6 +277,14 @@ int main(int argc, char *argv[]) { LOG(logERROR, ("Could not set handler function for SIGINT")); } + // validate control and stop port number + if (0 >= portno || portno > USHRT_MAX || 0 >= (portno + 1) || + (portno + 1) > USHRT_MAX) { + LOG(logERROR, ("Invalid control server or stop server port " + "numbers (%d, %d). It must be in range 1 - %d", + portno, portno + 1, USHRT_MAX)); + return -1; + } if (sharedMemory_create(portno) == FAIL) { return -1; } diff --git a/slsDetectorSoftware/include/sls/Detector.h b/slsDetectorSoftware/include/sls/Detector.h index fb106df89..e4c0d30ed 100644 --- a/slsDetectorSoftware/include/sls/Detector.h +++ b/slsDetectorSoftware/include/sls/Detector.h @@ -2008,7 +2008,7 @@ class Detector { ///@} private: - std::vector getPortNumbers(int start_port); + std::vector getValidPortNumbers(int start_port); void updateRxRateCorrections(); void setNumberofUDPInterfaces_(int n, Positions pos); }; diff --git a/slsDetectorSoftware/src/Detector.cpp b/slsDetectorSoftware/src/Detector.cpp index 32e7ab047..6ab1d8911 100644 --- a/slsDetectorSoftware/src/Detector.cpp +++ b/slsDetectorSoftware/src/Detector.cpp @@ -108,6 +108,9 @@ void Detector::setHostname(const std::vector &hostname) { } void Detector::setVirtualDetectorServers(int numServers, int startingPort) { + for (int i = 0; i != numServers; ++i) { + validatePortNumber(startingPort + i * 2); + } pimpl->setVirtualDetectorServers(numServers, startingPort); } @@ -1087,12 +1090,13 @@ Result Detector::getDestinationUDPPort(Positions pos) const { void Detector::setDestinationUDPPort(int port, int module_id) { if (module_id == -1) { - std::vector port_list = getPortNumbers(port); + std::vector port_list = getValidPortNumbers(port); for (int idet = 0; idet < size(); ++idet) { pimpl->Parallel(&Module::setDestinationUDPPort, {idet}, port_list[idet]); } } else { + validatePortNumber(port); pimpl->Parallel(&Module::setDestinationUDPPort, {module_id}, port); } } @@ -1103,12 +1107,13 @@ Result Detector::getDestinationUDPPort2(Positions pos) const { void Detector::setDestinationUDPPort2(int port, int module_id) { if (module_id == -1) { - std::vector port_list = getPortNumbers(port); + std::vector port_list = getValidPortNumbers(port); for (int idet = 0; idet < size(); ++idet) { pimpl->Parallel(&Module::setDestinationUDPPort2, {idet}, port_list[idet]); } } else { + validatePortNumber(port); pimpl->Parallel(&Module::setDestinationUDPPort2, {module_id}, port); } } @@ -1220,9 +1225,11 @@ void Detector::setRxPort(int port, int module_id) { it = port++; } for (int idet = 0; idet < size(); ++idet) { + validatePortNumber(port_list[idet]); pimpl->Parallel(&Module::setReceiverPort, {idet}, port_list[idet]); } } else { + validatePortNumber(port); pimpl->Parallel(&Module::setReceiverPort, {module_id}, port); } } @@ -1420,12 +1427,13 @@ void Detector::setRxZmqPort(int port, int module_id) { bool previouslyReceiverStreaming = getRxZmqDataStream(std::vector{module_id}).squash(false); if (module_id == -1) { - std::vector port_list = getPortNumbers(port); + std::vector port_list = getValidPortNumbers(port); for (int idet = 0; idet < size(); ++idet) { pimpl->Parallel(&Module::setReceiverStreamingPort, {idet}, port_list[idet]); } } else { + validatePortNumber(port); pimpl->Parallel(&Module::setReceiverStreamingPort, {module_id}, port); } if (previouslyReceiverStreaming) { @@ -1454,12 +1462,13 @@ Result Detector::getClientZmqPort(Positions pos) const { void Detector::setClientZmqPort(int port, int module_id) { bool previouslyClientStreaming = pimpl->getDataStreamingToClient(); if (module_id == -1) { - std::vector port_list = getPortNumbers(port); + std::vector port_list = getValidPortNumbers(port); for (int idet = 0; idet < size(); ++idet) { pimpl->Parallel(&Module::setClientStreamingPort, {idet}, port_list[idet]); } } else { + validatePortNumber(port); pimpl->Parallel(&Module::setClientStreamingPort, {module_id}, port); } if (previouslyClientStreaming) { @@ -2463,6 +2472,7 @@ Result Detector::getControlPort(Positions pos) const { } void Detector::setControlPort(int value, Positions pos) { + validatePortNumber(value); pimpl->Parallel(&Module::setControlPort, pos, value); } @@ -2471,6 +2481,7 @@ Result Detector::getStopPort(Positions pos) const { } void Detector::setStopPort(int value, Positions pos) { + validatePortNumber(value); pimpl->Parallel(&Module::setStopPort, pos, value); } @@ -2505,13 +2516,17 @@ Result Detector::getMeasurementTime(Positions pos) const { std::string Detector::getUserDetails() const { return pimpl->getUserDetails(); } -std::vector Detector::getPortNumbers(int start_port) { +std::vector Detector::getValidPortNumbers(int start_port) { int num_sockets_per_detector = getNumberofUDPInterfaces({}).tsquash( "Number of UDP Interfaces is not consistent among modules"); std::vector res; res.reserve(size()); for (int idet = 0; idet < size(); ++idet) { - res.push_back(start_port + (idet * num_sockets_per_detector)); + int port = start_port + (idet * num_sockets_per_detector); + for (int i = 0; i != num_sockets_per_detector; ++i) { + validatePortNumber(port + i); + } + res.push_back(port); } return res; } diff --git a/slsDetectorSoftware/src/DetectorImpl.cpp b/slsDetectorSoftware/src/DetectorImpl.cpp index 15bde3697..0887ecc0a 100644 --- a/slsDetectorSoftware/src/DetectorImpl.cpp +++ b/slsDetectorSoftware/src/DetectorImpl.cpp @@ -288,6 +288,7 @@ void DetectorImpl::addModule(const std::string &hostname) { if (res.size() > 1) { host = res[0]; port = StringTo(res[1]); + validatePortNumber(port); } if (host != "localhost") { diff --git a/slsDetectorSoftware/src/Module.cpp b/slsDetectorSoftware/src/Module.cpp index 8249ef1e8..f24528f1a 100644 --- a/slsDetectorSoftware/src/Module.cpp +++ b/slsDetectorSoftware/src/Module.cpp @@ -1350,7 +1350,9 @@ void Module::setReceiverHostname(const std::string &receiverIP, auto res = split(host, ':'); if (res.size() > 1) { host = res[0]; - shm()->rxTCPPort = std::stoi(res[1]); + int port = StringTo(res[1]); + validatePortNumber(port); + shm()->rxTCPPort = port; } strcpy_safe(shm()->rxHostname, host.c_str()); shm()->useReceiverFlag = true; diff --git a/slsDetectorSoftware/tests/test-CmdProxy-global.cpp b/slsDetectorSoftware/tests/test-CmdProxy-global.cpp index e89c38aef..23a4e013c 100644 --- a/slsDetectorSoftware/tests/test-CmdProxy-global.cpp +++ b/slsDetectorSoftware/tests/test-CmdProxy-global.cpp @@ -11,6 +11,31 @@ namespace sls { using test::GET; using test::PUT; +void test_valid_port(const std::string &command, + const std::vector &arguments, int detector_id, + int action, int port_number) { + Detector det; + CmdProxy proxy(&det); + std::string string_port_number = std::to_string(port_number); + REQUIRE_THROWS_WITH(proxy.Call(command, arguments, detector_id, action), + "Invalid port number " + string_port_number + + ". It must be in range 1 - 65535"); +} + +void test_valid_port(const std::string &command, + const std::vector &arguments, int detector_id, + int action) { + std::vector arg(arguments); + arg.push_back("0"); + + int test_values[2] = {77797, -1}; + for (int i = 0; i != 2; ++i) { + int port_number = test_values[i]; + arg[arg.size() - 1] = std::to_string(port_number); + test_valid_port(command, arg, detector_id, action, port_number); + } +} + void test_dac(defs::dacIndex index, const std::string &dacname, int dacvalue) { Detector det; CmdProxy proxy(&det); diff --git a/slsDetectorSoftware/tests/test-CmdProxy-global.h b/slsDetectorSoftware/tests/test-CmdProxy-global.h index ab869529b..20b1882f3 100644 --- a/slsDetectorSoftware/tests/test-CmdProxy-global.h +++ b/slsDetectorSoftware/tests/test-CmdProxy-global.h @@ -5,6 +5,14 @@ namespace sls { +void test_valid_port(const std::string &command, + const std::vector &arguments, int detector_id, + int action, int port_number); + +void test_valid_port(const std::string &command, + const std::vector &arguments, int detector_id, + int action); + void test_dac(slsDetectorDefs::dacIndex index, const std::string &dacname, int dacvalue); void test_onchip_dac(slsDetectorDefs::dacIndex index, diff --git a/slsDetectorSoftware/tests/test-CmdProxy-rx.cpp b/slsDetectorSoftware/tests/test-CmdProxy-rx.cpp index 2cc550812..84e1d9fcb 100644 --- a/slsDetectorSoftware/tests/test-CmdProxy-rx.cpp +++ b/slsDetectorSoftware/tests/test-CmdProxy-rx.cpp @@ -5,6 +5,7 @@ #include "sls/Detector.h" #include "sls/Version.h" #include "sls/sls_detector_defs.h" +#include "test-CmdProxy-global.h" #include #include "sls/versionAPI.h" @@ -237,6 +238,22 @@ TEST_CASE("rx_tcpport", "[.cmd][.rx]") { proxy.Call("rx_tcpport", {}, i, GET, oss); REQUIRE(oss.str() == "rx_tcpport " + std::to_string(port + i) + '\n'); } + + test_valid_port("rx_tcpport", {}, -1, PUT); + test_valid_port("rx_tcpport", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("rx_tcpport", {"65535"}, -1, PUT, 65536); + auto rxHostname = det.getRxHostname().squash("none"); + if (rxHostname != "none") { + std::ostringstream oss; + for (int i = 0; i != det.size(); ++i) { + oss << rxHostname << ":" << 65536 + i << "+"; + } + test_valid_port("rx_hostname", {oss.str()}, -1, PUT, 65536); + } + } + for (int i = 0; i != det.size(); ++i) { det.setRxPort(prev_val[i], i); } @@ -828,6 +845,12 @@ TEST_CASE("rx_zmqport", "[.cmd][.rx]") { std::to_string(port + i * socketsperdetector) + '\n'); } + test_valid_port("rx_zmqport", {}, -1, PUT); + test_valid_port("rx_zmqport", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("rx_zmqport", {"65535"}, -1, PUT, 65536); + } for (int i = 0; i != det.size(); ++i) { det.setRxZmqPort(prev_val_zmqport[i], i); } diff --git a/slsDetectorSoftware/tests/test-CmdProxy.cpp b/slsDetectorSoftware/tests/test-CmdProxy.cpp index 2ba4c7a9c..5d6a9c1e7 100644 --- a/slsDetectorSoftware/tests/test-CmdProxy.cpp +++ b/slsDetectorSoftware/tests/test-CmdProxy.cpp @@ -5,6 +5,7 @@ #include "sls/Detector.h" #include "sls/file_utils.h" #include "sls/sls_detector_defs.h" +#include "test-CmdProxy-global.h" #include #include @@ -76,7 +77,13 @@ TEST_CASE("hostname", "[.cmd]") { REQUIRE_NOTHROW(proxy.Call("hostname", {}, -1, GET)); } -// virtual: not testing +TEST_CASE("virtual", "[.cmd]") { + Detector det; + CmdProxy proxy(&det); + REQUIRE_THROWS(proxy.Call("virtual", {}, -1, GET)); + test_valid_port("virtual", {"1"}, -1, PUT); + test_valid_port("virtual", {"3", "65534"}, -1, PUT, 65536); +} TEST_CASE("versions", "[.cmd]") { Detector det; @@ -2618,6 +2625,13 @@ TEST_CASE("udp_dstport", "[.cmd]") { proxy.Call("udp_dstport", {"50084"}, -1, PUT, oss); REQUIRE(oss.str() == "udp_dstport 50084\n"); } + test_valid_port("udp_dstport", {}, -1, PUT); + test_valid_port("udp_dstport", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("udp_dstport", {"65535"}, -1, PUT, 65536); + } + for (int i = 0; i != det.size(); ++i) { det.setDestinationUDPPort(prev_val[i], {i}); } @@ -2702,8 +2716,18 @@ TEST_CASE("udp_dstport2", "[.cmd]") { proxy.Call("udp_dstport2", {"50084"}, -1, PUT, oss); REQUIRE(oss.str() == "udp_dstport2 50084\n"); } + + test_valid_port("udp_dstport2", {}, -1, PUT); + test_valid_port("udp_dstport2", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("udp_dstport2", {"65535"}, -1, PUT, 65536); + } + for (int i = 0; i != det.size(); ++i) { - det.setDestinationUDPPort2(prev_val[i], {i}); + if (prev_val[i] != 0) { + det.setDestinationUDPPort2(prev_val[i], {i}); + } } } else { REQUIRE_THROWS(proxy.Call("udp_dstport2", {}, -1, GET)); @@ -2922,6 +2946,13 @@ TEST_CASE("zmqport", "[.cmd]") { std::to_string(port + i * socketsperdetector) + '\n'); } + test_valid_port("zmqport", {}, -1, PUT); + test_valid_port("zmqport", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("zmqport", {"65535"}, -1, PUT, 65536); + } + if (det_type == defs::JUNGFRAU) { det.setNumberofUDPInterfaces(prev); } @@ -3266,6 +3297,13 @@ TEST_CASE("port", "[.cmd]") { proxy.Call("port", {}, 0, GET, oss); REQUIRE(oss.str() == "port 1942\n"); } + test_valid_port("port", {}, -1, PUT); + test_valid_port("port", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("port", {"65536"}, -1, PUT, 65536); + } + det.setControlPort(prev_val, {0}); } @@ -3283,6 +3321,12 @@ TEST_CASE("stopport", "[.cmd]") { proxy.Call("stopport", {}, 0, GET, oss); REQUIRE(oss.str() == "stopport 1942\n"); } + test_valid_port("stopport", {}, -1, PUT); + test_valid_port("stopport", {}, 0, PUT); + // should fail for the second module + if (det.size() > 1) { + test_valid_port("stopport", {"65536"}, -1, PUT, 65536); + } det.setStopPort(prev_val, {0}); } diff --git a/slsReceiverSoftware/src/ClientInterface.cpp b/slsReceiverSoftware/src/ClientInterface.cpp index 457b46263..de4703db5 100644 --- a/slsReceiverSoftware/src/ClientInterface.cpp +++ b/slsReceiverSoftware/src/ClientInterface.cpp @@ -45,6 +45,7 @@ ClientInterface::ClientInterface(int portNumber) : detType(GOTTHARD), portNumber(portNumber > 0 ? portNumber : DEFAULT_TCP_RX_PORTNO), server(portNumber) { + validatePortNumber(portNumber); functionTable(); parentThreadId = gettid(); tcpThread = diff --git a/slsSupportLib/include/sls/network_utils.h b/slsSupportLib/include/sls/network_utils.h index e0e48c416..6721bc507 100644 --- a/slsSupportLib/include/sls/network_utils.h +++ b/slsSupportLib/include/sls/network_utils.h @@ -88,5 +88,5 @@ IpAddr HostnameToIp(const char *hostname); std::string IpToInterfaceName(const std::string &ip); MacAddr InterfaceNameToMac(const std::string &inf); IpAddr InterfaceNameToIp(const std::string &ifn); - +void validatePortNumber(int port); } // namespace sls diff --git a/slsSupportLib/src/network_utils.cpp b/slsSupportLib/src/network_utils.cpp index 3a44b6fef..13847d4c7 100644 --- a/slsSupportLib/src/network_utils.cpp +++ b/slsSupportLib/src/network_utils.cpp @@ -203,4 +203,13 @@ MacAddr InterfaceNameToMac(const std::string &inf) { return MacAddr(mac); } +void validatePortNumber(int port) { + if (0 >= port || port > std::numeric_limits::max()) { + std::ostringstream oss; + oss << "Invalid port number " << port << ". It must be in range 1 - " + << std::numeric_limits::max(); + throw RuntimeError(oss.str()); + } +} + } // namespace sls