#! /usr/bin/env python
# -*- coding: utf-8 -*-

###########################################################################
#
# Eole NG - 2007
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
#
# saml_crypto.py
#
# librairie de signature de requêtes SAML 2.0
# nécessite les librairies suivantes :
# - python-xmlsec (http://pyxmlsec.labs.libre-entreprise.org)
# - python-M2Crypto
#
###########################################################################

import os, traceback
import base64, urllib
import libxml2, xmlsec
from M2Crypto import EVP, m2, X509
from config import CERTFILE, KEYFILE
from page import log

#########################
#  Fonctions Utilitaires

def init_libs():
    """intialise les librairies libxml2/xmlsec
    """
    libxml2.initParser()
    xmlsec.init()
    xmlsec.cryptoInit()

def cleanup(doc = None, dsig_ctx = None, res = False):
    """libère la mémoire (libxml2/xmlsec)
    """
    if dsig_ctx is not None:
        dsig_ctx.destroy()
    if doc is not None:
        doc.freeDoc()
    xmlsec.cryptoShutdown()
    xmlsec.shutdown()
    libxml2.cleanupParser()
    return res

def dtd_log(msg, arg):
    """masque l'affichage des warnings à la validation de la dtd"""
    pass

def find_key_file(cert_file):
    """détermine le fichier contenant la clé d'un certificat
    """
    if not KEYFILE:
        if os.path.isfile(os.path.splitext(cert_file)[0] + '.key'):
            key_file = os.path.splitext(cert_file)[0] + '.key'
        else:
            key_file = cert_file
    else:
        key_file = KEYFILE
    return key_file

def recreate_cert(data):
    """Reformattage des données d'un certificat PEM (lignes de 64 caractères)
    """
    cert = ["-----BEGIN CERTIFICATE-----"]
    if "\n" in data:
        # suppression d'eventuels espaces/sauts de lignes
        data = data.replace('\n','').strip()
    while len(data) > 64:
        line = data[:64]
        end_data = data[64:]
        data = end_data
        cert.append(line)
    cert.append(data)
    cert.append("-----END CERTIFICATE-----\n")
    return "\n".join(cert)


###################################################
# Méthode de signature/vérification des paramètres
# passés dans la requête (binding HTTP-Redirect binding)

def sign_request(cert_file, sign_method, message_type, message, relay_state = None):
    """génère une signature for un message SAML donné avec openssl_sign

    La signature doit être calculée sur les données suivantes: (SAMLResponse ou SAMLRequest) :

        SAMLResponse=resp_data&RelayState=relay_data&SigAlg=value -> encodé dans location header

    resp_data -> urlencode( base64_encode( gzdeflate( request ) ))
    relay_data -> urlencode( relay_state )
    SigAlg -> urlencode( sign_method )
    """
    req_args = []
    req_args.append((message_type, message))
    if relay_state:
        req_args.append(('RelayState', relay_state))
    req_args.append(('SigAlg', sign_method))
    req_string = urllib.urlencode(req_args)
    # initialisation de la clé et des données à signer
    pkey_file = find_key_file(cert_file)
    k = EVP.load_key(pkey_file)
    k.sign_init()
    k.sign_update(req_string)
    signature = base64.encodestring(k.final())
    # ajout de la signature aux paramètres de la requête
    req_args.append(('Signature', signature))
    return req_args

def check_signed_request(args, message_type, cert_file):
    """Vérifie que la signature présente dans les paramètres
    correspond au contenu de la requête

    args : paramètres de la requête
    message_type : type de message ('SAMLResponse, ...)
    """
    req_args = []
    msg = "%s %s" % (message_type, str(_("Error verifying signature")))
    # support de rsa-sha1 seulement
    try:
        if 'Signature' in args and args.get('SigAlg', [''])[0] == xmlsec.HrefRsaSha1:
            signature_value = base64.decodestring(args['Signature'][0])
            # reconstruction de la chaîne de données signées
            if message_type in args:
                req_args.append((message_type, args[message_type][0]))
                if 'RelayState' in args:
                    req_args.append(('RelayState', args['RelayState'][0]))
                if 'SigAlg' in args:
                    req_args.append(('SigAlg', args['SigAlg'][0]))
                    # Toutes les données sont disponibles
                    req_string = urllib.urlencode(req_args)
                    # vérification à l'aide du certificat
                    if cert_file is None:
                        k = EVP.load_key(CERTFILE)
                    else:
                        cert = X509.load_cert(cert_file)
                        k = cert.get_pubkey()
                    k.verify_init()
                    k.verify_update(req_string)
                    # checking that signature is valid
                    # HACK : we should use k.verify_final, but as it is broken
                    # in older m2crypto versions, we use the underlying m2 module
                    if m2.verify_final(k.ctx, signature_value, k.pkey) == 1:
                        return True, ""
                    else:
                        msg = _("invalid signature value")
    except Exception, err:
        traceback.print_exc()
        msg = str(err)
    return False, msg


###############################################################
# signature/vérification d'un document XML (binding HTTP-POST)

def sign_doc(xml_src, cert_file, reference = None, key_info = True):
    """signature d'un document XML
    references : list des ids de reference des noeuds à signer, signature du document entier si None
    """
    init_libs()
    # validation avec une DTD minimale pour permettre le référencement des IDS
    doc = libxml2.parseDoc(xml_src)
    ctx = libxml2.newValidCtxt()
    ctx.setValidityErrorHandler(dtd_log, dtd_log)
    dtd = libxml2.parseDTD(None, "templates/saml.dtd")
    # validation du document
    doc.validateDtd(ctx, dtd)
    # Creation du template de signature (RSA-SHA1 enveloped signature)
    sign_node = xmlsec.TmplSignature(doc, xmlsec.transformExclC14NId(), xmlsec.transformRsaSha1Id(), None)
    parent = doc.getRootElement()
    sibling = None
    # recherche de l'emplacement où insérer le noeud de signature (signNode, global ou dans l'assertion)
    if reference:
        for ch_node in doc.getRootElement().get_children():
            if ch_node.name == 'Assertion':
                id_attr = ch_node.hasProp('ID')
                if id_attr and id_attr.content == reference:
                    for ass_child in ch_node.get_children():
                        if ass_child.name == 'Issuer':
                            sibling = ass_child
                    break
        # xmlsec.addIDs(doc, doc.getRootElement(), ('ID'))
    else:
        for ch_node in doc.getRootElement().get_children():
            if ch_node.name == 'Issuer':
                sibling = ch_node
                break
    # Ajout du noeud <Signature/> au document
    if sibling:
        sibling.addNextSibling(sign_node)
    else:
        parent.addChild(sign_node)
    # Initialisation du contexte de signature
    dsig_ctx = xmlsec.DSigCtx()
    # Ajout d'une référence si besoin
    if reference:
        ref_node = sign_node.addReference(xmlsec.transformSha1Id(), None, "#%s" % reference, None)
        dsig_ctx.enabledReferenceUris = xmlsec.TransformUriTypeSameDocument
        dsig_ctx.keyInfoReadCtx.retrievalMethodCtx.enabledUris = xmlsec.TransformUriTypeSameDocument
    else:
        ref_node = sign_node.addReference(xmlsec.transformSha1Id(), None, None, None)
    # Ajout du noeurd 'enveloped transform'
    ref_node.addTransform(xmlsec.transformEnvelopedId())
    ref_node.addTransform(xmlsec.transformExclC14NId())
    key_file = find_key_file(cert_file)
    key = xmlsec.cryptoAppKeyLoad(key_file, xmlsec.KeyDataFormatPem, None, None, None)
    # chargement certificat et clé
    if xmlsec.cryptoAppKeyCertLoad(key, cert_file, xmlsec.KeyDataFormatPem) < 0:
        cleanup(doc, dsig_ctx)
        raise Exception("Error: failed to load pem certificate \"%s\"" % cert_file)
    if key.setName(key_file) < 0:
        cleanup(doc, dsig_ctx)
        raise Exception("Error: failed to set key name for key from \"%s\"" % key_file)
    if key_info:
        # Ajout de noeuds <dsig:KeyInfo/> et <dsig:X509Data/> si demandé
        key_info_node = sign_node.ensureKeyInfo(None)
        key_info_node.addX509Data()

    dsig_ctx.signKey = key
    # Signature du template
    if dsig_ctx.sign(sign_node) < 0:
        cleanup(doc, dsig_ctx)
        raise Exception("Error: signature failed")
    # renvoi du document sous forme de chaine
    data = doc.serialize()
    cleanup(doc, dsig_ctx)
    return xmlsec.HrefRsaSha1, data

def verify_doc(cert_file, xml_src):
    """Verifies XML signature in xml_file.
    Returns 0 on success or a negative value if an error occurs.
    """
    init_libs()
    status = False
    doc = libxml2.parseDoc(xml_src)

    ctx = libxml2.newValidCtxt()
    ctx.setValidityErrorHandler(dtd_log, dtd_log)
    dtd = libxml2.parseDTD(None, "templates/saml.dtd")
    # validation du document pour permettre le traitement des IDs
    doc.validateDtd(ctx, dtd)

    if doc is None or doc.getRootElement() is None:
        log.msg(_("Error: unable to parse request"))
        return cleanup(doc)
    node = xmlsec.findNode(doc.getRootElement(),
                           xmlsec.NodeSignature, xmlsec.DSigNs)
    # Initialisation du contexte de signature
    dsig_ctx = xmlsec.DSigCtx()
    try:
        dsig_ctx.signKey = xmlsec.cryptoAppKeyLoad(cert_file, xmlsec.KeyDataFormatCertPem, None, None, None)
    except:
        log.msg(_("Error loading key from certificate %s") % cert_file)
        return cleanup(doc, dsig_ctx)
    if node is None:
        log.msg(_("Error: signature node not found in request"))
        return cleanup(doc, dsig_ctx)
    # Verification de la signature
    if dsig_ctx.verify(node) < 0:
        log.msg(_("Error verifying signature"))
        return cleanup(doc, dsig_ctx)
    if dsig_ctx.status == xmlsec.DSigStatusSucceeded:
        status = True
    return cleanup(doc, dsig_ctx, status)

def check_signed_doc(xml_src, cert_file):
    """signs an xml document"""
    res = None
    try:
        assert os.access(cert_file, os.R_OK)
        res = verify_doc(cert_file, xml_src)
    except  Exception:
        traceback.print_exc()
        return False, _("An error occured while checking signature")
    if res:
        return True, ""
    else:
        return False, _("invalid XML signature")

