#!/usr/bin/env python2
# -*- coding: utf-8 -*-
###########################################################################
# Eole NG - 2012
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
###########################################################################
""" EoleDB the new database manager for EOLE
"""
import argparse
from os import listdir
from os.path import isfile, join
from IPy import IP
from termcolor import colored
import yaml
import re
from itertools import product

from eoledb.eoledbconnector import EoleDbConnector
from eoledb.eoledberrors import UnsupportedDatabase

CONF_FILE_RE = re.compile(r'^.*\.yml$')
HBA_PATTERN = "{protocol}\t{dbname}\t{dbuser}\t{source}\t{auth_method}"
HBA_RE = r"^(?P<protocol>{protocol})\s+(?P<dbname>{dbname})\s+(?P<dbuser>{dbuser})(?:\s+(?P<source>(?P<ip>{ip})(?:(?:/(?P<ip_class>{ip_class}))|(?:\s+(?P<ip_netmask>{ip_netmask})))?))?\s+(?P<auth_method>{auth_method})(?:\s+(?P<auth_option>{auth_option}))?$"
hba_options = {'protocol': r"(?:host|local)",
               'dbname': r"[\S]+",
               'dbuser': r"[\S]+",
               'ip': r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}",
               'ip_class': r"[0-9]{1,2}",
               'ip_netmask': r"[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}",
               'auth_method': r"(?:peer|ident|md5|password)",
               'auth_option': r".*"
               }
HBA_RE = re.compile(HBA_RE.format(**hba_options))
PG_HBA_END_PATH = 'etc/postgresql/9.5/main/pg_hba.conf'
ALL_DB_RE = re.compile(r'^\*')

def parse_pg_hba_file(pg_hba):
    """Renvoie la liste des accès configurés dans le fichier pg_hba.conf
    :param pg_hba: chemin du fichier pg_hba.conf
    :type pg_hba: str
    """
    with open(pg_hba, 'r') as pg_hba_file:
        relevant_rules = [rule.strip() for rule in pg_hba_file.readlines()]
        try:
            relevant_rules = relevant_rules[relevant_rules.index('# distant_users'):relevant_rules.index('# end distant_users')]
        except ValueError:
            relevant_rules = []
        pg_hba_rules = [HBA_RE.match(hba_rule).groupdict()
                        for hba_rule in relevant_rules
                        if HBA_RE.match(hba_rule)]
    return pg_hba_rules


def weight_pg_hba_rule(rule):
    """
    Return weight of rule based on user and db targeted ("all" must be at the
    end of the list
    :param rule: dictionary describing rule
    :type rule: dict
    """
    weight = 0
    # dbname
    dbname = rule['dbname']
    if dbname == 'all':
        weight += 3
    elif dbname == 'samerole':
        weight += 2
    elif dbname == 'sameuser':
        weight += 1
    # dbuser
    dbuser = rule['dbuser']
    if dbuser == 'all':
        weight += 2
    elif dbuser.startswith('+'):
        weight += 1
    return weight


def normalise_rule(rule):
    """Normalise les règles d'accès
    :param conf: paramètres d'une règle d'accès
    :type conf: dict
    """
    if rule['source'] in ['localhost', '127.0.0.1', None]:
        rule.setdefault('protocol', 'local')
        rule['source'] = ''
    elif rule['source'] != '':
        rule.setdefault('protocol', 'host')
        if rule['ip_netmask'] is None:
            netmask = IP('/'.join([rule['ip'], rule.get('ip_class', '32')])).strNetmask()
        else:
            netmask = rule['ip_netmask']
        rule['source'] = '\t'.join([rule['ip'], netmask])
    rule.setdefault('auth_method', 'md5')
    rule.setdefault('auth_option', '')
    return rule


def clean_rules(rules):
    """
    Return rules list after deduplication and sort
    :param rules: list of dict containing host based access rules parameters
    :type rules: tuple of dict
    """
    rules = [normalise_rule(rule) for rule in rules]
    dedup_rules = {(conf['protocol'], conf['dbname'], conf['dbuser']): conf
                   for conf in rules}.values()
    dedup_rules.sort(key=lambda x: weight_pg_hba_rule(x))
    return dedup_rules


def sort_confs(confs):
    """
    Return keys from confs ordered by name, taking care of dependencies
    (postgresql templating)
    :param confs: databases configuration
    :type confs: dict
    """
    sorted_templates = sorted([conf for conf in confs
                               if confs[conf].get('is_template', False) == True])
    sorted_databases = sorted([conf for conf in confs
                               if confs[conf].get('is_template', False) == False])
    sorted_confs = sorted_templates + sorted_databases
    return sorted_confs


class DBConf(object):
    """Configuration object offering one context to aggregate informations
    and generate shared configuration.
    """
    backends = ('mysql', 'postgres', 'sqlite')

    def __init__(self, conf_path):
        self.default_conf = {}
        self.default_conf = self.load_conf(conf_path)
        self.pg_hba_path = join(self.default_conf.get('container_path_postgresql', '/'),
                                PG_HBA_END_PATH)
        self.db_confs = {}

    def __enter__(self):
        """Retourne l'object configuration
        """
        return self

    def load_conf(self, cpath):
        """ Load configuration from yaml file
        """
        try:
            with open(cpath, 'r') as yml_conf:
                conf = yaml.load(yml_conf)
        except Exception:
            conf = None

        if conf is None:
            conf = {}
        elif isinstance(conf, dict) is False:
            msg = "[ERROR] Configuration file format error !"
            msg += "\n{0} is not a valid YAML file".format(cpath)
            raise Exception(msg)
        elif 'dbname' in conf:
            self.add_db_conf(conf)
        elif 'additional_db' in conf or 'additional_roles' in conf:
            self.add_db_conf({'dbname': 'postgres',
                              'dbtype': 'postgres',
                              'dbuser': 'postgres'})
            for db_conf in conf.get('additional_db', {}).values():
                self.add_db_conf(db_conf)
            for role_conf in conf.get('additional_roles', {}).values():
                self.add_role_conf(role_conf)
        else:
            return conf

    def expand_db_names(self, conf):
        """
        Return db names list expanding db from privileges dict keys in conf.
        :param conf: dictionnary with 'privileges' key
        :type conf: dict
        """
        privileges = {}
        db_privileges = {}
        for privilege in conf['privileges']:
            segments = ALL_DB_RE.sub('|'.join([pg_conf[0]
                                               for pg_conf
                                               in self.get_confs_by_backend('postgres')]),
                                     privilege['objet']).split('.')
            for priv in list(product(*[segment.split('|') for segment in segments])):
                #if len(priv) == 3 and priv[2] == '*':
                #    priv = priv[:2]
                privileges.setdefault(priv[0], set(['CONNECT']))
                priv = '.'.join(priv)
                privileges.setdefault(priv, set([]))
                privileges[priv].update(set(privilege['privilege'].split()))
        for privilege in privileges:
            db = privilege.split('.')[0]
            db_privileges.setdefault(db, {})
            db_privileges[db][privilege] = privileges[privilege]
        return db_privileges

    def add_db_conf(self, conf):
        """Ajouter la configuration au pool de configurations de l'instance
        :param conf: configuration de base de données
        :type conf: dict
        """
        mandatory_vars = {'dbtype': None,
                          'updatescripts': [],
                          'sqlscripts': [],
                          'pwd_files': [],
                          'in_cont': False,
                          'client_hosts': []}
        mandatory_vars.update(self.default_conf)

        for var, default in mandatory_vars.items():
            conf.setdefault(var, default)
        self.db_confs.setdefault(conf['dbname'], {})
        if 'additional_roles' in self.db_confs[conf['dbname']]:
            conf['additional_roles'].update(self.db_confs[conf['dbname']])
        self.db_confs[conf['dbname']].update(conf)

    def add_role_conf(self, conf):
        """Ajouter la configuration des rôles au pool de configurations
        de l'instance.
        :param conf: configuration de role
        :type conf: dict
        """
        if 'privileges' in conf:
            db_privileges = self.expand_db_names(conf)
            del conf['privileges']
            for db in db_privileges:
                self.db_confs.setdefault(db, {})
                self.db_confs[db].setdefault('additional_roles', {})
                conf.update(self.db_confs[db]['additional_roles'])
                additional_roles = self.db_confs[db]['additional_roles']
                additional_roles[conf['role']] = conf.copy()
                additional_roles[conf['role']].setdefault('privileges', {})
                additional_roles[conf['role']]['privileges'].update(db_privileges[db])
        else:
            self.db_confs.setdefault('postgres', {})
            if 'additional_roles' in self.db_confs['postgres']:
                conf.update(self.db_confs['postgres']['additional_roles'])
            else:
                self.db_confs['postgres']['additional_roles'] = {}
            self.db_confs['postgres']['additional_roles'][conf['role']] = conf

    def get_db_confs(self):
        """Méthode pour accéder aux dictionnaires des configurations des bases
        de données
        """
        sorted_confs = sort_confs(self.db_confs)
        return ((conf, self.db_confs[conf]) for conf in sorted_confs)

    def get_confs_by_backend(self, backend):
        """Méthode pour accéder aux dictionnaires des configurations des bases
        de données utilisant la backend fourni en paramètre.
        """
        if backend in self.backends:
            confs = [conf for conf in self.get_db_confs()
                     if conf[1]['dbtype'] == backend]
        else:
            confs = []
        return confs

    def gen_pg_access_conf(self):
        """Génère le contenu du fichier pg_hba
        """

        if isfile(self.pg_hba_path):
            custom_accesses = parse_pg_hba_file(self.pg_hba_path)
        else:
            custom_accesses = []

        pg_confs = self.get_confs_by_backend('postgres')
        if self.default_conf['in_cont'] is True:
            source = '192.0.2.1'
            adm_method = "ident"
        else:
            source = '127.0.0.1'
            adm_method = 'peer'
        access_confs = [{item[0]: item[1] for dic in conf for item in dic.iteritems()}
                        for pg_conf in pg_confs
                        for conf in list(product([{'dbname': pg_conf[1]['dbname']}, ],
                                                 [{'dbuser': pg_conf[1]['dbuser']}, ],
                                                 [{'source': src}
                                                  for src in pg_conf[1]['client_hosts'] + [source]]
                                                 )
                                         )
                        ]
        access_confs.extend(custom_accesses)
        access_confs = clean_rules(access_confs)
        access_str = HBA_PATTERN.format(protocol='local', dbname='all',
                                        dbuser='postgres', source='',
                                        #auth_method='{}\tmap=pg_map\n'.format(adm_method))
                                        auth_method='md5\n')
        access_str += '\n'.join([HBA_PATTERN.format(**access_rule)
                                for access_rule in access_confs])
        return access_str

    def write_pg_conf(self):
        """Écrit les fichiers de configuration de postgresql
        """
        access_conf = self.gen_pg_access_conf()
        with open(self.pg_hba_path, 'w') as pg_hba:
            pg_hba.write(access_conf)

    def __exit__(self, type, value, traceback):
        """Finaliser l'application de la configuration
        """
        if len(self.get_confs_by_backend('postgres')) > 0:
            self.write_pg_conf()


def usage():
    """ Print command help message
    """
    print "Usage:"
    print "\t -c file # Configuration File"


def run_change_password(conn, local_conf, bdir):
    """ Run all the changing password opérations
    """
    print "\t>>> Passwords",
    if conn.change_passwords(local_conf, bdir):
        print "\t[{0}]".format(colored('OK', 'green'))
    else:
        print "\t[{0}]".format(colored('NA', 'blue'))


def create_db(connection, context=None):
    """Exécute et affiche le résultat des opérations de création
    :param connection: connexion à la base de données
    :type connection: EoleDbConnector
    :param context: contexte d'exécution de la fonction
    :type context: str
    """
    creation_res = connection.instance_db()
    if context is 'reconfigure':
        print "\t>>> Create ",
        if creation_res:
            print "\t[{0}]".format(colored('OK', 'green'))
        else:
            print "\t[{0}]".format(colored('NA', 'blue'))
    return creation_res


def update_db(connection, context=None):
    """Exécute et affiche le résultat des opérations de mise à jour
    :param connection: connexion à la base de données
    :type connection: EoleDbConnector
    :param context: contexte d'exécution de la fonction
    :type context: str
    """
    update_res = connection.update_db()
    if context is 'reconfigure':
        print "\t>>> Update ",
        if update_res:
            print "\t[{0}]".format(colored('OK', 'green'))
        else:
            print "\t[{0}]".format(colored('NA', 'blue'))
    return update_res


def main():
    """ EoleDB main program
    """
    parser = argparse.ArgumentParser(description='Eole Database generator.')
    parser.add_argument('-c', '--config', metavar='CONFIG_FILE',
                        default='/etc/eole/eole-db.conf',
                        help='eole database configuration file)')
    parser.add_argument('-d', '--dbdir', metavar='DB_CONFIG_DIR',
                        default='/etc/eole/eole-db.d/',
                        help='eole database configuration directory)')
    parser.add_argument('-i', '--interactive', action='store_true',
                        default=False, help='eole database manager interactive mode')
    parser.add_argument('-b', '--backup-dir', metavar="PW_BACKUP_DIR",
                        default='/var/backups/eoledb',
                        help='eole database directory to store backups of changed files')
    args = parser.parse_args()

    if args.interactive:
        print "Not implemented yet"

    if args.config:
        with DBConf(args.config) as global_db_conf:
            for cfile in listdir(args.dbdir):
                if CONF_FILE_RE.match(cfile):
                    global_db_conf.load_conf(join(args.dbdir, cfile))
            for conf_file, conf in global_db_conf.get_db_confs():
                try:
                    conn = EoleDbConnector(conf)()
                    if conn:
                        print("{0} : ".format(conn.dbname.upper()))
                        if args.backup_dir:
                            run_change_password(conn,
                                                None,
                                                args.backup_dir)
                        else:
                            run_change_password(conn, None, None)
                        create_db(conn, context='reconfigure')
                        update_db(conn, context='reconfigure')
                except UnsupportedDatabase as err:
                    print err
                except AttributeError as err:
                    print err
    else:
        parser.error("options -c is mandatory")


if __name__ == "__main__":
    main()
