#!/usr/bin/env python3
# -*- coding: utf8 -*-
"""
utilitaire active_rvp sur l'Amon
"""
import sys
import xmlrpc.client
import re
from getpass import getpass
from os import system, sep, unlink, mkdir, listdir
from os.path import isfile, isdir, dirname, join
from shutil import copy
from IPy import IP

from pyeole.ihm import question_ouinon
from pyeole.ansiprint import print_orange, print_red
from creole import config
from creole.client import CreoleClient
from creole.cert import get_subject
from zephir.lib_zephir import TransportEole
try:
    from zephir.zephir_conf.zephir_conf import id_serveur, adresse_zephir
    ZEPHIR = True
except:
    ZEPHIR = False
from arv.config import strongswan_database, ipsec_tmp_path
from arv.lib.util import decrypt_privkey, password_OK, ipsec_restart,\
                         ipsec_running, fill_file, clean_directory
from arv.lib import cmd2 as cmd
from arv.db.strongswandb import (PeerConfigs, PrivateKey, PrivateKeyIdentity,
                                 SharedSecrets, SharedSecretIdentity, Identity,
                                 CertificateIdentity, Certificates, initialise)
from arv.config import ipsec_conf_file, ipsec_include_conf_files

def get_password():
    if isfile(str(join(ipsec_tmp_path, "ipsec.db"))) or isdir(join(ipsec_tmp_path, "CA")):
        password = getpass("")
    else:
        print_red("archive inexistante")
        sys.exit()
    return password


#______________________________________________________________________________
# command line interpreter
class Cli(cmd.Cmd):
    "command line interpreter (abstract base class)"
    def __init__(self):
        cmd.Cmd.__init__(self)
        self.doc_header = "Commandes documentées"
        self.undoc_header = "Commandes non documentées"
        self.ruler = '-'
        self.prompt = "#ActiveRVP> "
        self.intro  = """Console ActiveRVP :
  tapez help (ou "?") pour avoir la liste des commandes,
  tapez help <command> pour avoir une aide sur chaque commande"""

    def emptyline(self):
        print("Type 'exit' to finish with the session or type ? for help.")

#    def postcmd(self, stop, line):
#        stop = True

    def default(self, line):
        print("unknown command prefix")
        print("*** unknown syntax : %s (type 'help' for help for a list of valid commands)"%line)
        self.emptyline()

    def do_exit(self, line):
        """Exits from the console"""
        return True

    def do_quit(self, line):
        return True

    def do_EOF(self, args):
        """Exit on system end of file character"""
        return True #self.do_exit(args)

#    def postloop(self):
#        # put here your post actions
#        print "exit cli"


#_____________________________________________________________________________
# active_rvp command line interpreter
class ActiveRVPCli(Cli):
    """Active RVP Command line interpreter
    """
    global conf_eole
    global module
    conf_eole = CreoleClient().get_creole()
    # Test si on est sur Amon/Amonecole ou pas
    try:
        zone_rvp = conf_eole['type_amon']
        module = "amon"
    except:
        module = "autre"

    def do_init(self, line):
        """Mode d'activation rvp
        """
        if conf_eole['install_rvp'] == 'oui':
            if ZEPHIR:
                mode = self.select("manuel zephir quitter", "choisissez le mode : ")
                if mode == "zephir":
                   self.onecmd("zephir")
                elif mode == "manuel":
                    self.onecmd("manuel")
                else:
                    self.onecmd("exit")
            else:
                print_orange("Le serveur n'est pas enregistré sur un Zéphir")
                rep = question_ouinon("configurer en mode manuel", default='non')
                if rep == 'oui':
                    self.onecmd("manuel")
                else:
                    print("Abandon de la configuration RVP")
            return True
        else:
            print_orange("Le RVP n'est pas activé dans la configuration du serveur.")
            return True

    def do_manuel(self, line):
        """mode manuel"""
        # FIXME
        print_red("copier le contenu de l'archive dans " + ipsec_tmp_path)
        self.onecmd("start_ipsec")
        return True

    def do_zephir(self, line):
        """activation du rvp par zephir
        """
        if not ZEPHIR:
            print_red("serveur non enregistré sur Zéphir")
            return True
        if self.onecmd("get_rvp_from_zephir"):
            self.onecmd("start_ipsec")
        return True

    def do_delete(self, line):
        """Désactivation rvp
        """
        rep = question_ouinon("Etes-vous certain de vouloir désactiver le VPN ?", default='non')
        if rep == 'oui':
            try:
                activer_haute_dispo = conf_eole['activer_haute_dispo']
            except:
                activer_haute_dispo = "non"
            system("service ipsec stop")
            if isfile("/etc/ipsec.conf"):
                unlink("/etc/ipsec.conf")
            if isfile(strongswan_database):
                unlink(strongswan_database)
            if isfile("/usr/share/eole/test-rvp"):
                unlink("/usr/share/eole/test-rvp")
            if isfile("/etc/ipsec.secrets"):
                unlink("/etc/ipsec.secrets")
            for f in listdir("/etc/ipsec.d/cacerts"):
                unlink("/etc/ipsec.d/cacerts/"+f)
            for f in listdir("/etc/ipsec.d/certs"):
                unlink("/etc/ipsec.d/certs/"+f)
            for f in listdir("/etc/ipsec.d/private"):
                unlink("/etc/ipsec.d/private/"+f)
        else:
            print("Abandon de la suppression du RVP")
        return True

    def do_get_rvp_from_zephir(self, line):
        """récupération de l'archive rvp sur le Zephir"""

        #login, passwd, id_sphynx, tmp_ipsec_path
        login = input("Entrez le login Zephir : ")
        passwd = getpass("Entrez le mot de passe Zephir : ")
        id_sphynx = input("Entrez l'identifiant Zephir du serveur ARV (Sphynx) : ")
        if not isdir(dirname(ipsec_tmp_path)):
            mkdir(dirname(ipsec_tmp_path))
        if not isdir(ipsec_tmp_path):
            mkdir(ipsec_tmp_path)
        # remove temporaries files
        clean_directory(ipsec_tmp_path)
        zephir_proxy = xmlrpc.client.ServerProxy("https://%s:%s@%s:7080" % (login, passwd, adresse_zephir), transport=TransportEole())
        try:
            # récupération de l'archive de configuration vpn
            ret, contenu_b64 = zephir_proxy.uucp.sphynx_get(id_sphynx, id_serveur)
        except xmlrpc.client.ProtocolError:
            print("Erreur d'authentification zephir !")
            return False
        if ret:
            contenu = xmlrpc.client.base64.decodebytes(contenu_b64.data)
            # écriture du fichier tar
            archive = join(ipsec_tmp_path, 'vpn_{0}.tar.gz'.format(id_sphynx))
            f=open(archive,"wb")
            f.write(contenu)
            f.close()
            # décompression de l'archive
            system("/bin/tar xzf {0} -C {1}".format(archive,ipsec_tmp_path))
            unlink(archive)
            print("#->archive récupérée")
            return True
        else:
            print("Erreur :", contenu_b64)
            return False

    def do_generate_passthrough(self, line):
        """
        Generate passthrough connections based on ipsec configuration and local networks
        """

        # Passthrough connection
        # Example :
        # ipsec_conf_passthrough_connection.format(\
        #               "10.1.1.0/24",\
        #               "10.1.2.0/24, 10.1.3.0/24")
        ipsec_conf_passthrough_connection = """#DEB:passthrough-{0}-{1}
conn "passthrough-{0}-{1}"
    leftsubnet = {0}
    rightsubnet = {1}
    type = passthrough
    auto = route
#FIN:passthrough-{0}-{1}

"""

        if not isdir(dirname(ipsec_include_conf_files)):
            mkdir(dirname(ipsec_include_conf_files))
        passthrough_vertices = {}
        conn_pattern = re.compile(r'^conn (.*)')
        leftsubnet_pattern = re.compile(r'^ *leftsubnet = (.*)')
        rightsubnet_pattern = re.compile(r'^ *rightsubnet = (.*)')
        rightsourceip_pattern = re.compile(r'^ *rightsourceip = .*')

        # Put each left and right subnet in passthrough dictionnary
        # left subnet is the key
        with open("/etc/ipsec.conf", "r") as ipsec_conf_file:
            leftsubnets = ""
            rightsubnets = ""
            for conn_line in ipsec_conf_file:
                conn_match = conn_pattern.match(conn_line)
                if conn_match:
                    # First connection found
                    # Exit the loop and continue the find left and right subnet for each connection
                    break

            for leftsubnet_line in ipsec_conf_file:
                leftsubnet_match = leftsubnet_pattern.match(leftsubnet_line)
                if leftsubnet_match:
                    # left subnet found
                    # memorize leftsubnet match
                    leftsubnet = leftsubnet_match.group(1)
                    for right_line in ipsec_conf_file:
                        rightsubnet_match = rightsubnet_pattern.match(right_line)
                        if rightsubnet_match:
                            # rightsubnet found
                            # memorize rightsubnet match
                            rightsubnet = rightsubnet_match.group(1)
                            # Set values into passthrough dictionnary
                            for localsubnet in leftsubnet.replace('"','').split(','):
                                for remotesubnet in rightsubnet.replace('"','').split(','):
                                    if IP(localsubnet) not in passthrough_vertices:
                                        passthrough_vertices[IP(localsubnet)] = [IP(remotesubnet)]
                                    else:
                                        passthrough_vertices[IP(localsubnet)].append(IP(remotesubnet))
                            # Exit loop to find a new connection
                            break
                        else:
                            rightsourceip_match = rightsourceip_pattern.match(right_line)
                            if rightsourceip_match:
                                leftsubnets = ""
                                rightsubnets = ""
                                # left subnet and rightsourceip found
                                # It's a roadwarrior, no passthrough needed
                                break
                            else:
                                conn_match = conn_pattern.match(right_line)
                                if conn_match:
                                    # Found a new connection but no right subnet
                                    leftsubnets = ""
                                    rightsubnets = ""
                                    # left subnet found but not right subnet found
                                    # Exit loop to find a left subnet line
                                    break
                else:
                    conn_match = conn_pattern.match(leftsubnet_line)
                    if conn_match:
                        # No left subnet found but found a new connection in lesftsubnet loop
                        # Continue in this loop to find left subnet for the new connection
                        leftsubnets = ""
                        rightsubnets = ""

        # Add passthrough connection for each local network alias, vlan and local routes when needed
        local_subnets = []
        for no_int in range(int(conf_eole['nombre_interfaces'])):
            local_subnets.append(IP(conf_eole['adresse_network_eth' + str(no_int)] + "/" + conf_eole['adresse_netmask_eth' + str(no_int)]))
            if conf_eole['alias_eth' + str(no_int)] == "oui":
                network_var = 'alias_ip_eth' + str(no_int) + '.' + 'alias_network_eth' + str(no_int)
                netmask_var = 'alias_ip_eth' + str(no_int) + '.' + 'alias_netmask_eth' + str(no_int)
                network = conf_eole[network_var]
                netmask = conf_eole[netmask_var]
                for i in range(len(network)):
                    local_subnets.append(IP(network[i] + "/" + netmask[i]))
            if conf_eole['vlan_eth' + str(no_int)] == "oui":
                network_var = 'vlan_id_eth' + str(no_int) + '.' + 'vlan_network_eth' + str(no_int)
                netmask_var = 'vlan_id_eth' + str(no_int) + '.' + 'vlan_netmask_eth' + str(no_int)
                network = conf_eole[network_var]
                netmask = conf_eole[netmask_var]
                for i in range(len(network)):
                    local_subnets.append(IP(network[i] + "/" + netmask[i]))
        if conf_eole['activer_route'] == "oui":
            network_var = 'route_adresse'
            netmask_var = 'route_adresse' + '.' + 'route_netmask'
            route_in_vpn_var = 'route_adresse' + '.' + 'route_in_vpn'
            network = conf_eole[network_var]
            netmask = conf_eole[netmask_var]
            route_in_vpn = conf_eole[route_in_vpn_var]
            for i in range(len(network)):
                if route_in_vpn[i] == "non":
                    local_subnets.append(IP(network[i] + "/" + netmask[i]))
        passthrough_connections = []
        for local_subnet in local_subnets:
            for local_vertex in passthrough_vertices:
                if local_subnet in local_vertex:
                    for remote_subnet in passthrough_vertices[local_vertex]:
                        for local_dst_subnet in local_subnets:
                            if local_dst_subnet in remote_subnet:
                                if (local_subnet, local_dst_subnet) not in passthrough_vertices:
                                    passthrough_connections.append((local_subnet, local_dst_subnet))
                                if (local_dst_subnet, local_subnet) not in passthrough_connections:
                                    passthrough_connections.append((local_dst_subnet, local_subnet))
        with open(dirname(ipsec_include_conf_files) + sep + 'passthrough', "w") as passthrough:
            for local_subnet, local_dst_subnet in passthrough_connections:
                passthrough.write(ipsec_conf_passthrough_connection.format(local_subnet, local_dst_subnet))
        return True

    def do_start_ipsec(self, line):
        """Mise en place de la configuration ipsec"""
        try:
            if conf_eole['sw_database_mode'] == 'oui':
                # On déchiffre les clefs privées
                session = initialise("sqlite:///"+join(ipsec_tmp_path, "ipsec.db"))
                for privkey, ident in session.query(PrivateKey, Identity).\
                    filter(PrivateKey.id == PrivateKeyIdentity.private_key).\
                    filter(PeerConfigs.local_id==PrivateKeyIdentity.identity).\
                    filter(PeerConfigs.local_id == Identity.id).\
                    filter(Identity.type==9).all():

                    certif, certident = session.query(Certificates.data,
                        CertificateIdentity.certificate).\
                        filter(Certificates.id == CertificateIdentity.certificate).\
                        filter(CertificateIdentity.identity == ident.id).first()

                    certname = str(get_subject(cert=certif)[1], 'utf-8')
                    print()
                    print("Passphrase du certificat {0} : ".format(certname), end=' ')
                    tmpkey = '/tmp/pk.pem'
                    privatekey_file = open(tmpkey, "wb")
                    privatekey_file.write(privkey.data)
                    privatekey_file.close()
                    passwd = get_password()
                    tried = 1
                    while not password_OK(tmpkey, passwd) and tried < 3:
                        print("erreur passphrase, veuillez réessayer : ", end=' ')
                        passwd = get_password()
                        tried +=1
                    if tried == 3 and not password_OK(tmpkey, passwd):
                        print("Passphrase erronée, trop de tentatives !!!")
                        unlink(tmpkey)
                        exit(1)
                    unlink(tmpkey)
                    decrypted_privkey = decrypt_privkey(privkey.data, passwd)
                    privkey.data = decrypted_privkey
                    session.commit()
                session.flush()
                # Remplace le contenu de /etc/ipsec.d/ipsec.db par celui de l'archive
                fill_file(join(ipsec_tmp_path, "ipsec.db"), strongswan_database)
            else:
                # Mode fichier plat
                # On déchiffre les clefs privées
                ipsec_tmp_include_conf_path = join(ipsec_tmp_path, "ipsec.d/conf")
                ipsec_tmp_private_path = join(ipsec_tmp_path, "ipsec.d/private")
                ipsec_tmp_cacerts_path = join(ipsec_tmp_path, "ipsec.d/cacerts")
                ipsec_tmp_certs_path = join(ipsec_tmp_path, "ipsec.d/certs")
                privkey_files = listdir(ipsec_tmp_private_path)
                for privkey_filename in privkey_files:
                    certname = privkey_filename.replace('priv', '').replace('.pem', '')
                    copy(join(ipsec_tmp_private_path, privkey_filename), "/etc/ipsec.d/private/")
                    dst_privkey_filename = join("/etc/ipsec.d/private", privkey_filename)
                    print()
                    passwd = getpass("Passphrase du certificat {0} : ".format(certname))
                    tried = 1
                    while not password_OK(dst_privkey_filename, passwd) and tried < 3:
                        passwd = getpass("erreur passphrase, veuillez réessayer : ")
                        tried +=1
                    if tried == 3 and not password_OK(dst_privkey_filename, passwd):
                        print("Passphrase erronée, trop de tentatives !!!")
                        exit(1)
                    privkey_handler = open(dst_privkey_filename, "r")
                    privkey_string = privkey_handler.read()
                    privkey_handler.close()
                    privkey_handler = open(dst_privkey_filename, "wb")
                    privkey_handler.write(decrypt_privkey(privkey_string, passwd))
                    privkey_handler.close()
                copy(join(ipsec_tmp_path, "ipsec.secrets"), "/etc/ipsec.secrets")
                for ca_filename in listdir(ipsec_tmp_cacerts_path):
                    copy(join(ipsec_tmp_cacerts_path, ca_filename), "/etc/ipsec.d/cacerts/")
                for cert_filename in listdir(ipsec_tmp_certs_path):
                    copy(join(ipsec_tmp_certs_path, cert_filename), "/etc/ipsec.d/certs/")
                if isdir(ipsec_tmp_include_conf_path):
                    for include_conf_filename in listdir(ipsec_tmp_include_conf_path):
                        copy(join(ipsec_tmp_include_conf_path, include_conf_filename), dirname(ipsec_include_conf_files))

            if isfile(str(join(ipsec_tmp_path, "ipsec.conf"))):
                copy(join(ipsec_tmp_path, "ipsec.conf"), "/etc/ipsec.conf")
            else:
                print("Cette archive est générée par un serveur ARV non à jour !!!")
                print("Fichier ipsec.conf absent de l'archive. Création d'un fichier /etc/ipsec.conf minimum pour le mode VPN database.")
                ipsec_conf_file = open("/etc/ipsec.conf", "w")
                ipsec_conf_content = """config setup
    uniqueids = yes
    cachecrls = yes
    strictcrlpolicy = no
"""
                ipsec_conf_file.write(ipsec_conf_content)
                ipsec_conf_file.close()

            self.onecmd("generate_passthrough")
            if module == "amon":
                copy(join(ipsec_tmp_path, "test-rvp"), "/usr/share/eole/test-rvp")
                ipsec_restart()
            else:
                if not ipsec_running():
                    ipsec_restart()
                else:
                    system("""ipsec rereadall
                              ipsec update""")
            return True
        except Exception as err:
            print_red(f"RVP non configuré : {err}")
            self.onecmd("exit")
#_____________________________________________________________________________
# launch interpreter
cmd = ActiveRVPCli()
cmd.cmdloop()

