#! /usr/bin/env python
# -*- coding: utf-8 -*-

###########################################################################
#
# Eole NG - 2007
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
#
# saml_message.py
#
# librairie de génération de requêtes SAML 2.0 spécifiques à eole-sso
# nécessite les librairies suivantes :
# - python-saml (Copyright (C) 2006 Google Inc - http://www.apache.org/licenses/LICENSE-2.0)
# - python xmlsec (http://pyxmlsec.labs.libre-entreprise.org)
#
###########################################################################

from saml2 import saml, samlp
from saml_utils import format_timestamp, available_contexts
from saml_crypto import sign_doc
import time, codecs
import xmlsec
import SOAPpy
from eolesso.util import gen_random_id
from config import AUTH_SERVER_ADDR, SERVER_IP_ADDR, encoding, STATEMENT_TIMEOUT, TIME_ADJUST

# missing definitions in saml2 library
NAMEID_FORMAT_TRANSIENT = 'urn:oasis:names:tc:SAML:2.0:nameid-format:transient'
NAMEID_FORMAT_EMAIL = 'urn:oasis:names:tc:SAML:2.0:nameid-format:email'
URN_PROTECTED_PASSWORD = available_contexts['URN_PROTECTED_PASSWORD']
URN_TIME_SYNC_TOKEN = available_contexts['URN_TIME_SYNC_TOKEN']
URN_PREVIOUS_SESSION = 'urn:oasis:names:tc:SAML:2.0:ac:classes:PreviousSession'
NAME_FORMAT_BASIC = 'urn:oasis:names:tc:SAML:2.0:attrname-format:basic'
CONSENT_IMPLICIT = 'urn:oasis:names:tc:SAML:2.0:consent:current-implicit'

def gen_assertion(user_id, attributes, creation_date, auth_instant, \
                  assertion_consumer, assertion_id, issuer, sp_ident, \
                  session_index, from_credentials, auth_class, \
                  client_addr, client_dns):
    """génère une assertion saml
    """
    # creating assertion
    assertion = saml.Assertion(version = "2.0", id = assertion_id, issue_instant = format_timestamp(creation_date + TIME_ADJUST))
    assertion.issuer = saml.Issuer(text = issuer)
    # assertion subject
    # name_id_format = "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" ??
    #                  "urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"  (cf SimpleSaml)
    # le format peut être défini dans les metadata (priorité : sp -> idp -> default)

    # CLEON // nameid-format communiqué dans les attriuts : <format>=<user_attribute>
    if "nameid-format" in attributes :
        (format,user_id)=attributes["nameid-format"][0].split("=")
        nid = saml.NameID(format = format, name_qualifier = issuer, sp_name_qualifier = sp_ident, text = user_id)
    else:
        nid = saml.NameID(format = NAMEID_FORMAT_TRANSIENT, name_qualifier = issuer, sp_name_qualifier = sp_ident, text = user_id)

    subj = saml.Subject(name_id = nid)
    # subject confirmation
    subj_conf = saml.SubjectConfirmation(method = saml.SUBJECT_CONFIRMATION_METHOD_BEARER)
    conf_data = saml.SubjectConfirmationData(not_on_or_after = format_timestamp(creation_date + TIME_ADJUST + STATEMENT_TIMEOUT), recipient = assertion_consumer)
    subj_conf.subject_confirmation_data = conf_data
    subj.subject_confirmation = subj_conf
    assertion.subject = subj
    assertion.conditions = gen_conditions(sp_ident, creation_date)
    assertion.authn_statement = gen_statement(auth_instant, session_index, \
                                from_credentials, auth_class, client_addr, client_dns)
    # attributes statements
    assertion.attribute_statement = gen_attributes(attributes)
    return assertion

def gen_attributes(attributes):
    """génère la liste des attributs à envoyer
    """
    attr_statement = saml.AttributeStatement()
    for attr_name, values in attributes.items():
        if type(values) != list:
            values = [values]
        attr = saml.Attribute(name = attr_name, name_format = NAME_FORMAT_BASIC)
        for value in values:
            if value != "":
                # si besoin : encoder en base64 et
                attr.attribute_value.append(saml.AttributeValue(text = value, extension_attributes = {'xsi:type':"xs:string"}))
        attr_statement.attribute.append(attr)
    return attr_statement

def gen_statement(auth_instant, session_index, from_credentials, \
                  auth_class, client_addr, client_dns):
    """génère une réponse AuthnStatement
    """
    if not from_credentials:
        class_ref = URN_PREVIOUS_SESSION
    # XXX FIXME : Forced to PROTECTED_PASSWORD/TIME_SYNC_TOKEN for fim even if authenticated from a previous session
    if auth_class == URN_TIME_SYNC_TOKEN:
        class_ref = URN_TIME_SYNC_TOKEN
    else:
        class_ref = URN_PROTECTED_PASSWORD
    # subject locality
    subj_loc = saml.SubjectLocality(address = client_addr, dns_name = client_dns)
    # auth context
    ctx = saml.AuthnContext(authn_context_class_ref = saml.AuthnContextClassRef(text = class_ref))
    statement = saml.AuthnStatement(authn_instant = format_timestamp(auth_instant + TIME_ADJUST),
                             session_index = "%s" % str(session_index),
                             subject_locality = subj_loc,
                             authn_context = ctx)
    return statement


def gen_conditions(entity_id, creation_date):
    """génère les conditions d'une assertion
    """
    audience = saml.Audience(text = entity_id)
    audience_restr = saml.AudienceRestriction(audience = audience)
    return saml.Conditions(not_before = format_timestamp(creation_date + TIME_ADJUST),
                           not_on_or_after = format_timestamp(creation_date + TIME_ADJUST + STATEMENT_TIMEOUT),
                           audience_restriction = audience_restr)
    # return saml.Conditions(not_on_or_after = format_timestamp(creation_date + STATEMENT_TIMEOUT), audience_restriction = audience_restr)


########################################################
#### Génération de requêtes/réponse d'authentification

def gen_response(response_id, authenticated, auth_instant, session_index, \
                 from_credentials, user_id, attributes, \
                 request_id, assertion_consumer, issuer, sp_ident, cert_file, \
                 auth_class, client_addr, client_dns):
    """génère une réponse saml
    """
    if type(user_id) != unicode:
        user_id = unicode(user_id, encoding)
    if type(issuer) != unicode:
        issuer = unicode(issuer, encoding)
    if type(sp_ident) != unicode:
        sp_ident = unicode(sp_ident, encoding)

    response = samlp.Response(extension_attributes = {'xmlns:xs':"http://www.w3.org/2001/XMLSchema",
                                                      'xmlns:saml':"urn:oasis:names:tc:SAML:2.0:assertion",
                                                      'xmlns:xsi':"http://www.w3.org/2001/XMLSchema-instance",
                                                      'xmlns:ds':"http://www.w3.org/2000/09/xmldsig#"}
                             )
    response.id = response_id
    if request_id:
        response.in_response_to = request_id
    response.version = saml.V2
    creation_date = time.time()
    response.issue_instant = format_timestamp(creation_date + TIME_ADJUST)
    response.destination = assertion_consumer
    response.consent = saml.CONSENT_UNSPECIFIED
    response.issuer = saml.Issuer(text = issuer)
    response.status = samlp.Status()

    if not authenticated:
        response.status.status_code = samlp.StatusCode(samlp.STATUS_AUTHN_FAILED)
        response.status.status_message = samlp.StatusMessage()
        response.status.status_detail = samlp.StatusDetail()
        assertion_id = None
    else:
        response.status.status_code = samlp.StatusCode(samlp.STATUS_SUCCESS)
        response.status.status_message = samlp.StatusMessage()
        response.status.status_detail = samlp.StatusDetail()
        assertion_id = gen_random_id('_')
        assertion = gen_assertion(user_id, attributes, creation_date, \
                                  auth_instant, assertion_consumer, assertion_id, \
                                  issuer, sp_ident, session_index, from_credentials, \
                                  auth_class, client_addr, client_dns)
        # pour fédération avec le GAR #26849
        assertion.subject.subject_confirmation.subject_confirmation_data.in_response_to = request_id
        response.assertion.append(assertion)
    xml_doc = response.ToString()
    # sign and return the document
    return sign_doc(xml_doc, cert_file, assertion_id)

def gen_request(manager, request_id, issuer, recipient, destination, attr_service_index, cert_file,
                force_auth=False, is_passive=False, consent=CONSENT_IMPLICIT,
                class_ref=URN_PROTECTED_PASSWORD, sign_request=False, comparison='exact'):
    """génère une requête de type AuthnRequest

    XXX non utilisé:

      subject: Subject element
      conditions: Conditions element
      requested_authn_context: RequestedAuthnContext element
      scoping: Scoping element
      is_passive: IsPassive attribute
      assertion_consumer_service_url: AssertionConsumerServiceURL element
      protocol_binding: ProtocolBinding element
      provider_name: ProviderName element

    """
    if type(issuer) != unicode:
        issuer = unicode(issuer, encoding)
    if type(recipient) != unicode:
        recipient = unicode(recipient, encoding)
    request = samlp.AuthnRequest(id = request_id)
    request.version = saml.V2
    creation_date = time.time()
    request.issue_instant = format_timestamp(creation_date + TIME_ADJUST)
    request.not_on_or_after = format_timestamp(creation_date + TIME_ADJUST + STATEMENT_TIMEOUT)
    request.issuer = saml.Issuer(text = issuer)
    request.destination = destination
    request.name_id_policy = samlp.NameIDPolicy(format=NAMEID_FORMAT_TRANSIENT)
    request.consent = consent
    request.issuer = saml.Issuer(text = issuer)
    request.attribute_consuming_service_index = attr_service_index
    request.requested_authn_context = samlp.RequestedAuthnContext(comparison=comparison, authn_context_class_ref=saml.AuthnContextClassRef(text = class_ref))
    manager.update_saml_msg(request_id,
                            {'class_ref':class_ref,
                             'comparison':comparison,
                             'force_auth':force_auth,
                             'is_passive':is_passive,
                             'destination':destination,
                             'recipient':recipient,
                             'issuer':issuer})
    if force_auth:
        request.force_authn = 'true'
    else:
        request.force_authn = 'false'
    if is_passive:
        request.is_passive = 'true'
    else:
        request.is_passive = 'false'
    request.conditions = gen_conditions(destination, creation_date)

    xml_doc = request.ToString()
    # sign and return the document
    if sign_request:
        return sign_doc(xml_doc, cert_file)
    else:
        return None, xml_doc

#################################################
#### Génération de réponses/requêtes de logout

def gen_logout_response(request_id, response_id, issuer, destination, cert_file, logout_status, sign = True, status_msg = None):
    """génère une réponse de confirmation (ou non) à une demande de déconnexion
    """
    if type(issuer) != unicode:
        issuer = unicode(issuer, encoding)
    logout_msg = {samlp.STATUS_SUCCESS: _('Single logout success'),
                  samlp.STATUS_PARTIAL_LOGOUT: _('Partial single logout success'),
                  samlp.STATUS_UNKNOWN_PRINCIPAL: _('Unknown user session'),
                  samlp.STATUS_UNSUPPORTED_BINDING: _('Unsuported binding'),
                  samlp.STATUS_AUTHN_FAILED: _('Authentication failure'),
                 }
    response = samlp.LogoutResponse(in_response_to = request_id, id = response_id)
    response.version = saml.V2
    creation_date = time.time()
    response.issue_instant = format_timestamp(creation_date + TIME_ADJUST)
    response.not_on_or_after = format_timestamp(creation_date + TIME_ADJUST + STATEMENT_TIMEOUT)
    response.issuer = saml.Issuer(text = issuer)
    response.destination = destination
    response.status = samlp.Status()
    response.status.status_code = samlp.StatusCode(logout_status)
    if not status_msg:
        status_msg = str(logout_msg[logout_status])
    response.status.status_message = samlp.StatusMessage(text = status_msg)
    response.status.status_detail = samlp.StatusDetail()
    xml_doc = response.ToString()
    if not sign:
        # include assertion in a string and sign the string
        return xmlsec.HrefRsaSha1, xml_doc
    else:
        # sign within xml
        return sign_doc(xml_doc, cert_file)

def gen_logout_request(request_id, session_index, user_id, issuer, sp_ident, destination, cert_file=None, sign=True):
    """génère une demande de déconnexion
    """
    if type(user_id) != unicode:
        user_id = unicode(user_id, encoding)
    if type(issuer) != unicode:
        issuer = unicode(issuer, encoding)
    if type(sp_ident) != unicode:
        sp_ident = unicode(sp_ident, encoding)
    request = samlp.LogoutRequest(id = request_id)
    request.version = saml.V2
    creation_date = time.time()
    request.issue_instant = format_timestamp(creation_date + TIME_ADJUST)
    request.not_on_or_after = format_timestamp(creation_date + TIME_ADJUST + STATEMENT_TIMEOUT)
    request.issuer = saml.Issuer(text = issuer)
    request.destination = destination
    request.session_index = samlp.SessionIndex(text = session_index)
    request.name_id = saml.NameID(format = NAMEID_FORMAT_TRANSIENT, text = user_id, name_qualifier = issuer, \
                                  sp_name_qualifier = sp_ident)
    request.consent = saml.CONSENT_UNSPECIFIED
    request.issuer = saml.Issuer(text = issuer)
    xml_doc = request.ToString()
    if not sign:
        # include assertion in a string and sign the string
        return xmlsec.HrefRsaSha1, xml_doc
    else:
        # sign within xml
        return sign_doc(xml_doc, cert_file)

def gen_status_response(cert_file, issuer, destination, code, request_id=None, msg=None):
    """génère une reponse avec un status d'erreur en réponse à une demande particulière

    codes disponibles :

    samlp.STATUS_AUTHN_FAILED
    samlp.STATUS_NO_PASSIVE
    samlp.STATUS_REQUEST_DENIED
    samlp.STATUS_RESOURCE_NOT_RECOGNIZED
    samlp.STATUS_UNKNOWN_PRINCIPAL
    samlp.STATUS_INVALID_ATTR_NAME_OR_VALUE
    samlp.STATUS_NO_SUPPORTED_IDP
    samlp.STATUS_REQUEST_UNSUPPORTED
    samlp.STATUS_RESPONDER
    samlp.STATUS_UNSUPPORTED_BINDING
    samlp.STATUS_INVALID_NAMEID_POLICY
    samlp.STATUS_PARTIAL_LOGOUT
    samlp.STATUS_REQUEST_VERSION_DEPRECATED
    samlp.STATUS_SUCCESS
    samlp.STATUS_VERSION_MISMATCH
    samlp.STATUS_NO_AUTHN_CONTEXT
    samlp.STATUS_PROXY_COUNT_EXCEEDED
    samlp.STATUS_REQUEST_VERSION_TOO_HIGH
    samlp.STATUS_TOO_MANY_RESPONSES
    samlp.STATUS_NO_AVAILABLE_IDP
    samlp.STATUS_REQUESTER
    samlp.STATUS_REQUEST_VERSION_TOO_LOW
    samlp.STATUS_UNKNOWN_ATTR_PROFILE
    """
    response_id = gen_random_id('_')
    issue_instant = time.time()
    st = samlp.Status(samlp.StatusCode(value=code), samlp.StatusMessage(text=msg))
    response = samlp.Response(id = response_id,
                              in_response_to = request_id,
                              version = saml.V2,
                              issue_instant = format_timestamp(issue_instant + TIME_ADJUST),
                              destination = destination,
                              consent = saml.CONSENT_UNSPECIFIED,
                              issuer = saml.Issuer(text = issuer),
                              status = st)
    xml_doc = response.ToString()
    # sign and return the document
    return sign_doc(xml_doc, cert_file)

###########################################################################
## Réponse auth/requête logout SAML 1.1 pour utilisation en mode CAS/SAML

def gen_saml11_response(ticket, auth_instant, from_url, issuer, username, user_infos):
    """génère une réponse de type SAML11 pour compatibilité avec les clients CAS récents
    """
    response_tmpl = open('templates/cas_client_saml_response.tmpl').read()
    # on force l'encodage en utf-8 si nécessaire
    if type(username) != unicode:
        username = unicode(username, encoding)
    if type(issuer) != unicode:
        issuer = unicode(issuer, encoding)
    if type(user_infos) != unicode:
        user_infos = unicode(user_infos, encoding)
    data = {'attributes':user_infos}
    data['response_id'] = gen_random_id('_')
    data['assertion_id'] = gen_random_id('_')
    data['recipient'] = from_url
    data['issuer'] = issuer
    # auth_instant : date de création de la session sso de l'utilisateur
    data['auth_instant'] = format_timestamp(auth_instant + TIME_ADJUST)
    issue_instant = time.time()
    data['issue_instant'] = format_timestamp(issue_instant + TIME_ADJUST)
    data['not_before'] = format_timestamp(issue_instant + TIME_ADJUST)
    data['not_after'] = format_timestamp(issue_instant + TIME_ADJUST + STATEMENT_TIMEOUT)
    data['audience'] = from_url
    data['username'] = username
    response_tmpl = unicode(response_tmpl, encoding)
    saml_response = response_tmpl % data
    # entête SOAP dans le template suite à problèmes avec pronote.net
    # soap_response = SOAPpy.buildSOAP('\n%s') % saml_response
    return saml_response

def gen_saml11_error(error_msg, from_url):
    """génère une réponse de type SAML11 pour compatibilité avec les clients CAS récents
    """
    response_tmpl = open('templates/cas_client_saml_error.tmpl').read()
    data = {}
    data['response_id'] = gen_random_id('_')
    if type(from_url) != unicode:
        from_url = unicode(from_url, encoding)
    data['recipient'] = from_url
    issue_instant = time.time()
    data['issue_instant'] = format_timestamp(issue_instant + TIME_ADJUST)
    if type(error_msg) != unicode:
        error_msg = unicode(error_msg, encoding)
    data['message'] = error_msg
    response_tmpl = unicode(response_tmpl, encoding)
    saml_response = response_tmpl % data
    return saml_response

def gen_cas_logout_request(ticket, use_SOAP):
    request_tmpl = codecs.open('templates/cas_client_logout_request.tmpl', 'r', encoding).read()
    # on force l'encodage en utf-8 si nécessaire
    data = {'session_index': ticket}
    data['req_id'] = gen_random_id('_')
    data['issue_instant'] = format_timestamp(time.time())
    data['nameid'] = "@NOT_USED@"
    request = request_tmpl.format(**data)
    if use_SOAP:
        request = SOAPpy.buildSOAP('\n%s') % request
    return request
