#! /usr/bin/env python
# -*- coding: utf-8 -*-

###########################################################################
#
# Eole NG - 2008
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
#
# saml_utils.py
#
# utilitaires pour la gestion des requêtes SAML et des données metadata
#
###########################################################################

import os, datetime, time, codecs, calendar
from saml_crypto import check_signed_request, check_signed_doc
from M2Crypto import X509
from xml.etree import ElementTree
from cgi import escape
from urllib import quote, unquote
import base64, re, zlib, StringIO, traceback
from config import METADATA_DIR, IDP_IDENTITY, AUTH_FORM_URL, encoding, LC_ALL, TIME_ADJUST, DEFAULT_MIN_CONTEXT, RNE
from saml2 import samlp, saml
import SOAPpy
from page import log, trace

from eolesso.util import is_true

delta_adjust = datetime.timedelta(seconds=TIME_ADJUST)

# contextes d'authentification disponibles en mode fournisseur d'identité
available_contexts = {'URN_PROTECTED_PASSWORD':'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport',
                      'URN_TIME_SYNC_TOKEN':'urn:oasis:names:tc:SAML:2.0:ac:classes:TimeSyncToken'}
# indicateur du niveau de sécurité des différents contextes
context_levels = {'urn:oasis:names:tc:SAML:2.0:ac:classes:Password':0,
                  'urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport':10,
                  'urn:oasis:names:tc:SAML:2.0:ac:classes:TimeSyncToken':20}

default_idp_options = {'attribute_set':'default',
                       'allow_idp':'true',
                       'allow_idp_initiated':'true'}

def check_required_contexts(comparison, requested_contexts, available=available_contexts.values()):
    """méthode de vérification d'un demande de contexte
    Pour simplifier, on ne prend en compte que les contextes
    gérés par le serveur (Idp).
    types de comparaison définis dans SAMLv2 : minimum/maximum/exact/better
    """
    if comparison == 'exact':
        for av_ctx in available:
            if av_ctx in requested_contexts:
                return True, av_ctx
    else:
        # dans les autres cas, tous les contextes indiqués doivent être connus
        for list_ctx in (available, requested_contexts):
            for req_ctx in list_ctx:
                if req_ctx not in context_levels:
                    return False, _('Unknown authentication context ({0})').format(req_ctx)
    if comparison == 'minimum':
        min_req_level = min([context_levels[req] for req in requested_contexts])
        for av_ctx in available:
            # on utilise le premier contexte ayant au moins le même niveau que le minimum demandé
            if context_levels[av_ctx] >= min_req_level:
                return True, av_ctx
    elif comparison == 'maximum':
        # on utilise le plus petit des contextes demandés
        max_required = min([context_levels[req] for req in requested_contexts])
        max_ctx_ok = None
        # on recherche le plus grand contexte dispo qui ne dépasse pas ceux requis
        for av_ctx in available:
            if context_levels[av_ctx] <= max_required:
                if max_ctx_ok is None or context_levels[av_ctx] > context_levels[max_ctx_ok]:
                    max_ctx_ok = av_ctx
        if max_ctx_ok:
            return True, max_ctx_ok
    elif comparison == 'better':
        # on retourne le premier contexte supérieur à tous ceux demandés
        min_req_level = max([context_levels[req] for req in requested_contexts])
        min_ctx_ok = None
        for av_ctx in available:
            if context_levels[av_ctx] > min_req_level:
                if min_ctx_ok is None or context_levels[av_ctx] < context_levels[min_ctx_ok]:
                    min_ctx_ok = av_ctx
        if min_ctx_ok:
            return True, min_ctx_ok
    return False, _('No matching authentication context')

# exceptions
class Redirect(Exception):
    def __init__(self, location):
        super(Redirect, self).__init__(location)
        self.location = location

class InternalError(Exception):

    def __init__(self, reason):
        self.reason = reason
        self.code = 'INTERNAL_ERROR'
        super(InternalError, self).__init__(reason)

    def as_xml(self):
        return """%s:%s""" % (self.code, escape(self.reason))

def split_form_arg(arg_data, max_len=80):
    """split arg_data into several lines of length max_len"""
    split_arg = []
    arg_len = len(arg_data)
    done = 0
    while done < arg_len:
        split_arg.append(arg_data[:max_len])
        arg_data=arg_data[max_len:]
        done += max_len
    return "\n".join(split_arg)

# outils de formattage des dates

def date_from_string(date_string):
    "convertit une date au format SAML en objet datetime"
    # suppression des millisecondes (non prises en compte par datetime)
    if '.' in date_string:
        date_string = date_string[:date_string.rindex('.')] + 'Z'
    return datetime.datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ")

def format_timestamp(timest = None):
    """formatte un timestamp au format gmt
    """
    if timest is None:
        timest = time.time()
    date = datetime.datetime.utcfromtimestamp(timest)
    date_string = date.isoformat()
    if date.microsecond != 0:
        date_string = date_string[:date_string.rindex('.')]
    return "%sZ" % date_string

def check_date(date_string, not_before = None):
    now = datetime.datetime.utcnow()
    date = date_from_string(date_string)
    if date > now:
        if (not_before is None) or (now > not_before):
            return True
    return False

# outils de decodage/encodage des requetes saml (utilisation de urlencode, base64 et gz deflate/inflate)
def htc(match):
    """retourne un élément encodé sous forme de chaine
    """
    return chr(int(match.group(1), 16))

def urldecode(url):
    """décode une url encodée
    """
    rex = re.compile('%([0-9a-hA-H][0-9a-hA-H])', re.M)
    return rex.sub(htc, url)

def decode_request(request, compress = False):
    """décode et décompresse une requête
    """
    if compress:
        request = urldecode(request)
    request = base64.decodestring(request)
    if compress:
        # MAX_WBTS est nécessaire pour la compatibilité avec les méthodes PHP gzinflate/deflate
        request = zlib.decompress(request, -zlib.MAX_WBITS)
    return request

def encode_idp_cookie(data_list):
    """encode une liste d'identifiants pour utilisation dans un cookie d'IDPDiscovery
    """
    cookie_val = " ".join([base64.encodestring(data) for data in data_list])
    return quote(cookie_val)

def encode_request(request, compress = False):
    """encode et compresse si besoin une requête
    """
    if compress:
        # compression de la requête à l'aide d'un objet zlib compressor (zlib.compress ne gère pas le flag MAX_WBITS)
        buf = StringIO.StringIO(request)
        cmpr = zlib.compressobj(zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
        while 1:
            request = buf.read(1024 * 8)
            if not request:
                break
            request = cmpr.compress(request)
        request = cmpr.flush()
    request = base64.encodestring(request).replace('\n', '').strip()
    return request

# outils de gestion des fichiers metadata
def find_tag(node, tag, namespace = "urn:oasis:names:tc:SAML:2.0:metadata"):
    """Recherche un élément dans un noeud donné
    """
    nodes = node.findall(".//{%s}%s" % (namespace, tag))
    if nodes == []:
        return node.findall(".//%s" % tag)
    else:
        return nodes

def find_issuer(xml_string):
    """cherche le noeud Issuer
    """
    doc = ElementTree.fromstring(xml_string)
    for issuer in find_tag(doc, 'Issuer', 'urn:oasis:names:tc:SAML:2.0:assertion'):
        return issuer.text.encode(encoding)
    return None

def get_metadata(identifier):
    """renvoie un dictionnary décrivant les métadonnées pour l'entité spécifiée"""
    sp_meta = {}
    meta_file = os.path.join(METADATA_DIR, "%s.xml" % identifier.replace(os.sep,'_'))
    if not os.path.isfile(meta_file):
        raise InternalError(_('No metada found for identity {0}').format(identifier))
    doc = ElementTree.parse(meta_file)
    # id de l'entité
    for entity in find_tag(doc, 'EntityDescriptor'):
        if 'entityID' in entity.attrib:
            sp_meta['entityID'] = entity.get('entityID')
        if 'ID' in entity.attrib:
            sp_meta['entity_local_id'] = entity.get('ID')
        # on conserve également les attributs supplémentaires d'entityDescriptor
        for attr_name in entity.attrib:
            if attr_name not in ('entityID', 'entity_local_id'):
                sp_meta[attr_name] = entity.get(attr_name)
    if not 'entityID' in sp_meta:
        if 'entityID' in doc.getroot().attrib:
            sp_meta['entityID'] = doc.getroot().get('entityID')
            if 'ID' in doc.getroot().attrib:
                sp_meta['entity_local_id'] = doc.getroot().get('ID')
            # on conserve également les attributs supplémentaires d'entityDescriptor
            for attr_name in doc.getroot().attrib:
                if attr_name not in ('entityID', 'entity_local_id'):
                    sp_meta[attr_name] = doc.getroot().get(attr_name)
    if type(sp_meta.get('entityID','')) == unicode:
        sp_meta['entityID'] = sp_meta['entityID'].encode(encoding)
    if type(sp_meta.get('entity_local_id','')) == unicode:
        sp_meta['entity_local_id'] = sp_meta['entity_local_id'].encode(encoding)
    # chargement des descriptions de service provider / identity provider
    for ent_type in ('SPSSODescriptor', 'IDPSSODescriptor'):
        data = {}
        # informations générales
        for desc_node in find_tag(doc, ent_type):
            if 'WantAuthnRequestsSigned' in desc_node.attrib:
                data['WantAuthnRequestsSigned'] = desc_node.get('WantAuthnRequestsSigned')
            if 'AuthnRequestsSigned' in desc_node.attrib:
                data['AuthnRequestsSigned'] = desc_node.get('AuthnRequestsSigned')
            if 'WantAssertionsSigned' in desc_node.attrib:
                data['WantAssertionsSigned'] = desc_node.get('WantAssertionsSigned')
            if 'protocolSupportEnumeration' in desc_node.attrib:
                data['protocolSupportEnumeration'] = desc_node.get('protocolSupportEnumeration')
            # informations éventuelles sur le nameID
            for entity in find_tag(desc_node, 'NameIDFormat'):
                nameids = data.get('NameIDFormat', [])
                nameids.append(entity.text.encode(encoding))
                data['NameIDFormat'] = nameids
            # recherche de la description des urls de service (endpoints)
            for endpoint in ('AssertionConsumerService', 'SingleLogoutService', 'IDPSSODescriptor', 'SingleSignOnService'):
                for entity in find_tag(desc_node, endpoint):
                    resp_loc = ""
                    is_default = False
                    index = None
                    if is_true(entity.get('isDefault')):
                        is_default = True
                    if 'index' in entity.attrib:
                        try:
                            index = int(entity.get('index'))
                        except ValueError:
                            index = None
                    if 'ResponseLocation' in entity.attrib:
                        # Les réponses doivent être envoyées sur un url différente
                        resp_loc = entity.get('ResponseLocation')
                    if 'Location' in entity.attrib:
                        data_endp = data.get(endpoint, [])
                        data_endp.append((entity.get('Location'), entity.get('Binding'), resp_loc, is_default, index))
                        data[endpoint] = data_endp

            # recherche des certificats définis pour la signature
            for entity in find_tag(desc_node, 'KeyDescriptor'):
                if 'use' in entity.attrib:
                    if entity.get('use') != 'signing':
                        continue
                for cert_data in find_tag(entity, 'X509Certificate', 'http://www.w3.org/2000/09/xmldsig#'):
                    data['SignCert'] = cert_data.text.encode(encoding)
            # recherche des informations sur les attributs nécessaires au fonctionnement du service
            attr_consuming_services = {}
            for attr_serv in find_tag(desc_node, 'AttributeConsumingService'):
                index = attr_serv.get('index')
                opt_attr = []
                required_attr = []
                for attr in find_tag(attr_serv, 'RequestedAttribute'):
                    if 'isRequired' in attr.attrib and is_true(attr.get('isRequired')):
                        required_attr.append(attr.get('Name'))
                    else:
                        opt_attr.append(attr.get('Name'))
                if 'isDefault' in attr_serv.attrib and is_true(attr_serv.get('isDefault')):
                    attr_consuming_services[int(index)] = (True, required_attr, opt_attr)
                else:
                    attr_consuming_services[int(index)] = (False, required_attr, opt_attr)
            data['AttributeConsumingServices'] = attr_consuming_services
        sp_meta[ent_type] = data
    return sp_meta

def gen_metadata(manager, server_url, cert_file):
    """génère une desciption en XML des métadonnées du serveur (idp/sp)"""
    # données dynamiques
    logout_endpoint = "%s/logout" % server_url
    sso_endpoint = "%s/saml" % server_url
    consumer_endpoint = "%s/saml/acs" % server_url
    # chargement du template des metadata
    meta_file = 'templates/attr_service.tmpl'
    attr_service_tmpl = codecs.open(meta_file, 'r', encoding).read()
    # description des sets d'attributs (attributes_consuming_services)
    attr_consuming_services = []
    # on force le jeu par défaut comme premier jeu disponible
    set_names = []
    for set_name in manager.attribute_sets.keys():
        if set_name == 'default':
            set_names.insert(0, set_name)
        else:
            set_names.append(set_name)
    # on regarde si le jeu d'attributs par défaut est redéfini
    default_set = manager.associations.get('default', {}).get('attribute_set', 'default')
    for set_name in set_names:
        index = manager.attribute_sets_ids[set_name]
        attribute_set = manager.attribute_sets[set_name]
        req_attrs = []
        opt_attrs = []
        try:
            xml_lang = LC_ALL[:2]
        except:
            xml_lang = "fr"
        if set_name == default_set:
            default = 'true'
        else:
            default = 'false'
        if set_name == 'default':
            # jeu d'attribut fourni par défaut (Eole)
            set_descr = _('default attribute set')
        else:
            set_descr = "{0} {1}".format(_('custom attribute set'), set_name)
        set_descr = unicode(set_descr, encoding)
        for attr_type, attributes in attribute_set.items():
            for attr in attributes:
                if attr_type == 'optional_attrs':
                    attr_descr = u"{0}: {1}".format(_("optional attribute"), attr)
                    opt_attrs.append(u"""<RequestedAttribute FriendlyName="%s" Name="%s" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic" isRequired="false"/>""" % (attr_descr, attr))
                else:
                    attr_descr = u"{0}: {1}".format(_("required attribute"), attr)
                    req_attrs.append(u"""<RequestedAttribute FriendlyName="%s" Name="%s" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic" isRequired="true"/>""" % (attr_descr, attr))
        req_attrs.extend(opt_attrs)
        attrs = u'\n\t\t\t'.join(req_attrs)
        attr_consuming_services.append(attr_service_tmpl.format(index, default, xml_lang, set_name, xml_lang, set_descr, attrs))
    # lecture du certificat
    cert = X509.load_cert(cert_file)
    cert_data = []
    for line in cert.as_pem().split('\n'):
        if '-----BEGIN CERTIFICATE-----' not in line and '-----END CERTIFICATE-----' not in line:
            cert_data.append(line)
    cert_data = ''.join(cert_data).strip()
    # attributs supplémentaires pour la balise entityDesciptor
    extra_attrs = [""]
    if RNE:
        extra_attrs.append('uaj="{0}"'.format(RNE))
    # chargement du template des metadata
    meta_file = 'templates/metadata.tmpl'
    md_tmpl = open(meta_file).read().strip()
    data = md_tmpl % (IDP_IDENTITY, " ".join(extra_attrs), cert_data, sso_endpoint, sso_endpoint, logout_endpoint, logout_endpoint, \
            cert_data, logout_endpoint, logout_endpoint, consumer_endpoint, consumer_endpoint, ''.join(attr_consuming_services))
    return data

def get_attributes(sp_meta, index = None):
    """Recherche les attributs requis/optionnels spécifiés dans les métadonnées
    """
    required = []
    optionnal = []
    sp_meta = sp_meta['SPSSODescriptor']
    if len(sp_meta.get('AttributeConsumingServices', [])) == 0:
        return required, optionnal
    if index and int(index) in sp_meta['AttributeConsumingServices']:
        # si un index précis est spécifié, on cherche le service correspondant
        default, required, optionnal = sp_meta['AttributeConsumingServices'][int(index)]
    else:
        # sinon on utilise le service par défaut (ou celui ayant le plus petit index le cas échéant)
        for data in sp_meta['AttributeConsumingServices'].values():
            if data[0] == True:
                # service par défaut
                return data[1], data[2]
        # pas de défaut trouvé, on prend le premier
        inds = sp_meta['AttributeConsumingServices'].keys()
        inds.sort()
        required = sp_meta['AttributeConsumingServices'][inds[0]][1]
        optionnal = sp_meta['AttributeConsumingServices'][inds[0]][2]
    return required, optionnal

def get_endpoint(sp_meta, service, ent_type = 'SPSSODescriptor', allowed_bindings = None, index = None):
    """récupère l'adresse et le binding ou envoyer les messages SAML
    """
    if allowed_bindings is None:
        allowed_bindings = [samlp.BINDING_HTTP_POST, samlp.BINDING_HTTP_REDIRECT]
    service_url = binding = response_url = ""
    if ent_type in sp_meta:
        for endp_data in sp_meta[ent_type].get(service, [("", "", "", False, None)]):
            serv_url, bind, response, isdefault, ind = endp_data
            if index and str(ind) != str(index):
                # index précis demandé, on passe au suivant
                continue
            # FIXME TODO : verifier le certificat de l'assertion consumer url (agriates/pki nationale)
            if not bind in allowed_bindings:
                continue
            # si cet endpoint a le bon index ou est l'endpoint par défaut (ou si on n'en a pas encore trouvé), on le sélectionne
            if isdefault == True or service_url == "" or index:
                service_url = serv_url
                binding = bind
                response_url = response
    if service_url == "":
        raise InternalError(u'%s : %s' % (sp_meta.get('entityID', _('service provider')), _('no usable endpoint for {0}').format(service)))
    if not binding in allowed_bindings:
        raise InternalError(u'%s : %s' % (sp_meta.get('entityID', _('service provider')), _('no suppported binding for {0}').format(service)))
    return binding, service_url, response_url or service_url

def extract_message(request, message_type, check_signature=True, compressed=True):
    """décode un message et vérifie sa signature si nécessaire
    """
    if request.method == 'POST' or compressed == False:
        xml_doc = decode_request(request.args[message_type][0])
    else:
        xml_doc = decode_request(request.args[message_type][0], True)
    # recherche de l'émetteur et de son éventuel certificat de signature
    sp_ident = find_issuer(xml_doc)
    cert_file = os.path.join(METADATA_DIR, 'certs', '%s.crt' % sp_ident.replace(os.sep,'_'))
    # vérification de la signature
    if check_signature:
        if request.method == 'GET':
            code, reason = check_signed_request(request.args, message_type, cert_file)
        else:
            code, reason = check_signed_doc(xml_doc, cert_file)
        if code:
            return True, xml_doc
        return False, reason
    else:
        return True, xml_doc

def process_auth_request(manager, request, check_signature=False, compressed=True):
    """Traite une demande d'authentification
    """

    # exemple de requête reçue (FIM Créteil)
    #
    # <samlp:AuthnRequest xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion" xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
    #        AssertionConsumerServiceIndex="0"
    #        Destination="https://...:8443/saml"
    #        ID="acbd1461df9a8c1a9c8923dc6283242a"
    #        IssueInstant="2012-07-23T13:20:15Z"
    #        ProviderName="SP Creteil - PIA"
    #        Version="2.0">
    #   <saml:Issuer>...</saml:Issuer>
    #   <samlp:NameIDPolicy AllowCreate="true" Format="urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"></samlp:NameIDPolicy> --> non géré
    #   <samlp:RequestedAuthnContext>
    #     <saml:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</saml:AuthnContextClassRef>
    #   </samlp:RequestedAuthnContext>
    # </samlp:AuthnRequest>


    allowed_bindings = [samlp.BINDING_HTTP_POST, samlp.BINDING_HTTP_REDIRECT]
    consumer_url = binding = req_id = None
    signature_ok, saml_req = extract_message(request, 'SAMLRequest', check_signature, compressed)
    if not signature_ok:
        # FIXME TODO: generate a saml negative saml response
        raise InternalError("signature error : %s" % str(saml_req))
    req = samlp.AuthnRequestFromString(saml_req)
    # vérification de l'émetteur et des métadonnées
    req_issuer = req.issuer.text.encode(encoding)
    assert req_issuer in manager.sp_meta , _('No metada found for identity {0}').format(req_issuer)
    sp_meta = manager.get_metadata(req_issuer)
    # Jeu d'attributs / pas géré actuellement. Les attributs doivent être gérés
    # au niveau du filtre d'attributs attribué à ce fournisseur de service.
    attr_service_index = req.attribute_consuming_service_index
    #Destination
    assert req.destination in ("%s%s" % (AUTH_FORM_URL, request.path), IDP_IDENTITY), \
    _('Invalid assertion destination')
    #IssueInstant
    instant = date_from_string(req.issue_instant)
    assert instant <= datetime.datetime.utcnow(), _('assertion issue instant is in future')
    #ForceAuthn
    force_auth = False
    if is_true(req.force_authn):
        force_auth = True
    #isPassive
    passive = False
    if is_true(req.is_passive):
        # mode passif : si non connecté, on ne tente pas d'établir une session
        passive = True
    if 'AssertionConsumerService' in sp_meta['SPSSODescriptor']:
        # méthode d'envoi de la réponse:
        #  définie soit par le couple ProtocolBinding/AssertionConsumerServiceURL soit par l'index AssertionConsumerIndex
        #  on essaye de trouver un endpoint correspondant dans les métadonnées de l'entité partenaire
        for endpoint in sp_meta['SPSSODescriptor']['AssertionConsumerService']:
            location, binding, resp_loc, is_default, index = endpoint
            if req.assertion_consumer_service_index:
                if str(index) == str(req.assertion_consumer_service_index.encode(encoding)):
                    # on utilise la destination correspondant à l'index demandé
                    consumer_url = resp_loc or location
                    break
            elif req.protocol_binding and binding == req.protocol_binding:
                # correspond au binding demandé
                consumer_url = resp_loc or location
                # vérification sur l'AssertionConsumerServiceURL
                if req.assertion_consumer_service_url:
                    assert req.assertion_consumer_service_url == consumer_url, \
                    _('Invalid assertion ConsumerService URL')
                break
            elif is_default or consumer_url is None:
                # pas de destination spécifiée, on utilise la destination par défaut
                # ou la première compatible si aucun défaut n'est défini
                if binding in allowed_bindings:
                    consumer_url = resp_loc or location

    # FIXME TODO: other attributes we should manage
    #Subject ? (asks for authentication information for a known user)
    #NameIDPolicy (AllowCreate) ?
    req_contexts = []
    comparison = 'exact'
    if req.requested_authn_context:
        # méthode de comparaison demandée, ou par défaut: comparaison exacte
        comparison = req.requested_authn_context.comparison or "exact"
        for class_ref in req.requested_authn_context.authn_context_class_ref:
            req_contexts.append(class_ref.text)
    req_id = req.id
    return req_issuer, consumer_url, binding, req_id, passive, force_auth, comparison, req_contexts

def check_subject_confirmation(subject_confirmations, now, request):
    """Vérifie qu'au moins un élément subject_confirmation est valide
    """
    for conf in subject_confirmations:
        try:
            if conf.method == saml.SUBJECT_CONFIRMATION_METHOD_BEARER:
                # Pour l'instant, le serveur ne sait pas émettre de requête d'authentification -> in_response_to doit être vide
                # assert conf.subject_confirmation_data.in_response_to is None
                # XXX FIXME : vérification de l'id de request correspondant
                assert conf.subject_confirmation_data.recipient == "%s%s" % (AUTH_FORM_URL, request.path), \
                _('Invalid subject confirmation recipient')
                # recipient ok, vérification des contraintes de temps
                not_before = conf.subject_confirmation_data.not_before
                if not_before:
                    assert now >= date_from_string(not_before) + delta_adjust, \
                    _('Assertion is not valid yet')
                not_after = conf.subject_confirmation_data.not_on_or_after
                if not_after:
                    assert now + delta_adjust <= date_from_string(not_after), \
                    _('Assertion has expired')
                # verification de l'adresse du client si nécessaire
                address = conf.subject_confirmation_data.address
                if address:
                    assert address == request.remoteAddr.host, \
                    _('subject confirmation address does not match client address')
                # sujet de l'assertion confirmé
                return True
        except:
            traceback.print_exc()
            return False
    # aucun sujet n'a pu être confirmé
    return False

def process_assertion(manager, request):
    """traite les données transmises dans une assertion SAML
    """
    signature_ok, saml_resp = extract_message(request, 'SAMLResponse')
    # SIGNATURE CHECK
    if not signature_ok:
        # FIXME TODO: generate a saml negative saml response
        raise InternalError("%s : %s" % (_('signature error'), str(saml_resp)))
    # GENERAL CHECKS
    now = datetime.datetime.utcnow()
    response = samlp.ResponseFromString(saml_resp)
    response_id = response.id
    # on vérifie que cette réponse n'a pas déjà été reçue
    if manager.replayed_saml_msg(response_id):
        return False, _('response message has been replayed ({0})').format(response_id)

    # XXX FIXME : l'issuer de la réponse peut être différent de l'issuer de l'assertion
    # resp_issuer = response.issuer.text
    in_response_to = response.in_response_to
    if in_response_to:
        # si l'attribut in_response_to existe, on vérifie qu'il
        # correspond bien à une requête en attente
        orig_request_data = manager.get_saml_msg(in_response_to)
        if orig_request_data is None:
            return False, _('corresponding request not found')
        if orig_request_data.get('response', None) is not None:
            # une réponse à déjà été reçue pour cette requête
            return False, _('response already received for request {0}').format(in_response_to)
        # vérification du niveau d'authentification fourni
    issue_instant = date_from_string(response.issue_instant)
    if issue_instant + delta_adjust > now:
        return False, _('Invalid date in response : {0}').format(response.issue_instant)
    # STATUS_CODE CHECK
    if response.status.status_code.value != samlp.STATUS_SUCCESS:
        status_msg = _("Authentication failure")
        if response.status.status_message is not None:
            status_msg = response.status.status_message.text.encode(encoding)
        return False, status_msg
    # recherche d'un assertion valide
    valid_assertions = {}
    for assertion in response.assertion:
        resp_attrs = {}
        try:
            # id
            assertion_id = assertion.id
            issuer = assertion.issuer.text.encode(encoding)
            log.msg(_('Checking SAML assertion issued by {0} ({1})').format(issuer, assertion_id))
            instant = date_from_string(assertion.issue_instant)
            assert instant + delta_adjust <= now
            # version
            assert assertion.version == saml.V2
            # subject nameid
            # XXX FIXME Vérifier que les entités correspondent ?
            # qualifier = assertion.subject.name_id.name_qualifier
            # sp_qualifier = assertion.subject.name_id.sp_name_qualifier
            name_id = assertion.subject.name_id.text.encode(encoding)
            assert check_subject_confirmation(assertion.subject.subject_confirmation, now, request)
            # le sujet est confirmé comme valide
            # verification des conditions de l'assertion
            if assertion.conditions.audience_restriction:
                assert (IDP_IDENTITY in [aud.audience.text.encode(encoding) for aud in assertion.conditions.audience_restriction]) or \
                       ("%s%s" % (AUTH_FORM_URL, request.path) in [aud.audience.text.encode(encoding) for aud in assertion.conditions.audience_restriction]), \
                       _('No valid recipient in Assertion')
            # authn_statement
            for statement in assertion.authn_statement:
                # récupération de l'index de session sur l'entité partenaire (necéssaire pour la gestion du single logout)
                session_index = statement.session_index
                auth_instant = statement.authn_instant
                assert date_from_string(auth_instant) + delta_adjust <= now, \
                _('Assertion statement instant is in future')
                # on convertit auth_instant en timetuple
                auth_instant = calendar.timegm(date_from_string(auth_instant).utctimetuple())
                # faire des vérifications sur l'adresse/nom dns et le type d'authentification ?
                class_ref = statement.authn_context.authn_context_class_ref.text
                if in_response_to:
                    # vérification du contexte d'authentification
                    req_comparison = orig_request_data.get('comparison', 'exact')
                    # on vérifie que la source du message correspond à la requête originale
                    req_recipient = orig_request_data.get('recipient', '').encode(encoding)
                    assert issuer == req_recipient, _('Assertion issuer should be {0}').format(req_recipient)
                    # niveau minimum demandé dans la requête d'origine
                    req_class_ref = orig_request_data.get('class_ref', None)
                    ctxt_ok, req_ctx = check_required_contexts(req_comparison, [req_class_ref], [class_ref])
                    assert ctxt_ok, "%s (%s)" % (_('Insufficient authentication context'), class_ref)
                else:
                    # réponse envoyée spontanément par un fournisseur d'identité
                    # on regarde si la fédération depuis ce fournisseur est autorisée
                    assert manager.check_federation_allowed(issuer, idp_initiated=True), \
                           "%s (%s)" % (_('no federation allowed for this entity'), _('idp initiated federation'))
                    # on récupère les options définies pour ce founisseur
                    idp_options = manager.get_federation_options(issuer)
                    req_class_ref = idp_options.get('req_context', DEFAULT_MIN_CONTEXT)
                    req_comparison = idp_options.get('comparison', 'minimum')
                    # on vérifie que la classe d'authentification spécifiée correspond
                    # au niveau minimum acceptable
                    ctxt_ok, req_ctx = check_required_contexts(req_comparison, [req_class_ref], [class_ref])
                    assert ctxt_ok, "%s (%s)" % (_('Insufficient authentication context'), class_ref)
                # vérification des adresses ? -> pb si l'utilisateur n'est pas sur
                # le même plan d'adressage que l'idp par rapport au fournisseur de service (nous)
                # dns_name = statement.subject_locality.dns_name
                # address = statement.subject_locality.address
            # attribute_statement
            for attr_stat in assertion.attribute_statement:
                for attribute in attr_stat.attribute:
                    # vérifier le format
                    # format = attribute.name_format
                    resp_attrs[attribute.name] = []
                    for val in attribute.attribute_value:
                        if val.text:
                            # la valeur peut être nul
                            resp_attrs[attribute.name].append(val.text.encode(encoding))
                    # resp_attrs[attribute.name] = [val.text.encode(encoding) for val in attribute.attribute_value]
            valid_assertions[assertion_id] = (session_index, issuer, resp_attrs, name_id, auth_instant, class_ref)
        except AssertionError, err:
            log.msg(_('Assertion rejected from {0}: {1}').format(issuer, str(err)))
            continue
        except:
            traceback.print_exc()
            continue
    if valid_assertions != {}:
        return True, valid_assertions

    return False, _('No valid assertion')

def get_saml11_artifact(xml_stream):
    """procceses a saml1 request and extracts Artifact value
    """
    soap_req = SOAPpy.parseSOAP(xml_stream)
    try:
        artifact = soap_req.Request.AssertionArtifact
    except:
        artifact = None
    return artifact
