Improve address and connection handling

* improve address checking
+ add ipv6 capabilities

Change-Id: I5369336bec449c27d79d857018f319266dfd4d0e
Reviewed-on: https://forge.frm2.tum.de/review/c/secop/frappy/+/30885
Tested-by: Jenkins Automated Tests <pedersen+jenkins@frm2.tum.de>
Reviewed-by: Alexander Zaft <a.zaft@fz-juelich.de>
This commit is contained in:
Alexander Zaft
2023-04-12 14:15:52 +02:00
parent a49d64953c
commit f491625dd1
4 changed files with 118 additions and 60 deletions

View File

@ -21,9 +21,9 @@
# *****************************************************************************
"""Define helpers"""
import re
import importlib
import linecache
import re
import socket
import sys
import threading
@ -295,50 +295,68 @@ def formatException(cut=0, exc_info=None, verbose=False):
return ''.join(res)
HOSTNAMEPAT = re.compile(r'[a-z0-9_.-]+$', re.IGNORECASE) # roughly checking for a valid hostname or ip address
HOSTNAMEPART = re.compile(r'^(?!-)[a-z0-9-]{1,63}(?<!-)$', re.IGNORECASE)
def validate_hostname(host):
"""checks if the rules for valid hostnames are adhered to"""
if len(host) > 255:
return False
for part in host.split('.'):
if not HOSTNAMEPART.match(part):
return False
return True
def parseHostPort(host, defaultport):
"""Parse host[:port] string and tuples
def validate_ipv4(addr):
"""check if v4 address is valid."""
try:
socket.inet_aton(addr)
except OSError:
return False
return True
Specify 'host[:port]' or a (host, port) tuple for the mandatory argument.
If the port specification is missing, the value of the defaultport is used.
raises TypeError in case host is neither a string nor an iterable
raises ValueError in other cases of invalid arguments
def validate_ipv6(addr):
"""check if v6 address is valid."""
try:
socket.inet_pton(socket.AF_INET6, addr)
except OSError:
return False
return True
def parse_ipv6_host_and_port(addr, defaultport=10767):
""" Parses IPv6 addresses with optional port. See parse_host_port for valid formats"""
if ']' in addr:
host, port = addr.rsplit(':', 1)
return host[1:-1], int(port)
if '.' in addr:
host, port = addr.rsplit('.', 1)
return host, int(port)
return (host, defaultport)
def parse_host_port(host, defaultport=10767):
"""Parses hostnames and IP (4/6) addressses.
The accepted formats are:
- a standard hostname
- base IPv6 or 4 addresses
- 'hostname:port'
- IPv4 addresses in the form of 'IPv4:port'
- IPv6 addresses in the forms '[IPv6]:port' or 'IPv6.port'
"""
if isinstance(host, str):
host, sep, port = host.partition(':')
if sep:
port = int(port)
else:
port = defaultport
else:
host, port = host
if not HOSTNAMEPAT.match(host):
raise ValueError(f'illegal host name {host!r}')
if 0 < port < 65536:
return host, port
raise ValueError(f'illegal port number: {port!r}')
def tcpSocket(host, defaultport, timeout=None):
"""Helper for opening a TCP client socket to a remote server.
Specify 'host[:port]' or a (host, port) tuple for the mandatory argument.
If the port specification is missing, the value of the defaultport is used.
If timeout is set to a number, the timout of the connection is set to this
number, else the socket stays in blocking mode.
"""
host, port = parseHostPort(host, defaultport)
# open socket and set options
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if timeout:
s.settimeout(timeout)
# connect
s.connect((host, int(port)))
return s
colons = host.count(':')
if colons == 0: # hostname/ipv4 wihtout port
port = defaultport
elif colons == 1: # hostname or ipv4 with port
host, port = host.split(':')
port = int(port)
else: # ipv6
host, port = parse_ipv6_host_and_port(host, defaultport)
if (validate_ipv4(host) or validate_hostname(host) or validate_ipv6(host)) \
and 0 < port < 65536:
return (host, port)
raise ValueError(f'invalid host {host!r} or port {port}')
# keep a reference to socket to avoid (interpreter) shut-down problems

View File

@ -35,7 +35,7 @@ import time
import re
from frappy.errors import CommunicationFailedError, ConfigError
from frappy.lib import closeSocket, parseHostPort, tcpSocket
from frappy.lib import closeSocket, parse_host_port
try:
from serial import Serial
@ -60,7 +60,7 @@ class AsynConn:
if not iocls:
# try tcp, if scheme not given
try:
parseHostPort(uri, 1) # check hostname only
parse_host_port(uri, 1) # check hostname only
except ValueError:
if 'COM' in uri:
raise ValueError("the correct uri for a COM port is: "
@ -175,7 +175,9 @@ class AsynTcp(AsynConn):
if uri.startswith('tcp://'):
uri = uri[6:]
try:
self.connection = tcpSocket(uri, self.default_settings.get('port'), self.timeout)
host, port = parse_host_port(uri, self.default_settings.get('port'))
self.connection = socket.create_connection((host, port), timeout=self.timeout)
except (ConnectionRefusedError, socket.gaierror, socket.timeout) as e:
# indicate that retrying might make sense
raise CommunicationFailedError(str(e)) from None

View File

@ -21,22 +21,22 @@
# *****************************************************************************
"""provides tcp interface to the SECoP Server"""
import errno
import os
import socket
import socketserver
import sys
import threading
import time
import errno
import os
from frappy.datatypes import BoolType, StringType
from frappy.errors import SECoPError
from frappy.lib import formatException, \
formatExtendedStack, formatExtendedTraceback
from frappy.lib import formatException, formatExtendedStack, \
formatExtendedTraceback
from frappy.properties import Property
from frappy.protocol.interface import decode_msg, encode_msg_frame, get_msg
from frappy.protocol.messages import ERRORPREFIX, \
HELPREPLY, HELPREQUEST, HelpMessage
from frappy.protocol.messages import ERRORPREFIX, HELPREPLY, HELPREQUEST, \
HelpMessage
DEF_PORT = 10767
MESSAGE_READ_SIZE = 1024
@ -57,7 +57,7 @@ class TCPRequestHandler(socketserver.BaseRequestHandler):
clientaddr = self.client_address
serverobj = self.server
self.log.info("handling new connection from %s:%d" % clientaddr)
self.log.info("handling new connection from %s", format_address(clientaddr))
data = b''
# notify dispatcher of us
@ -160,7 +160,7 @@ class TCPRequestHandler(socketserver.BaseRequestHandler):
def finish(self):
"""called when handle() terminates, i.e. the socket closed"""
self.log.info('closing connection from %s:%d' % self.client_address)
self.log.info('closing connection from %s', format_address(self.client_address))
# notify dispatcher
self.server.dispatcher.remove_connection(self)
# close socket
@ -171,8 +171,28 @@ class TCPRequestHandler(socketserver.BaseRequestHandler):
finally:
self.request.close()
class DualStackTCPServer(socketserver.ThreadingTCPServer):
"""Subclassed to provide IPv6 capabilities as socketserver only uses IPv4"""
def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, enable_ipv6=False):
super().__init__(
server_address, RequestHandlerClass, bind_and_activate=False)
class TCPServer(socketserver.ThreadingTCPServer):
# override default socket
if enable_ipv6:
self.address_family = socket.AF_INET6
self.socket = socket.socket(self.address_family,
self.socket_type)
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
if bind_and_activate:
try:
self.server_bind()
self.server_activate()
except:
self.server_close()
raise
class TCPServer(DualStackTCPServer):
daemon_threads = True
# on windows, 'reuse_address' means that several servers might listen on
# the same port, on the other hand, a port is not blocked after closing
@ -191,13 +211,16 @@ class TCPServer(socketserver.ThreadingTCPServer):
self.name = name
self.log = logger
port = int(options.pop('uri').split('://', 1)[-1])
enable_ipv6 = options.pop('ipv6', False)
self.detailed_errors = options.pop('detailed_errors', False)
self.log.info("TCPServer %s binding to port %d", name, port)
for ntry in range(5):
try:
socketserver.ThreadingTCPServer.__init__(
self, ('0.0.0.0', port), TCPRequestHandler, bind_and_activate=True)
DualStackTCPServer.__init__(
self, ('', port), TCPRequestHandler,
bind_and_activate=True, enable_ipv6=enable_ipv6
)
break
except OSError as e:
if e.args[0] == errno.EADDRINUSE: # address already in use
@ -217,3 +240,11 @@ class TCPServer(socketserver.ThreadingTCPServer):
def __exit__(self, *args):
self.server_close()
def format_address(addr):
if len(addr) == 2:
return '%s:%d' % addr
address, port = addr[0:2]
if address.startswith('::ffff'):
return '%s:%d' % (address[7:], port)
return '[%s]:%d' % (address, port)

View File

@ -21,21 +21,28 @@
# *****************************************************************************
import pytest
from frappy.lib import parseHostPort
from frappy.lib import parse_host_port
@pytest.mark.parametrize('hostport, defaultport, result', [
(('box.psi.ch', 9999), 1, ('box.psi.ch', 9999)),
(('/dev/tty', 9999), 1, None),
('box.psi.ch:9999', 1, ('box.psi.ch', 9999)),
('/dev/tty:9999', 1, None),
('localhost:10767', 1, ('localhost', 10767)),
('www.psi.ch', 80, ('www.psi.ch', 80)),
('/dev/ttyx:2089', 10767, None),
('COM4:', 2089, None),
('underscore_valid.123.hyphen-valid.com', 80, ('underscore_valid.123.hyphen-valid.com', 80)),
('123.hyphen-valid.com', 80, ('123.hyphen-valid.com', 80)),
('underscore_invalid.123.hyphen-valid.com:10000', 80, None),
('::1.1111', 2, ('::1', 1111)),
('[2e::fe]:1', 50, ('2e::fe', 1)),
('127.0.0.1:50', 1337, ('127.0.0.1', 50)),
('234.40.128.3:13212', 1337, ('234.40.128.3', 13212)),
])
def test_parse_host(hostport, defaultport, result):
if result is None:
with pytest.raises(ValueError):
parseHostPort(hostport, defaultport)
parse_host_port(hostport, defaultport)
else:
assert result == parseHostPort(hostport, defaultport)
assert result == parse_host_port(hostport, defaultport)