# -*- coding: utf-8 -*-
###########################################################################
# Eole NG - 2012
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
###########################################################################

import psycopg2
import sys
from eoledb import EoleDb
from password import EolePgPwd
from subprocess import check_call
import os
import pwd
import re

PGPASS_PATTERN = "{dbhost}:{dbport}:{dbname}:{dbuser}:{dbpass}"
PGPASS_FILE = '/root/.pgpass'
ALLOWED_PRIVILEGES = {'database': set(['ALL', 'CONNECT', 'CREATE', 'TEMPORARY', 'TEMP']),
                      'schema': set(['ALL', 'CREATE', 'USAGE']),
                      'table':set( ['ALL', 'SELECT', 'INSERT', 'UPDATE', 'DELETE', 'TRUNCATE', 'REFERENCES', 'TRIGGER'])}
DB_OBJECT_RE = re.compile(r'(?P<db>[^\.]+)(\.(?P<schema>[^\.]+))?(.(?P<table>[^\.]+))?')

def demote(user):
    def set_ids():
        uid = pwd.getpwnam(user).pw_uid
        gid = pwd.getpwnam(user).pw_gid
        os.setgid(gid)
        os.setuid(uid)
    return set_ids


class EoleDbPg(EoleDb):
    """Connection a une base de donées postgresql"""
    rtpwdChanged = False
    pgpass = set()

    def __init__(self, obj):
        EoleDb.__init__(self, obj)
        self.set_conf({'dbroot': 'postgres',
                       'dbcont': 'postgresql',
                       'dbhost': '192.0.2.11' if self.in_cont else '127.0.0.1',
                       'dbport': 5432})

    def __dbExists__(self):
        """ Test if a database exists
        """
        with self.connect(mode='admin') as conn:
            req = """SELECT count(*) FROM pg_database
                    WHERE datname='{0}';""".format(self.dbname)
            with conn.cursor() as cursor:
                cursor.execute(req)
                res = cursor.fetchone()[0]
        if res == 1:
            return True
        else:
            return False

    def __userExists__(self, user, host):
        """ Test if a database exists
        """
        with self.connect(mode='admin') as conn:
            req = """SELECT count(*) FROM pg_roles
                   WHERE rolname='{0}';""".format(self.dbuser)
            with conn.cursor() as cursor:
                res = cursor.execute(req)
        if res == 1:
            return True
        else:
            return False

    def _template_proof_run(step):
        template_sql = "UPDATE pg_database SET datistemplate='{0}' WHERE datname='{1}'"
        def run(self):
            if getattr(self, 'is_template', False) is True:
                with self.connect(mode='admin') as conn:
                    conn.set_isolation_level(0)
                    with conn.cursor() as cursor:
                        cursor.execute(template_sql.format(0, self.dbname))
                res = step(self)
                with self.connect(mode='admin') as conn:
                    conn.set_isolation_level(0)
                    with conn.cursor() as cursor:
                        cursor.execute(template_sql.format(1, self.dbname))
            else:
                res = step(self)
            return res
        return run

    def connect(self, mode='owner'):
        """ Open Database connection with predefined parameters:
            - owner: database and user as specified in configuration file
            - superuser: database as specified in configuration file and
            superuser
            - admin: postgres database and superuser
        :param mode: flag determining database name and user used
        :type mode: str 'owner'|'superuser'|'admin'
        """
        conn = None
        mode = 'superuser' if mode == 'owner' else mode
        try:
            if mode == 'admin':
                conn = psycopg2.connect(host=self.dbhost,
                                        port=self.dbport,
                                        database='postgres',
                                        user=self.dbroot,
                                        password=EoleDbPg.dbrootpwd)
            elif mode == 'superuser':
                conn = psycopg2.connect(host=self.dbhost,
                                        port=self.dbport,
                                        database=self.dbname,
                                        user=self.dbroot,
                                        password=EoleDbPg.dbrootpwd)
            elif mode == 'owner':
                conn = psycopg2.connect(host=self.dbhost,
                                        database=self.dbname,
                                        port=self.dbport,
                                        user=self.dbuser,
                                        password=self.dbpass)
            else:
                raise Exception('Unknown connection mode: {0}'.format(mode))
        except Exception as err:
            print str(sys.exc_info()[0])
            raise Exception("Error while connecting to PostgreSQL database: {0}".format(err))
        return conn

    def run(self, statement):
        """ Run SQL Statements"""
        print("Run Posgresql Statement : \n\t" + statement)

    def create_db(self, context=None):
        if self.__dbExists__():
            return False
        with self.connect(mode='admin') as conn:
            if self.__dict__.get('createscript', None) is not None:
                arr = self.get_statements(self.createscript)
            else:
                template = self.__dict__.get('template', 'DEFAULT')
                sql = """CREATE DATABASE "{}" TEMPLATE {};"""
                arr = [sql.format(self.dbname, template), ]
            conn.set_isolation_level(0)
            with conn.cursor() as cursor:
                for statement in arr:
                    cursor.execute(statement)
                sql = """ALTER DATABASE "{0}" OWNER TO "{1}";"""
                cursor.execute(sql.format(self.dbname, self.dbuser))
        return True

    @_template_proof_run
    def instance_db(self):
        if self.__dbExists__():
            return False
        else:
            self.create_db()
            with self.connect(mode='superuser') as conn:
                for sql in self.sqlscripts:
                    statements = self.get_statements(sql)
                    for elm in statements:
                        with conn.cursor() as cursor:
                            cursor.execute(elm)
            return True

    @_template_proof_run
    def update_db(self):
        if self.__dbExists__():
            self.manage_privileges()
            with self.connect(mode='owner') as conn:
                for sql in self.updatescripts:
                    statements = self.get_statements(sql)
                    for elm in statements:
                        with conn.cursor() as cursor:
                            cursor.execute(elm)
            return True
        else:
            return False

    def create_role(self, role, role_options=None):
        with self.connect(mode='admin') as conn:
            with conn.cursor() as cursor:
                cursor.execute("SELECT rolname FROM pg_roles;")
                if role in (row[0] for row in cursor):
                    return False
                else:
                    options = []
                    if role_options is not None: options.extend(role_options)
                    cursor.execute("""CREATE ROLE "{0}" WITH {1};""".format(role,
                                                                      ' '.join(options)))
                    if cursor.statusmessage != "CREATE ROLE":
                        raise Exception("Erreur à la création du role {}".format(role))
                    return True

    def get_schemas_list(self):
        """
        Return list of schemas for current database
        """
        with self.connect(mode='superuser') as conn:
            with conn.cursor() as cursor:
                cursor.execute("SELECT * FROM pg_namespace")
                schemas = [row[0] for row in cursor]
        return schemas

    def expand_db_objects(self, privileges_dict):
        """
        Return list of SQL fragments matching targeted db objects
        used in GRANT SQL commands.
        :param privileges_dict: dictionary linking db_objects representation with
        privileges
        :type privileges_dict: dict
        """
        tables_in_schema = 'ALL TABLES IN SCHEMA "{0}"'
        schemas = 'SCHEMA "{0}"'
        database = 'DATABASE "{0}"'
        table = 'TABLE "{0}"'
        params = []
        if privileges_dict is not None:
            for db_objects, privileges in privileges_dict.iteritems():
                param = {}
                db_objects = db_objects.split('.')
                if len(db_objects) == 1: # object is database specific
                    param['object'] = database.format(db_objects[0])
                    if 'ALL' in privileges:
                        param['privileges'] = set(['ALL'])
                    else:
                        param['privileges'] = privileges.intersection(ALLOWED_PRIVILEGES['database'])
                elif len(db_objects) > 1:
                    if db_objects[1] == '*':
                        db_objects[1] = ', '.join(self.get_schemas_list())
                    if len(db_objects) == 3: # object is table specific
                        if db_objects[2] == '*':
                            param['object'] = tables_in_schema.format(db_objects[1])
                        else:
                            param['object'] = table.format(db_objects[2])
                        if 'ALL' in privileges:
                            param['privileges'] = set(['ALL'])
                        else:
                            param['privileges'] = privileges.intersection(ALLOWED_PRIVILEGES['table'])
                    elif len(db_objects) == 2: #  object is schema specific
                        param['object'] = tables_in_schema.format(db_objects[1])
                        if 'ALL' in privileges:
                            param['privileges'] = set(['ALL'])
                        else:
                            param['privileges'] = privileges.intersection(ALLOWED_PRIVILEGES['schema'])
                if len(param['privileges']) > 0:
                    params.append(param)

        return params

    def revoke_privileges(self, role):
        """
        Execute REVOKE commands to revoke privileges from given role.
        :param role: role the privileges are to be revoked
        :type role: str
        """
        if role != self.dbroot:
            revoke_database = 'REVOKE ALL PRIVILEGES ON DATABASE "{0}" FROM "{1}"'
            revoke_tables = 'REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA "{0}" FROM "{1}"'
            with self.connect(mode='superuser') as conn:
                with conn.cursor() as cursor:
                    cursor.execute(revoke_database.format(self.dbname, role))
                    schemas = self.get_schemas_list()
                    for schema in schemas:
                        cursor.execute(revoke_tables.format(schema, role))

    def grant_privileges(self, role, privileges=None):
        """
        Execute GRANT commands to grant given privileges to given role.
        :param role: role whom privileges are to be granted to
        :type role: str
        :param privileges: database object along with privileges
        :type privileges: dict
        """
        if role != self.dbroot:
            grant_priv = 'GRANT {0} ON {1} TO "{2}";'
            params = self.expand_db_objects(privileges)
            for param in params:
                with self.connect(mode='superuser') as conn:
                    with conn.cursor() as cursor:
                        cursor.execute(grant_priv.format(', '.join(param['privileges']),
                                                        param['object'],
                                                        role))

    def manage_privileges(self):
        """
        Set up privileges for roles associated with db.
        """
        self.revoke_privileges(self.dbuser)
        self.grant_privileges(self.dbuser, {self.dbname: set(['ALL'])})
        with self.connect(mode='admin') as conn:
            with conn.cursor() as cursor:
                sql = 'ALTER DATABASE "{0}" OWNER TO "{1}";'
                cursor.execute(sql.format(self.dbname, self.dbuser))
        # additional role linked to db
        for role in self.__dict__.get('additional_roles', []):
            self.revoke_privileges(role)
            self.grant_privileges(role, self.additional_roles[role].get('privileges', None))

    def alter_role_password(self, user, password):
        sql = """ALTER ROLE "{0}" WITH PASSWORD '{1}';""".format(user, password)
        dbhost = '' if self.dbhost in ['localhost', '127.0.0.1'] else '-h {}'.format(self.dbhost)
        cmd = ['psql', dbhost, '-w', '-c', sql, 'postgres', self.dbroot]
        with open(os.devnull, 'w') as devnull:
            code = check_call(cmd,
                                stdout=devnull,
                                stderr=devnull,
                                preexec_fn=demote('postgres'))
        if code == 0:
            return True
        else:
            return False

    def set_pgpass(self, user, password, database=None, update=True):
        dbhost = 'localhost' if self.dbhost == '127.0.0.1' else self.dbhost
        pgpass = PGPASS_PATTERN.format(dbhost=dbhost,
                                       dbport=self.dbport,
                                       dbname=database if database is not None else '*',
                                       dbuser=user,
                                       dbpass=password)
        if update:
            EoleDbPg.pgpass.update(set([pgpass]))
            with open(PGPASS_FILE, 'w') as pgpass_file:
                pgpass_file.write('\n'.join(EoleDbPg.pgpass))
            os.chmod(PGPASS_FILE, 384)
        return pgpass

    def change_passwords(self, local_conf=None, backup_dir=None):
        """ Change users password
        """
        # .pgpass pour permettre la connexion sans entrer de mot de passe
        if EoleDbPg.rtpwdChanged is False:
            rpwd = EolePgPwd(self.dbroot, None)
            rpwd.gen_new_password('auto')
            password = rpwd.ncl_pass
            if password is not None and self.alter_role_password(self.dbroot, password):
                self.set_pgpass(self.dbroot, password)
                EoleDbPg.dbrootpwd = password
                EoleDbPg.rtpwdChanged = True
            else:
                raise Exception("[Error] Root password renew failed")

        # owner of the database if not dbroot
        if self.dbuser != self.dbroot:
            self.create_role(self.dbuser, self.__dict__.get('dbuser_options', None))
            pwd_mode = self.__dict__.get('dbuser_pwd_mode', 'auto')
            if pwd_mode == 'auto':
                pwd = EolePgPwd(self.dbuser, None)
                pwd.gen_new_password('auto')
                self.dbpass = pwd.ncl_pass
                if self.dbpass is not None and self.alter_role_password(self.dbuser, self.dbpass):
                    self.set_pgpass(self.dbuser, self.dbpass, database=self.dbname)

                # ALTER ROLE <role> WITH PASSWORD '<password>';
                # applications
                for pwd_file in self.pwd_files:
                    try:
                        pwd.update_conf_file(pwd_file['file'], pwd_file['pattern'])
                    except KeyError as err:
                        msg = "Missing or unknown parameter {0} in :\n".format(err)
                        for files in self.pwd_files:
                            msg += u'   {0}\n'.format(str(files))
                        msg += "Please check your configuration\n"
                        raise Exception(msg)
        else:
            self.dbpass = self.dbrootpwd

        # other roles if pwd_mode set to auto
        for role in self.__dict__.get('additional_roles', {}).values():
            self.create_role(role['role'], role.get('options', None))
            if role.get('pwd_mode', 'manuel') == 'auto':
                pwd = EolePgPwd(self.dbuser, None)
                pwd.gen_new_password('auto')
                self.alter_role_password(role['role'], pwd.ncl_pass)

        return True

    _template_proof_run = staticmethod(_template_proof_run)
