#!/usr/bin/env python3
# -*- coding: utf-8 -*-
""" Deploy VMs with Zephir help """

import getpass
import socket
import base64
import json
import subprocess
import tempfile
import time
import xml.etree.ElementTree as XML
import sys
import os
import shutil
import re
from cryptography.fernet import Fernet
#from zephir.lib_zephir import *
from creole.client import CreoleClient
client = CreoleClient()
if client.get_creole('activer_deploiement_automatique') == 'non':
    print("Déploiement automatique désactivé.")
    sys.exit()
from pyeole.ihm import print_line
from pyeole.ansiprint import print_red
try:
    from zephir.zephir_conf.zephir_conf import adresse_zephir
except:
    print_red("Le serveur n'est pas enregistrer sur un serveur Zéphir")
    sys.exit()
from zephir.lib_zephir import convert, xmlrpclib, EoleProxy, TransportEole, flushed_input

VARIABLE_NAME = u'activer_modele_vm'
VARIABLE_VALUE = u'oui'
RNE = client.get_creole(u'zephir_numero_etab')
DPL_ROOT_DIR = u'/etc/eole/hapy-deploy'
KEY_FILE = os.path.join(DPL_ROOT_DIR, ".dpl.sc")
CRD_FILE = os.path.join(DPL_ROOT_DIR, ".zephir.sc")
MODE = client.get_creole('dp_mode')
STATUS_FILE = os.path.join(DPL_ROOT_DIR, ".hapy-deploy.status")
IMG_DS = client.get_creole(u'one_ds_image_name')
PROV_DIR = os.path.join(DPL_ROOT_DIR,"scripts")
ZCREDS_NAME = "zcreds.sc"
ZCA_NAME = "zephir-ca.crt"
TMOUT=360

socket.setdefaulttimeout(TMOUT)

def silent_run(command):
    return subprocess.run(command.split(), capture_output=True)

def crypt(message):
    """Chiffrer une chaine de caractères"""
    key = None
    if not os.path.exists(KEY_FILE):
        key = Fernet.generate_key()
        ckf = open(KEY_FILE, "w")
        ckf.write(key.decode())
        ckf.close()
    else:
        ckf = open(KEY_FILE, 'r')
        key = ckf.readline()
        ckf.close()
    ft = Fernet(key)
    return ft.encrypt(message.encode())

def decrypt(message):
    """Déchiffrer une chaine de caractères"""
    ckf = open(KEY_FILE, 'r')
    key = ckf.readline()
    ckf.close()
    ft = Fernet(key)
    return ft.decrypt(message).decode()

def add_zephir_credential_file(addr, user, passwd):
    ds_id = None

    line = addr + " " + user + " " + passwd

    fp = tempfile.NamedTemporaryFile(delete=False, dir="/var/tmp/one/hapy-deploy")
    fp.write(line.encode())
    fp.close()
    shutil.chown(fp.name, user="oneadmin")

    rs = silent_run("onedatastore list --csv --no-header -f TYPE=fil -l ID")
    if rs.returncode == 0:
        ds_id = rs.stdout.decode().rstrip()

    if ds_id:
        if silent_run(f"oneimage show {ZCREDS_NAME}").returncode == 0:
            silent_run(f"oneimage delete {ZCREDS_NAME}")

        cmd = f"oneimage create --type CONTEXT --datastore {ds_id}"
        cmd += f" --name {ZCREDS_NAME}"
        cmd += f" --path {fp.name}"

        rs = silent_run(cmd)
        if rs.returncode == 0:
            time.sleep(2)
            os.remove(fp.name)
            return True
        else:
            print(rs.stderr.decode() + rs.stdout.decode())
            os.remove(fp.name)
            return False

def get_pwd(addr, port):
    """lecture d'un login/passwd pour l'application zephir
    """
    login_ok = 0
    idx = 0
    user = None
    passwd = None
    from_hst = False
    while login_ok == 0:
        try:
            # flush de l'entrée standard au cas où l'utilisateur aurait
            # tapé <entrée> pendant l'Upgrade
            termios.tcflush(sys.stdin, termios.TCIOFLUSH)
        except:
            pass
        if os.path.exists(CRD_FILE):
            creds = open(CRD_FILE, 'r').read().split()
            user = decrypt(creds[1].encode())
            passwd = decrypt(creds[0].encode())
            from_hst = True
        if not user:
            user = flushed_input("Entrez votre login zephir (rien pour sortir) : ")
        if user != "":
            if not passwd:
                passwd = getpass.getpass("Mot de passe zephir pour %s : " % user)
            # création du proxy avec zephir
            proxy = EoleProxy("https://%s:%s@%s:%s" % (user, passwd,
                                                       addr, port), transport=TransportEole())
            try:
                res = convert(proxy.get_permissions(user))
                login_ok = 1
                if not from_hst:
                    rep = flushed_input("Voulez-vous retenir ces identifiants \
                        pour les prochaines utilisations ? [non] :")
                    if rep.rstrip() == "":
                        pass
                    elif rep.rstrip() == "oui" or rep.rstrip() == "o":
                        crd = open(CRD_FILE, "w")
                        crd.write(crypt(passwd).decode())
                        crd.write("\n")
                        crd.write(crypt(user).decode())
                        crd.close()
            except xmlrpclib.ProtocolError:
                login_ok = 0
                from_hst = False
                user = None
                passwd = None
                print_line("\n Erreur d'authentification \n")
        else:
            return False, "! Abandon de la procédure !"
        idx += 1
        if login_ok == 0 and idx > 9:
            return False, "! Nombre de tentative dépassée !"
    if add_zephir_credential_file(addr,user,passwd):
        return True, proxy
    return False, None

def get_config_from_zephir(id_list, zephir_proxy):
    """ Get the server configuration from zephir server
        writes a file with all the configuration for each server.
    """
    resultat = None

    # Recover ssh key
    fic_cle = open("/var/spool/uucp/.ssh/id_rsa.pub", "r")
    cle_pub = fic_cle.readlines()[0]
    fic_cle.close()

    for id_serveur in id_list:
        raw_conf = zephir_proxy.serveurs.get_config(id_serveur)
        #Write configuraiton file
        dest_dir = os.path.join(DPL_ROOT_DIR, "confs", str(id_serveur))
        try:
            os.makedirs(dest_dir)
        except FileExistsError as err:
            pass
        except:
            e = sys.exc_info()[0]
            print(e)

        dest_file = os.path.join(dest_dir, "config.eol.sc")
        file_conf = open(dest_file, "wb")
        file_conf.write(crypt(json.dumps(raw_conf[1])))
        file_conf.close()
        #return raw_conf
    return True

def get_vms_infos(id_list):
    """ Extract virtual machine informations from the server configuration
    """
    vms = []
    # Get VM Information
    for vm in id_list:
        s_conf_file = os.path.join(DPL_ROOT_DIR, "confs", str(vm), "config.eol.sc")
        s_conf = open(s_conf_file, 'r').read()
        c_conf = decrypt(s_conf.encode())
        cnf = json.loads(c_conf)
        idx = 0
        if "vm_index" in cnf:
            idx = cnf['vm_index']

        vmi = {"name" : cnf['nom_domaine_machine'],
              "id_zephir": str(vm),
              "index": idx,
              "market": cnf['vm_marketplace'],
              "app": cnf['vm_app'],
              "cpu": cnf['vm_cpu'],
              "vcpu": cnf['vm_vcpu'],
              "memory": cnf['vm_memory'],
              "ram": cnf['vm_memory'],
              "disk_size": cnf['vm_disk_size'],
              "net": [],
              "conf": cnf
             }
        for idx in range(int(cnf['nombre_interfaces'])):
            vmi['net'].append({ "name": cnf[f'vm_vnet_name{idx}'],
                                "ip": cnf[f"adresse_ip_eth{idx}"],
                                "mask": cnf[f'adresse_netmask_eth{idx}'],
                                "gw": cnf[f'adresse_ip_gw'],
                                "dns": cnf[f'adresse_ip_dns'].split()[0]})
        vms.append(vmi)
    return vms

def get_total_cpu_ram():
    """ Get the total CPU units and Memory available in the cluster before deploy.
    """
    tcpu = 0
    tram = 0
    # List all hosts
    cmd = "onehost list -l ID --no-header"
    res = silent_run(cmd)
    if res.returncode == 0:
        hstlst = res.stdout.decode().split()
        # Sum CPU and RAM of all hosts
        for hst in hstlst:
            res = silent_run("onehost show {0} -x".format(hst))
            if res.returncode == 0:
                infos = XML.fromstring(res.stdout)
                for inf in infos:
                    if inf.tag == "HOST_SHARE":
                        for attr in inf:
                            if attr.tag == "MAX_CPU":
                                tcpu += int(attr.text)
                            if attr.tag == "MAX_MEM":
                                tram += int(attr.text)
            else:
                return None, None
    else:
        return None, None
    return tcpu, tram

def check_resources(vm_list):
    """ Validates if what is asked is less than what we have
        return True if everyting is in order
    """
    total_cpu = 0
    total_ram = 0
    total_disk = 0

    ds_size = 0

    # Get image datastore size
    ds_info = None
    res = silent_run("onedatastore show {0} -x".format(IMG_DS))
    if res.returncode == 0:
        ds_info = XML.fromstring(res.stdout)

    for inf in ds_info:
        if inf.tag == "TOTAL_MB":
            ds_size = int(int(inf.text)/1024) # Disk size is in Go

    # Get Total CPU and RAM on cluster
    cl_total_cpu, cl_total_ram = get_total_cpu_ram()

    for vm in vm_list:
        total_cpu += int(float(vm["cpu"])*100)
        total_ram += int(vm["ram"])*1024
        if vm["disk_size"]:
            total_disk += int(vm["disk_size"])

    print("Vérification des resources (Demandé/Disponible):")
    print("\tCPU:  {0}/{1} Un".format(total_cpu, cl_total_cpu))
    print("\tRAM:  {0}/{1} MB".format(int(total_ram/1024), int(cl_total_ram/1024)))
    print("\tDISK: {0}/{1} GB".format(total_disk, ds_size))

    if total_cpu < cl_total_cpu and total_ram < cl_total_ram and total_disk < ds_size:
        return True

    return False

def check_vnet(vms):
    """ Check if configured virtual networks in the VM are available in the cluster
        return True if everything is in order
        return False if onevnet command failed
        return False and the network names if any network is missing
    """
    vnets = []
    avnets = []
    missing = []
    res = silent_run("onevnet list -l NAME --no-header")
    if res.returncode == 0:
        vnets = res.stdout.decode().split()
    else:
        return False, None

    for vm in vms:
        for net in vm["net"]:
            avnets.append(net["name"])

    for net in avnets:
        if net not in vnets:
            missing.append(net)

    if len(missing) != 0:
        return False, missing

    return True, None

def import_apps_from_markets(vms):
    """ "Import" appliance from market (download image and base template)
    """

    # Create a list containing the appliances to be imported
    apps = []
    for vm in vms:
        if vm["app"] not in apps:
            apps.append(vm["app"])

    print("Import des appliances :")
    # Import Appliance
    messages = []
    for app in apps:
        res = silent_run(f"onemarketapp export {app} {app} --datastore {IMG_DS}")
        if res.returncode != 0:
            messages.append("{app}:IMPORT_APP:KO")
        else:
            sys.stdout.write(f"   Importing {app} ")
            sys.stdout.flush()
            if re.match(r'^\[.*Error.*NAME is already taken by.*', res.stdout.decode()):
                messages.append(f"{app}:IMPORT_APP:EXISTS")
                sys.stdout.write("\t[EXISTS]\n")
                sys.stdout.flush()
            else:
                if silent_run(f"oneimage show {app}").returncode != 0:
                    print(f"Warning: import of appliance {app} failed !")
                    print(f"   {res.stdout}")
                    messages.append(f"{app}:IMPORT_APP:KO")
                else:
                    st = ""
                    cnt = 0
                    while st != "rdy" or cnt == TMOUT:
                        st = silent_run(f"oneimage list -f NAME={app} -l STAT --no-header").stdout.decode().strip()
                        cnt += 1
                    if cnt != TMOUT:
                        messages.append(f"{app}:IMPORT_APP:OK")
                        sys.stdout.write("\t[OK]\n")
                    else:
                        messages.append(f"{app}:IMPORT_APP:TIMEOUT")
                        sys.stdout.write("\t[TIMEOUT]\n")

    print("Création des modèles de machines virtuelles:")
    for vm in vms:
        sys.stdout.write(f"   {vm['name']} ")
        sys.stdout.flush()
        cl = silent_run(f"onetemplate clone --recursive {vm['app']} {vm['name']}")
        if cl.returncode != 0:
            messages.append(f"{vm['name']}:TEMPLATE_CLONE:KO")
            sys.stdout.write(f"\t[KO]\n")
            sys.stdout.flush()
        else:
            messages.append(f"{vm['name']}:TEMPLATE_CLONE:OK")
            sys.stdout.write(f"\t[OK]\n")
            sys.stdout.flush()
    fd = open(STATUS_FILE, "a")
    for msg in messages:
        fd.write(msg + "\n")
    fd.close()
    return True

def set_context(vm_id, vnets):
    files = {}
    init_files = {}
    res = silent_run("oneimage list -f TYPE=CX -l USER,NAME --no-header --csv")
    if res.returncode == 0:
        for fle in res.stdout.decode().split():
            owner=fle.split(",")[0]
            name=fle.split(",")[1]
            if re.match(r'^[0-9][0-9]_',name):
                init_files[name] = owner
            if name in (ZCA_NAME, ZCREDS_NAME):
                files[name] = owner

    if len(files) != 0:
        lines = "CONTEXT = [\n"
        lines += "FILES_DS = \""
        # Add files and init scripts
        for fl in sorted(files):
            lines += '$FILE[IMAGE=\\"{0}\\", IMAGE_UNAME=\\"{1}\\"] '.format(fl, files[fl])
        for fil in sorted(init_files):
            lines += '$FILE[IMAGE=\\"{0}\\", IMAGE_UNAME=\\"{1}\\"] '.format(fil, init_files[fil])
        lines += "\",\n"
        # Define init scripts to be started on boot
        lines += "INIT_SCRIPTS = \""
        for fil in sorted(init_files):
            lines += fil + ' '
        lines += '",\n'
        lines += 'SSH_PUBLIC_KEY="$USER[SSH_PUBLIC_KEY]",\n'
        idx = 0
        for net in vnets:
            lines += f'IP_ETH{idx}={net["ip"]},\n'
            lines += f'NETMASK_ETH{idx}={net["mask"]},\n'
            lines += f'DNS_ETH{idx}={net["dns"]}\n,'
            if idx == 0:
                lines += f'GW_ETH{idx}={net["gw"]}\n,'
            idx += 1
        lines += 'ZEPHIR_ID={0},\n'.format(vm_id)
        lines += "NETWORK=\"YES\"\n"
        lines += "]\n"
        return lines
    return ""

def update_vm_templates(vms):
    """ Update VM template with what's in the "server" configuration
    """
    for vm in vms:
        template = ""
        tmpl = []
        utmpl = []

        ntmpl = f"CPU = {vm['cpu']}\n"
        ntmpl += f"VCPU = {vm['vcpu']}\n"
        ntmpl += f"MEMORY = {vm['memory']}\n"

        for net in vm["net"]:
            ntmpl += f"NIC = [ NETWORK = {net['name']}, NETWORK_UNAME = \"oneadmin\" ]"

        ntmpl += set_context(vm["id_zephir"],vm["net"])

        fp = tempfile.NamedTemporaryFile(delete=False)
        fp.write(ntmpl.encode())
        fp.close()

        res = silent_run(f"onetemplate update --append {vm['name']} {fp.name}")
        if res.returncode != 0:
            print("Error updating template {0}".format(vm["app"]))
            print("   {0}".format(res.stdout))

        os.remove(fp.name)

def get_vm_status():
    sts = {}
    res = silent_run("onevm list --csv --no-header")
    if res.returncode == 0:
        st = res.stdout.decode().strip().split('\n')
        for st in res.stdout.decode().strip().split('\n'):
            st = st.split(',')
            sts[st[3]] = st[4]
        return sts
    else:
        return None

def _wait_for_poff_or_error(vms):
    machines = vms
    score = len(machines)
    messages = []
    while score != 0:
        status = get_vm_status()
        for vm in machines:
            if status[vm] == "poff" or status[vm] == "err":
                messages.append(f"VM-{vm}:STATUS:{status[vm].upper()}")
                score = score - 1
                machines.remove(vm)
    return messages

def instance_vms(vms):
    tpl_names = []
    vm_ids = []
    messages = []
    print("Création des machines virtuelles :")
    for vm in vms:
        sys.stdout.write(f'   {vm["name"]} ')
        sys.stdout.flush()
        if silent_run(f"onevm show {vm['name']}").returncode != 0:
            cmd = "onetemplate instantiate " + vm["name"] + " --name " + vm["name"]
            res = silent_run(cmd)
            if res.returncode == 0:
                sys.stdout.write("\t[OK]\n")
                messages.append(f"{vm['name']}:INSTANCE_VM:OK")
                vm_ids.append(vm["name"])
            else:
                sys.stdout.write("\t[KO]\n")
                messages.append(f"{vm['name']}:INSTANCE_VM:KO:{res.stdout.decode()}{res.stderr.decode().rstrip}")
                return None
        else:
            messages.append(f"{vm['name']}:INSTANCE_VM:EXISTS")
            sys.stdout.write("\t[EXISTS]\n")

    messages += _wait_for_poff_or_error(vm_ids)
    fd = open(STATUS_FILE, "a")
    for msg in messages:
        fd.write(msg + "\n")
    fd.close()
    return get_vm_status()

def cleanup(vms):
    """ Clean the deployement (this is bad)
        Terminate all VMS
        Remove Context options from templates
        Remove zephir creds file
        Deploy again vms without context options
    """
    # FIXME TODO
    return True

def clean_up_vm_templates(vms):
    # FIXME TODO
    return True

def poweron_vms(vms):
    # FIXME TODO
    fd = open(STATUS_FILE, "a")
    for vm in vms:
        if silent_run(f"onevm resume {vm['name']}").returncode != 0:
            fd.write(f"{vm['name']}:RESUME:OK\n")
        else:
            fd.write(f"{vm['name']}:RESUME:KO\n")
    fd.close()
    return True

def main():
    # import des fonctions communes de Zéphir client
    authentified, proxy = get_pwd(adresse_zephir, 7080)
    if authentified == False:
        print(proxy)
        sys.exit(0)

    if not os.path.exists(DPL_ROOT_DIR):
        os.mkdir(DPL_ROOT_DIR)

    # on récupère la liste des groupes
    print('Recherche de la liste des serveurs (cette action peut être longue)')
    servers = []
    if MODE in ('site', 'mixte'):
        for rne in RNE:
            try:
                dico = {u'rne': rne}
                groupe_vars = {u'var_1': (VARIABLE_NAME, VARIABLE_VALUE, False)}
                idx, liste_serveurs = convert(proxy.serveurs.groupe_serveur(dico,
                                                                            groupe_vars, False))
            except xmlrpclib.ProtocolError:
                print_red("""Erreur de permissions ou Zéphir non disponible""")
                sys.exit()
            except socket.error as e:
                print_red("""Erreur de connexion au serveur Zéphir (%s)""" % str(e))
                sys.exit()
            if idx != 1:
                print_red(liste_serveurs)
                sys.exit()
            for server in liste_serveurs:
                servers.append(server['id'])

    # Sort Servers by zephir_id for "recovered servers"
    servers = sorted(servers)

    if MODE in ("liste manuelle", "mixte"):
        for srv in client.get_creole('dp_server_id_list'):
            servers.append(int(srv))

    print("Liste des serveurs : {}".format(servers))
    get_config_from_zephir(servers, proxy)

    vms = get_vms_infos(servers)

    if not check_resources(vms):
        print_red("Erreur, resources insuffisantes pour déployer \
        les machines virtuelles demandées")
        return

    res, net = check_vnet(vms)
    if not res:
        if net:
            print_red("Erreur, réseaux virtuels incohérents {0} \
                n'existe.nt pas dans Hâpy".format(net))
        else:
            print_red("Erreur, vérification de la cohérence \
                des réseaux virtuels impossible")
        return

    import_apps_from_markets(vms)
    update_vm_templates(vms)
    instance_vms(vms)
    #clean_up_vm_templates(vms)
    poweron_vms(vms)
    cleanup(vms)

if __name__ == '__main__':
    main()
