#!/usr/local/bin/python
# @(#) $Id$ (LBL)
"""Communicate with acld

   Copyright (c) 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2023
   The Regents of the University of California. All rights reserved.

   Redistribution and use in source and binary forms, with or without
   modification, are permitted provided that the following conditions are met:
       * Redistributions of source code must retain the above copyright
         notice, this list of conditions and the following disclaimer.
       * Redistributions in binary form must reproduce the above copyright
         notice, this list of conditions and the following disclaimer in the
         documentation and/or other materials provided with the distribution.
       * Neither the name of the University nor the names of its contributors
         may be used to endorse or promote products derived from this software
         without specific prior written permission.

   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
   AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
   IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
   ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
   FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
   DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
   OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
   HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
   LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
   OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
   SUCH DAMAGE.
"""

import errno
import os
import select
import socket
import time

DEFAULTINFLIGHT = 100
DEFAULTTIMEOUT = 90
DEFAULTADDR = '127.0.0.1'
DEFAULTPORT = 1965

FAILED_SUFFIX = '-failed'

FILE_CMDS = (
    'addmacwhitelist',
    'addwhitelist',
    'drop',
    'filter',
    'nofilter',
    'nonullzero',
    'nullzero',
    'query',
    'queryfilter',
    'querymacwhitelist',
    'querynullzero',
    'querywhitelist',
    'remmacwhitelist',
    'remwhitelist',
    'restore',
)

CHECK_MAP = {
    'filter': 'queryfilter',
    'nullzero': 'querynullzero',
}

DRYRUN_MAP = {
    'drop': 'query',
    'filter': 'queryfilter',
    'nofilter': 'queryfilter',
    'nonullzero': 'querynullzero',
    'nullzero': 'querynullzero',
}

SINGLE_CMDS = (
    'listmac',
    'listnullzero',
    'listroute',
    'macwhitelist',
    'state',
    'whitelist',
)

DOUBLE_CMDS = (
    'blockmac',
    'compact',
    'listacl',
    'listprefix',
    'querymac',
    'restoremac',
)

TRIPLE_CMDS = (
    'blockhosthost',
    'droptcpport',
    'droptcpsynport',
    'dropudpport',
    'noprefix',
    'permittcpdsthostport',
    'permitudpdsthostport',
    'prefix',
    'restorehosthost',
    'restoretcpport',
    'restoretcpsynport',
    'restoreudpport',
    'unpermittcpdsthostport',
    'unpermitudpdsthostport',
)

READONLY_CMDS = (
    'listacl',
    'listmac',
    'listnullzero',
    'listroute',
    'query',
    'queryfilter',
    'querymac',
    'querymacwhitelist',
    'querynullzero',
    'querywhitelist',
    'state',
    'whitelist',
)

def commentrequired(command):
    """Return True if a comment is required for this command"""
    if command in ('compact', 'filter', 'nullzero', 'prefix'):
        return True
    for prefix in ('add', 'block', 'drop'):
        if command.startswith(prefix):
            return True
    return False

class acld:
    """Communicate with acld"""
    def __init__(self, check=False, user=None, host=None, debug=False,
        dryrun=False, inflight=DEFAULTINFLIGHT, rate=None,
        timeout=DEFAULTTIMEOUT, debugfile=None):
        """Create an acld instance"""
        self.ibuf = ''
        if not user:
            self.user = os.getlogin()
        else:
            self.user = user
        if not host:
            self.host = socket.gethostname()
        else:
            self.host = host
        self.ident = f'{self.user}@{self.host}'
        self.check = check
        self.debug = debug
        self.dryrun = dryrun
        self.inflight = inflight
        self.nsent = 0
        self.rate = rate
        self.startts = None
        self.timeout = timeout
        self.prefertimeout = 0.250
        self.debugfile = debugfile

        self.s = None
        self.cookie = 0
        self.inflightrequests = []

    def close(self):
        """Close the acld socket"""
        if self.s:
            # Send the exit
            scookie = str(self.cookie)
            self.cookie += 1
            cmds = ['exit', scookie]
            cmdline = ' '.join(cmds)
            tup = [cmdline]
            self.sendline(tup)
            self.inflightrequests.append(cmds[0:2])
            _tret, _last_tup, _payload = self.getresponse()

            # Close the socket
            self.s.close()
            self.s = None

    def getacldline(self):
        """Get a line from the acld socket and an optional error message"""
        while True:
            tup = self.ibuf.split('\r\n', 1)
            if len(tup) > 1:
                self.ibuf = tup[1]
                return tup[0], None

            # Read more
            rlist = [self.s]
            wlist = []
            xlist = []
            timeout = self.timeout
            try:
                rlist, wlist, xlist = select.select(rlist, wlist, xlist,
                    timeout)
            except socket.error as e:
                errmsg = f'getacldline: select: {e.strerror}'
                return None, errmsg

            # Check for timeout
            if not rlist and not wlist and not xlist:
                # XXX
                return (None,
                    f'getacldline: acld read timeout ({timeout:.1f} seconds)')

            if self.s in rlist:
                try:
                    buf = self.s.recv(1024)
                except IOError as e:
                    errmsg = f'getacldline: {e.strerror}'
                    return None, errmsg
                buf = buf.decode()
                cc = len(buf)
                if cc > 0:
                    # XXX
                    if self.debugfile:
                        for line in buf.rstrip('\r\n').split('\r\n'):
                            print(f'<= {line}', file=self.debugfile)
                    self.ibuf += buf

    def getresponse(self):
        """Reap response, return status, last_tup and payload
           If last_tup is empty, the payload is an error message"""
        ret = 0
        line, errmsg = self.getacldline()
        if errmsg:
            return 1, [], errmsg

        tup = line.split()
        n = len(tup)

        if n < 3 or n > 4:
            errmsg = f'getresponse: Invalid response: {line}'
            return 1, [], errmsg
        cookie = tup[1]

        if not self.inflightrequests:
            return 1, [], 'Missing inflightrequests!'

        if cookie == self.inflightrequests[0][1]:
            last_tup = self.inflightrequests.pop(0)
        else:
            # Find the index of the request with this cookie
            try:
                i = [x[1] for x in self.inflightrequests].index(cookie)
                last_tup = self.inflightrequests.pop(i)
            except ValueError:
                last_tup = []

            if not last_tup:
                errmsg = f'getresponse: Unexpected cookie: {cookie}'
                if n == 4:
                    while True:
                        line, errmsg2 = self.getacldline()
                        if errmsg2:
                            errmsg = '\n'.join([errmsg, errmsg2])
                        if line == '.':
                            break
                return 1, [], errmsg

        extended_what = tup[2]

        # Read optional payload
        payloads = []
        if n == 4:
            if tup[3] != '-':
                errmsg = f'getresponse: Missing continuation: {line}'
                return 1, [], errmsg

            while True:
                line, errmsg = self.getacldline()
                if line == '.':
                    break
                payloads.append(line)
                if errmsg:
                    payloads.append(line)

        if extended_what.endswith(FAILED_SUFFIX):
            extended_what = extended_what[:-len(FAILED_SUFFIX)]
            ret = 1

        # Don't check what against extended_what, latter might be "unknown"

        if payloads:
            payload = '\n'.join(payloads)
        else:
            payload = None

        return ret, last_tup, payload

    def makeconnection(self, servers):
        """Takes list of server address and port pairs and returns
           a socket of the first server to connect and an optional
           errno and error message"""
        # List of all connecting sockets
        sockets = []

        # Corresponding list of socket identification strings
        idents = []

        # All error messages
        errmsgs = []

        # All non-zero errno's
        errnums = []

        # Keep track of the first configured server socket
        first_s = None
        for i, server in enumerate(servers):
            addr, port = server
            ident = f'{addr}.{int(port)}'
            if '.' in addr:
                fam = socket.AF_INET
            elif ':' in addr:
                fam = socket.AF_INET6
            else:
                raise ValueError(f'Invalid ip address: {addr}')
            s = socket.socket(fam, socket.SOCK_STREAM)
            # Turn non-blocking off for connect
            s.setblocking(0)
            code = s.connect_ex(tuple(server))
            s.setblocking(1)

            if code not in (0, errno.EINPROGRESS):
                errmsg = f'{ident}: {os.strerror(code)}'
                errmsgs.append(errmsg)
                errnums.append(code)
                s.close()
                continue

            sockets.append(s)
            idents.append(ident)
            if not first_s:
                first_s = s

        # Loop on select() until we connect or timeout
        ident = None
        s = None
        ident2 = None
        s2 = None
        timeout = float(self.timeout)
        while not s and timeout > 0.0 and sockets:
            rlist = sockets[:]
            ts = time.time()
            rlist, _wlist, _xlist = select.select(rlist, [], [], timeout)
            timeout -= time.time() - ts
            timeout = max(timeout, 0.0)
            if not rlist:
                errmsgs.append('timeout')
                errnums.append(errno.ETIMEDOUT)
                continue

            # Check for connect errors, remove failed candidates
            for i, s3 in enumerate(rlist):
                code = s3.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
                if code != 0:
                    rlist[i] = None
                    j = sockets.index(s3)
                    ident = idents[j]
                    errmsg = f'{ident}: {os.strerror(code)}'
                    errmsgs.append(errmsg)
                    errnums.append(code)
                    sockets[j].close()
                    sockets.pop(j)
                    idents.pop(j)
                    continue

            # Collapse down to read-ready descriptors
            rlist = [_f for _f in rlist if _f]

            # Prefer the first configured server
            if rlist:
                # Wait a bit if this isn't the first server
                i = sockets.index(rlist[0])
                s2 = rlist[0]
                i = sockets.index(s2)
                ident2 = idents.pop(i)
                sockets.pop(i)

                # Done if this is the first configured server
                # or no more sockets left
                if s2 == first_s or not sockets:
                    s = s2
                    s2 = None
                    ident = ident2
                    ident2 = None
                    break

                # Wait a bit longer
                if timeout > self.prefertimeout:
                    timeout = self.prefertimeout

        # Clean up leftover sockets
        for s3 in sockets:
            s3.close()

        # Use the second configured server?
        if not s and s2:
            s = s2
            ident = ident2
            s2 = None

        # Clean up the second configured server?
        if s2:
            s2.close()

        # If we succeeded just ignore errors accumulated along the way
        if s:
            return s, 0, None

        # Not sure if this can happen
        if not errmsgs:
            errmsgs.append('failed to open any')
            errnums.append(errno.EDOM)

        # Return the lowest errno collected and all of the error messages
        errnums.sort()
        return None, errnums[0], ', '.join(errmsgs)

    def open(self, addr=None, port=None, servers=None):
        """Open a socket to acld, return an optional error message
           With no arguments, opens DEFAULTADDR/DEFAULTPORT
           With addr or port argument, fill in from defaults
           With servers argument only first addr/port pair to connect"""
        # Forwards compatibility allowing servers as only argument
        if isinstance(addr, list):
            if port:
                raise ValueError('addr can not be a list')
            if not servers:
                servers = addr
                addr = None

        if addr or port:
            # Disallow addr/port and servers
            if servers:
                raise ValueError('Cannot specify addr or port and servers')

            # Fill in ports for single server connect
            if not addr:
                addr = DEFAULTADDR
            if not port:
                port = DEFAULTPORT
            servers = [[addr, port]]
        elif not servers:
            if not addr:
                addr = DEFAULTADDR
            if not port:
                port = DEFAULTPORT
            # No args backwards compatibility
            servers = [[addr, port]]
        numservers = len(servers)

        tries = 3
        delay = 0.5
        errmsg = None
        while tries > 0:
            errmsg = None
            errnum = -1
            tries -= 1

            # Try to make a connection
            self.s, errnum, errmsg = self.makeconnection(servers)

            # Retry for some errors
            # EPERM means the source port was in use
            if errmsg and errnum in (errno.ECONNREFUSED, errno.EPERM):
                if numservers > 1 and tries > 0:
                    time.sleep(delay)
                    delay = delay * 2
                    continue
                return errmsg
            break

        if errmsg:
            return errmsg
        self.s.setblocking(0)

        self.cookie = 0
        self.inflightrequests = [['acld', '0']]
        ret, _last_tup, payload = self.getresponse()
        if ret:
            errmsgs = ['bad greeting from acld']
            if payload:
                errmsgs.append(payload)
            self.s.close()
            self.s = None
            return ': '.join(errmsgs)

        return None

    def process(self, command, items, last=None, comments=None,
        callback=None, data=None):
        """Process list of items (addresses or ports)
           If we're checking, always issue the check command first
           When we get a check result, optionally issue the real command
           Return a return status, a list of messages and an optional
           list of error messages
           If callback is provided, it is called with the result
           of each acld transaction with "data" as an argument and
           and empty payload (msgs) will be returned
           Optionally append "last" to each command to allow multiple
           port blocks and host port exceptions per commit"""
        msgs = []
        errmsgs = []
        self.inflightrequests = []
        ret = 0

        # Handle check and dryrun modes
        detail = '?'
        if self.dryrun:
            command2 = DRYRUN_MAP.get(command)
            detail = 'dry run'
        elif self.check:
            command2 = CHECK_MAP.get(command)
            detail = 'check'
        else:
            command2 = None

        if (self.dryrun or self.check) and not command2:
            errmsg = f"Can't find {detail} command for {command}"
            errmsgs.append(errmsg)
            return 1, msgs, errmsgs

        # Loop on items from the list
        self.cookie = 0
        for item in items:
            scookie = str(self.cookie)
            self.cookie += 1

            # Build the next command
            if command2:
                cmds = [command2]
            else:
                cmds = [command]
            cmds += [scookie, item]

            if last:
                cmds.append(last)
            cmds.append('-')
            cmdline = ' '.join(cmds)
            tup = [cmdline, self.ident]
            if comments:
                tup.extend(comments)
            tup.append('.')

            # Send the command
            self.sendline(tup)
            self.inflightrequests.append(cmds[0:3])

            # Reap
            if len(self.inflightrequests) >= self.inflight:
                tret, last_tup, payload = self.getresponse()

                # If we didn't get a tuple, the payload is an error message
                if not last_tup:
                    errmsg = payload
                    return 1, [], [errmsg]

                # Is this a check result?
                # XXX
                # What about filter?
                # XXX
                if command2 and last_tup[0] == command2:
                    if self.dryrun:
                        if payload and payload.startswith('N '):
                            fmt = 'Note: %s is already nullzero routed'
                            msg = fmt % last_tup[2]
                            msgs.append(msg)
                        continue

                    if tret:
                        scookie = str(self.cookie)
                        self.cookie += 1
                        item = last_tup[2]
                        cmds = [command, scookie, item, '-']
                        cmdline = ' '.join(cmds)

                        # Send the command
                        tup = [cmdline, self.ident]
                        if comments:
                            tup.extend(comments)
                        tup.append('.')
                        self.sendline(tup)
                        self.inflightrequests.append(cmds[0:3])
                    continue

                # Command result
                if tret:
                    ret = 1
                    # Append item
                    if len(last_tup) > 2 and last_tup[2] not in payload:
                        payload = ' '.join([payload, last_tup[2]])
                if payload:
                    if callback:
                        callback(payload, data)
                    else:
                        msgs.append(payload)

        # Reap the stragglers
        while self.inflightrequests:
            tret, last_tup, payload = self.getresponse()

            # If we didn't get a tuple, the payload is an error message
            if not last_tup:
                errmsg = payload
                return 1, [], [errmsg]

            # Is this a check result?
            if command2 and last_tup[0] == command2:
                if self.dryrun:
                    # What about filter?
                    # XXX
                    if payload and payload.startswith('N '):
                        msg = f'Note: {last_tup[2]} is already nullzero routed'
                        msgs.append(msg)
                    continue

                if tret:
                    scookie = str(self.cookie)
                    self.cookie += 1
                    item = last_tup[2]
                    cmds = [command, scookie, item, '-']
                    cmdline = ' '.join(cmds)

                    # Send the command
                    tup = [cmdline, self.ident]
                    if comments:
                        tup.extend(comments)
                    tup.append('.')
                    self.sendline(tup)
                    self.inflightrequests.append(cmds[0:3])
                continue

            # Command result
            if tret:
                ret = 1
                # Append item
                if len(last_tup) > 2 and last_tup[2] not in payload:
                    payload = ' '.join([payload, last_tup[2]])
            if payload:
                if callback:
                    callback(payload, data)
                else:
                    msgs.append(payload)

        return ret, msgs, errmsgs

    def processone(self, args, comments=None):
        """Process a single command
           Return a return status, a list of messages and an optional
           list of error messages"""
        msgs = []
        errmsgs = []
        self.cookie = 0
        ret = 0
        command = args[0]

        # Handle check and dryrun modes
        detail = '?'
        if self.dryrun:
            command2 = DRYRUN_MAP.get(command)
            detail = 'dryrun'
        elif self.check:
            command2 = CHECK_MAP.get(command)
            detail = 'check'
        else:
            command2 = None

        if (self.dryrun or self.check) and not command2:
            errmsg = f"Can't find {detail} command for {command}"
            errmsgs.append(errmsg)
            return 1, msgs, errmsgs

        # Run the check command
        if command2:
            # Build the command
            scookie = str(self.cookie)
            self.cookie += 1
            cmds = [command2, scookie]
            cmds += args[1:]
            cmdline = ' '.join(cmds)

            # Send the command
            tup = [cmdline]
            self.sendline(tup)
            self.inflightrequests = [cmds]
            tret, last_tup, payload = self.getresponse()

            # If we didn't get a tuple, the payload is an error message
            if not last_tup:
                errmsg = payload
                return 1, [], [errmsg]

            if self.dryrun:
                if command in ('nofilter', 'nonullzero'):
                    # Return same queryfilter/querynullzero result
                    return tret, msgs, errmsgs

                # filter/nullzero cases
                if tret == 0:
                    # report the filter/nullzero already exists
                    return 2, msgs, errmsgs
                # pretend we added a filter/nullzero
                return 0, msgs, errmsgs

            # Normal case
            if tret == 0:
                msg = f"Skipping {' '.join(args)}: ({command} is yes)"
                msgs.append(msg)
                return 2, msgs, errmsgs

        # Build the command
        cmds = args[:]
        self.inflightrequests = [cmds]
        scookie = str(self.cookie)
        self.cookie += 1
        cmds.insert(1, scookie)
        cmds.append('-')
        cmdline = ' '.join(cmds)

        if self.dryrun:
            msg = f'+ {cmdline}'
            msgs.append(msg)
        else:
            # Send the command
            tup = [cmdline, self.ident]
            if comments:
                tup.extend(comments)
            tup.append('.')
            self.sendline(tup)
            ret, last_tup, payload = self.getresponse()

            # If we didn't get a tuple, the payload is an error message
            if not last_tup:
                errmsg = payload
                return 1, [], [errmsg]

            present = False
            notpresent = False
            if payload:
                for line in payload.splitlines():
                    if last_tup[2] == line:
                        present = True
                    if 'not found' in line:
                        notpresent = True

                    # Return the special "already exists" code
                    if 'already' in line:
                        ret = 2
                    msgs.append(line)
            # Only add the inconsistent state message if it is not already there
            if (present and notpresent and
                (not payload or 'inconsistent state' not in payload)):
                errmsg = f'{last_tup[2]} inconsistent state'
                if errmsg not in errmsgs:
                    errmsgs.append(errmsg)
                if ret == 0:
                    ret = 1

        return ret, msgs, errmsgs

    def sendline(self, tup):
        """Add \r\n's to each string in the list and concatenate
           into a string and deal with non blocking I/O"""
        if self.debugfile:
            for line in tup:
                print(f'=> {line}', file=self.debugfile)
        buf = ''.join([x + '\r\n' for x in tup])
        while True:
            rlist = []
            wlist = [self.s]
            xlist = []
            timeout = 1.0
            rlist, wlist, xlist = select.select(rlist, wlist, xlist, timeout)

            if self.s in wlist:
                if self.rate:
                    ts = time.time()
                    if not self.startts:
                        self.startts = ts
                    else:
                        while True:
                            persec = self.nsent // (time.time() - self.startts)
                            if self.rate >= persec:
                                break
                            time.sleep(0.1)
                self.s.send(buf.encode())
                self.nsent += 1
                break
