# -*- coding: UTF-8 -*-
###########################################################################
#
# Eole NG - 2007
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
#
# libsecure.py
#
# classes utilitaires pour lancement des services en https
# et vérification des certificats
#
###########################################################################

import sys, os
from glob import glob
import httplib, xmlrpclib
from urlparse import urlparse
from twisted.python import log
from twisted.web.client import HTTPClientFactory
from twisted.internet import protocol, reactor
from twisted.protocols import policies

from time import time

# imports SSL
# On utilise M2Crypto au lieu de python OpenSSL (permet la vérification des AltNames)
# Attention : Utiliser TwistedProtocolWrapper.connectSSL (connectTCP/listenSSL/listenTCP)
# au lieu des fonctions standard de twisted.internet.reactor
from M2Crypto import SSL
from M2Crypto.SSL.TwistedProtocolWrapper import TLSProtocolWrapper

def get_cert_subjects(X509_cert):
    """retourne la liste des noms/adresses du sujet d'un certificat X509"""
    subjects = []
    subj = X509_cert.get_subject()
    for subj_cn in subj.get_entries_by_nid(subj.nid['CN']):
        subjects.append(subj_cn.get_data().as_text().lower().strip())
    # récupération des AltNames si disponibles
    try:
        altnames = X509_cert.get_ext('subjectAltName').get_value()
        for altname in altnames.split(','):
            if altname.split(':')[0].lower().strip() in ('dns', 'ip'):
                subjects.append(altname.split(':')[1].lower().strip())
    except:
        # extension non présente
        pass
    return subjects

def X509_verify_cert_error_string(errnum):
    return str(errnum)

# transport sécurisé utilisant un certificat
class TransportEole(xmlrpclib.SafeTransport):

    def __init__(self, cert_file, key_file = None):
        self.cert_file = cert_file
        self.key_file = key_file or cert_file
        super(xmlrpclib.SafeTransport, self).__init__()

    def make_connection(self, host):
        # créée un objet HTTPS connection à partir d'un descripteur d'hôte
        # host peut être une chaine, ou un tuple (host, x509-dict)
        localhost, extra_headers, x509 = self.get_host_info(host)
        try:
            HTTPS = httplib.HTTPS
        except AttributeError:
            raise NotImplementedError(
                "your version of httplib doesn't support HTTPS"
                )
        else:
            cx = HTTPS(localhost, None,
                       key_file = self.cert_file,
                       cert_file = self.cert_file)
            return cx

class ClientContextFactory:
    """context factory for SSL clients."""
    isClient = 1

    def __init__(self, certfile, keyfile='', ca_location='', mode=None):
        self.certfile = certfile
        self.keyfile = keyfile
        self.ca_location = ca_location
        self.x509_checker = SSL.Checker.Checker()
        if mode is None:
            self.verify_mode = SSL.m2.SSL_VERIFY_PEER|SSL.m2.SSL_VERIFY_FAIL_IF_NO_PEER_CERT
        else:
            self.verify_mode = mode
        self.peer_cert_data = {}

    ## fonction de vérification d'un certificat (utilisé pour les proxies)
    def _cert_verify(self, peer_cert, host=None):
        """ This function will be called once for each certificate in the chain
        - calls M2Crypto Checker for certificate subject check (including AltNames)
        - stores possibles subjects for further checks in application
        """
        # XXX FIXME : récativer validation des noms ?
        # pour l'instant : vérification dans le callback de vérification du proxy
        # cert_validity = self.x509_checker(peer_cert, host)
        cert_validity = True
        subjects = []
        if cert_validity:
            # récupération du sujet du certificat (CN)
            subj = peer_cert.get_subject()
            for subj_cn in subj.get_entries_by_nid(subj.nid['CN']):
                subjects.append(subj_cn.get_data().as_text().lower().strip())
            # récupération des AltNames si disponibles
            try:
                altnames = peer_cert.get_ext('subjectAltName').get_value()
                for altname in altnames.split(','):
                    if altname.split(':')[0].lower().strip() in ('dns', 'ip address'):
                        subjects.append(altname.split(':')[1].lower().strip())
            except LookupError:
                # extension non présente
                pass
        self.peer_cert_data['subject'] = subjects
        return cert_validity

    def getContext(self):
        ctx = SSL.Context()
        try:
            if self.keyfile and os.path.isfile(self.keyfile):
                ctx.load_cert(self.certfile, self.keyfile)
            else:
                ctx.load_cert(self.certfile, self.certfile)
        except:
            # impossible de charger la clé depuis le fichier du certificat, on regarde si il n'y a pas un fichier .key associé
            alt_keyfile = os.path.splitext(self.certfile)[0] + '.key'
            ctx.load_cert(self.certfile, alt_keyfile)
        # chargement des certificats des autorités de certification (si disponible)
        if os.path.isdir(self.ca_location):
            ca_certs = [os.path.join(self.ca_location, ca_file) for ca_file in os.listdir(self.ca_location)]
        elif os.path.isfile(self.ca_location):
            ca_certs = [self.ca_location]
        else:
            ca_certs = []
            log.msg(_("Invalid CA location"), self.ca_location)
        for ca_cert in ca_certs:
            if os.path.splitext(ca_cert)[1] in [".pem",".der",".crt",".cert"]:
                try:
                    ctx.load_verify_locations(ca_cert)
                except SSL.SSLError, e:
                    log.msg(_('Error loading CA certificate %s : %s') % (ca_cert, str(e)))
        ctx.set_verify(self.verify_mode, 10)
        return ctx

class ServerContextFactory:
    """
    Factory permettant de créer un contexte SSL
    """

    def __init__(self, certfile, keyfile='', ca_file=''):
        self.certfile = certfile
        self.keyfile = keyfile
        self.ca_file = ca_file

    def getContext(self):
        """Create an SSL context.
        This is a sample implementation that loads a certificate from a file
        called 'server.pem'.
        """
        ctx = SSL.Context(protocol='tlsv1')
        try:
            if self.keyfile:
                ctx.load_cert_chain(self.certfile, self.keyfile)
            else:
                ctx.load_cert_chain(self.certfile, self.certfile)
        except:
            # impossible de charger la clé depuis le fichier du certificat, on regarde si il n'y a pas un fichier .key associé
            alt_keyfile = os.path.splitext(self.certfile)[0] + '.key'
            try:
                ctx.load_cert_chain(self.certfile, alt_keyfile)
            except SSL.SSLError, e:
                log.msg(_('Error loading Key file %s : %s') % (alt_keyfile, str(e)))
        if self.ca_file:
            try:
                # chargement des ca à envoyer au client
                ctx.load_client_ca(self.ca_file)
            except Exception, e:
                log.msg(_('Error loading certificates chain file %s : %s') % (self.ca_file, str(e)))
        return ctx

def getPageM2(url, checker=None, contextFactory=None, *args, **kwargs):
    """remplacement de twisted.web.client.getPage pour permettre une utilisation avec M2Crypto

    Download a page. Return a deferred, which will callback with a
    page (as a string) or errback with a description of the error.

    See HTTPClientFactory to see what extra args can be passed.
    """

    urlparsed = urlparse(url)
    scheme = urlparsed.scheme
    host = urlparsed.hostname
    port = urlparsed.port
    path = urlparsed.path
    factory = HTTPClientFactory(url, *args, **kwargs)
    wrappingFactory = policies.WrappingFactory(factory)
    if checker is None:
        checker = SSL.Checker.Checker()
    if scheme == 'https':
        wrappingFactory.protocol = lambda factory, wrappedProtocol: \
            TLSProtocolWrapper(factory,
                               wrappedProtocol,
                               startPassThrough=0,
                               client=1,
                               contextFactory=contextFactory,
                               postConnectionCheck=checker)
        factory.startTLS = 1
        if port == None: port = 443
    else:
        wrappingFactory.protocol = lambda factory, wrappedProtocol: \
            TLSProtocolWrapper(factory,
                               wrappedProtocol,
                               startPassThrough=1,
                               client=1,
                               contextFactory=contextFactory,
                               postConnectionCheck=checker)
        factory.startTLS = 0
        if port == None: port = 80
    reactor.connectTCP(host, port, wrappingFactory, timeout=30)
    return factory.deferred
