#!/usr/bin/env python3
# *****************************************************************************
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation; either version 2 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# this program; if not, write to the Free Software Foundation, Inc.,
# 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
# Module authors:
#   Alexander Zaft <a.zaft@fz-juelich.de>
#
# *****************************************************************************

"""SEC node autodiscovery tool."""

import argparse
import json
import os
import select
import socket
import sys
from collections import namedtuple
from time import time as currenttime

UDP_PORT = 10767

Answer = namedtuple('Answer',
                    'address, port, equipment_id, firmware, description')


def decode(msg, addr):
    msg = msg.decode('utf-8')
    try:
        data = json.loads(msg)
    except Exception:
        return None
    if not isinstance(data, dict):
        return None
    if data.get('SECoP') != 'node':
        return None
    try:
        eq_id = data['equipment_id']
        fw = data['firmware']
        desc = data['description']
        port = data['port']
    except KeyError:
        return None
    addr, _scanport = addr
    return Answer(addr, port, eq_id, fw, desc)


def print_answer(answer, *, short=False):
    try:
        hostname = socket.gethostbyaddr(answer.address)[0]
        address = hostname
        numeric = f' ({answer.address})'
    except Exception:
        address = answer.address       
        numeric = ''
    if short:
        # NOTE: keep this easily parseable!
        print(f'{answer.equipment_id} {address}:{answer.port}')
        return
    print(f'Found {answer.equipment_id} at {address}{numeric}:')
    print(f'  Port: {answer.port}')
    print(f'  Firmware: {answer.firmware}')
    desc = answer.description.replace('\n', '\n    ')
    print(f'  Node description: {desc}')
    print('-' * 80)


def scan(max_wait=1.0):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    s.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    # send a general broadcast
    try:
        s.sendto(json.dumps(dict(SECoP='discover')).encode('utf-8'),
                 ('255.255.255.255', UDP_PORT))
    except OSError as e:
        print('could not send the broadcast:', e)
        # we still keep listening for self-announcements
    start = currenttime()
    seen = set()
    while currenttime() < start + max_wait:
        res = select.select([s], [], [], 0.1)
        if res[0]:
            try:
                msg, addr = s.recvfrom(1024)
            except socket.error:  # pragma: no cover
                continue
            answer = decode(msg, addr)
            if answer is None:
                continue
            if (answer.address, answer.equipment_id, answer.port) in seen:
                continue
            seen.add((answer.address, answer.equipment_id, answer.port))
            yield answer


def listen(*, short=False):
    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    if os.name == 'nt':
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    else:
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
    s.bind(('0.0.0.0', UDP_PORT))
    while True:
        try:
            msg, addr = s.recvfrom(1024)
        except KeyboardInterrupt:
            break
        answer = decode(msg, addr)
        if answer:
            print_answer(answer, short=short)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-l', '--listen', action='store_true',
                        help='Keep listening after the broadcast.')
    parser.add_argument('-s', '--short', action='store_true',
                        help='Print short info (always on when listen).')
    args = parser.parse_args(sys.argv[1:])
    short = args.listen or args.short
    if not short:
        print('-' * 80)
    for answer in scan():
        print_answer(answer, short=short)
    if args.listen:
        listen(short=short)
