# -*- coding: UTF-8 -*-
###########################################################################
#
# Eole NG - 2007
# Copyright Pole de Competence Eole  (Ministere Education - Academie Dijon)
# Licence CeCill  cf /root/LicenceEole.txt
# eole@ac-dijon.fr
#
# ticketcache.py
#
# librairie de gestion de session et de tickets d'authentification
# fonctionnement inspiré de la librairie cas :
# http://www.ja-sig.org/products/cas/
#
###########################################################################

from twisted.internet import reactor
from eolesso.util import gen_ticket_id


class InvalidSession(Exception):
    """levée quand un détecte une session invalide"""

class TicketCache:
    """classe gérant les sessions actives"""

    def __init__(self, timeout, address, ticket_prefix="LT", renew_tickets=False):
        self.timeout = timeout
        self.ticket_cache = {}
        self.prefix = ticket_prefix
        self.renew_tickets = renew_tickets
        self.address = address

    def end_session(self,session_id,cancel=False):
        """supprime une session du cache"""
        if self.ticket_cache.has_key(session_id):
            # on annule le callback de timeout du ticket
            if cancel:
                callID = self.ticket_cache[session_id][0]
                callID.cancel()
            # et on le supprime
            del self.ticket_cache[session_id]

    def gen_ticket_id(self):
        """à redéfinir dans la classe dérivée"""
        return gen_ticket_id(self.prefix, self.address)

    def add_session(self, data,session_id = None):
        """démarre une session et stocke les données correspondantes"""
        # on génère un identifiant de session (bidon)
        if session_id == None:
            session_id = self.gen_ticket_id()
        self.ticket_cache[session_id] = []
        # on donne une limite de temps pour la session
        callID = reactor.callLater(self.timeout, self.end_session, session_id)
        self.ticket_cache[session_id] = [callID,data]
        # on retourne l'identifiant de session
        return session_id

    def count(self):
        return len(self.ticket_cache)

    def reset_timer(self, session_id):
        """réinitialise la durée d'expiration de la session"""
        callID = self.ticket_cache[session_id][0]
        callID.cancel()
        # nouveau timeout
        callID=reactor.callLater(self.timeout, self.end_session, session_id)
        self.ticket_cache[session_id][0] = callID

    def get_session_info(self, session_id):
        """renvoie les infos connues sur la session en cours"""
        if self.ticket_cache.has_key(session_id):
            data = self.ticket_cache[session_id][1]
            if not self.renew_tickets:
                self.end_session(session_id,True)
            # on retourne les données stockées
            return data
        else:
            raise InvalidSession(session_id)

    def validate_session(self, session_id):
        """vérifie la validité d'une session"""
        # on regarde si la session a été délivrée par ce cache
        if session_id != '':
            session_address = "-".join(session_id.split('-')[1:-1])
            if session_address == self.address:
                if self.ticket_cache.has_key(session_id):
                    return True
        return False

    def get_session():
        pass


class SamlMsgCache(dict):

    def __init__(self, maxlen = 100):
        """
        stocke les derniers identifiants de messages Saml (ou assertion/requêtes/...) traités
        permet de limiter les risques de DOS par rejeu rapide d'un même message
        maxlen : nombre maximum d'identifiant à garder en mémoire
        """
        self.maxlen = maxlen
        self.ordered_list = []

    def get_msg(self, msg_id):
        """checks if an id already exists in this cache"""
        if msg_id in self:
            return self[msg_id]
        return None

    def add(self, msg_id, data=None):
        """adds an id in the cache, if maxlen is reached,
        destroy the oldest entry
        """
        if len(self) == self.maxlen:
            oldest = self.ordered_list.pop(0)
            del(self[oldest])
        self.ordered_list.append(msg_id)
        self[msg_id] = data

