# -*- coding: utf-8 -*-
#
##########################################################################
# eoleauth - cachedsession.py
# Copyright © 2013 Pôle de compétences EOLE <eole@ac-dijon.fr>
#
# License CeCILL:
#  * in french: http://www.cecill.info/licences/Licence_CeCILL_V2-fr.html
#  * in english http://www.cecill.info/licences/Licence_CeCILL_V2-en.html
#
# Based on http://flask.pocoo.org/snippets/109/
#
# minimal in memory session management used as fallback when redis
# is not yet configured
##########################################################################
import os
import sys
import pickle
import base64
import hmac
import hashlib
import random
import string

import datetime
from uuid import uuid4
from collections import OrderedDict

from werkzeug.datastructures import CallbackDict
from flask.sessions import SecureCookieSessionInterface, SessionMixin
from flask import current_app
from eoleauthlib.i18n import _

def _generate_sid(prefix):
    return "{0}{1}".format(prefix,str(uuid4()))

def _calc_hmac(body, secret):
    if sys.version_info[0] < 3:
        secret = secret.encode('utf8')
    else:
        if isinstance(secret, str):
            secret = secret.encode()
        body = body.encode()
    sec = hmac.new(secret, body, hashlib.sha1).digest()
    return base64.b64encode(sec)

class ManagedSession(CallbackDict, SessionMixin):
    def __init__(self, initial=None, sid=None, new=False, randval=None, hmac_digest=None):
        def on_update(self):
            self.modified = True

        CallbackDict.__init__(self, initial, on_update)
        self.sid = sid
        self.new = new
        self.modified = False
        self.randval = randval
        self.hmac_digest = hmac_digest

    def sign(self, secret):
        if not self.hmac_digest:
            self.randval = ''.join(random.sample(string.ascii_lowercase+string.digits, 20))
            self.hmac_digest = _calc_hmac('%s:%s' % (self.sid, self.randval), secret)

class SessionManager(object):
    def new_session(self):
        'Create a new session'
        raise NotImplementedError

    def exists(self, sid):
        'Does the given session-id exist?'
        raise NotImplementedError

    def remove(self, sid):
        'Remove the session'
        raise NotImplementedError

    def get(self, sid, digest):
        'Retrieve a managed session by session-id, checking the HMAC digest'
        raise NotImplementedError

    def get_map(self, map_key):
        'Retrieve a session by session-id, checking the HMAC digest'
        raise NotImplementedError

    def put(self, app, session):
        'Store a managed session'
        raise NotImplementedError

class CachingSessionManager(SessionManager):
    def __init__(self, parent, num_to_store):
        self.parent = parent
        self.num_to_store = num_to_store
        self._cache = OrderedDict()

    def _normalize(self):
        if len(self._cache) > self.num_to_store:
            current_app.logger.info(_("Session cache size: {0}, flushing 20%").format(len(self._cache)))
            while len(self._cache) > (self.num_to_store * 0.8):  # flush 20% of the cache
                self._cache.popitem(False)
            # TODO : Purge files older than N days or use cron task ?

    def new_session(self):
        session = self.parent.new_session()
        self._cache[session.sid] = session
        self._normalize()
        return session

    def remove(self, sid):
        self.parent.remove(sid)
        if sid in self._cache:
            del self._cache[sid]

    def exists(self, sid):
        if sid in self._cache:
            return True
        return self.parent.exists(sid)

    def get(self, sid, digest):
        session = None
        if sid in self._cache:
            session = self._cache[sid]
            if session.hmac_digest != digest:
                session = None

            # reset order in OrderedDict
            del self._cache[sid]

        if not session:
            session = self.parent.get(sid, digest)

        self._cache[sid] = session
        self._normalize()
        return session

    def get_map(self, map_key):
        return self.parent.get_map(map_key)

    def put(self, app, session):
        self.parent.put(app, session)
        if session.sid in self._cache:
            del self._cache[session.sid]
        self._cache[session.sid] = session
        self._normalize()

class FileBackedSessionManager(SessionManager):
    def __init__(self, path, secret, prefix="session:"):
        self.path = path
        self.secret = secret
        self.prefix = prefix
        if not os.path.exists(self.path):
            os.makedirs(self.path)

    def exists(self, sid):
        fname = os.path.join(self.path, sid)
        return os.path.exists(fname)

    def remove(self, sid):
        current_app.logger.info(_('Removing session: {0}').format(sid))
        fname = os.path.join(self.path, sid)
        if os.path.exists(fname):
            os.unlink(fname)

    def new_session(self):
        sid = _generate_sid(self.prefix)
        fname = os.path.join(self.path, sid)

        while os.path.exists(fname):
            sid = _generate_sid(self.prefix)
            fname = os.path.join(self.path, sid)

        # touch the file
        with open(fname, 'w'):
            pass

        current_app.logger.info(_("New session created: {0}").format(sid))

        return ManagedSession(sid=sid)

    def get(self, sid, digest):
        'Retrieve a managed session by session-id, checking the HMAC digest'

        current_app.logger.debug(_("Looking for session: {0}, {1}").format(sid, digest))

        fname = os.path.join(self.path, sid)
        data = None
        hmac_digest = None
        randval = None

        if os.path.exists(fname):
            if sys.version_info[0] >= 3:
                mode = 'rb'
            else:
                mode = 'r'
            try:
                with open(fname, mode) as f:
                    randval, hmac_digest, data = pickle.load(f)
            except:
                current_app.logger.warning(_("Error loading session file"))

        if not data:
            current_app.logger.debug(_("Missing session data ?"))
            return self.new_session()

        # This assumes the file is correct, if you really want to
        # make sure the session is good from the server side, you
        # can re-calculate the hmac

        if sys.version_info[0] >= 3 and isinstance(hmac_digest, bytes):
            hmac_digest = hmac_digest.decode('utf-8')
        if hmac_digest != digest:
            current_app.logger.debug(_("Invalid HMAC for session"))
            return self.new_session()

        return ManagedSession(data, sid=sid, randval=randval, hmac_digest=hmac_digest)

    def get_map(self, map_key):
        'Retrieves a session id matching a predefined session attribute'
        if self.exists(map_key):
            fname = os.path.join(self.path, map_key)
            with open(fname, 'br') as f:
                session_id = pickle.load(f)
                if self.exists(session_id):
                    return session_id
        return None

    def put(self, app, session):
        'Store a managed session'

        app.logger.debug(_("Storing session: %s") % session.sid)

        if not session.hmac_digest:
            session.sign(self.secret)

        fname = os.path.join(self.path, session.sid)
        if sys.version_info[0] >= 3:
            mode = 'wb'
        else:
            mode = 'w'
        with open(fname, mode) as f:
            pickle.dump((session.randval, session.hmac_digest, dict(session)), f)
        # store attribute/session mapping if defined in auth plugin
        if getattr(app, 'eoleauth_map_attr', None):
            map_key = session.get(app.eoleauth_map_attr, None)
            if map_key and hasattr(session, 'sid'):
                map_key = 'map_' + self.prefix + map_key
                fmap_name = os.path.join(self.path, map_key)
                with open(fmap_name, 'bw') as f:
                    app.logger.debug(_('storing session mapping ({0})').format(app.eoleauth_map_attr))
                    pickle.dump((session.sid), f)

class ManagedSessionInterface(SecureCookieSessionInterface):
    def __init__(self, manager, skip_paths, cookie_timedelta):
        self.manager = manager
        self.skip_paths = skip_paths
        self.cookie_timedelta = cookie_timedelta

    # disabled: use standard flask function
    #def get_expiration_time(self, app, session):
    #    if session.permanent:
    #        return app.permanent_session_lifetime
    #    return datetime.datetime.now() + self.cookie_timedelta

    def open_session(self, app, request):
        cookie_val = request.cookies.get(app.session_cookie_name)

        if not cookie_val or not '!' in cookie_val:
            # Don't bother creating a cookie for static resources
            for sp in self.skip_paths:
                if request.path.startswith(sp):
                    return None

            app.logger.debug(_('Missing cookie'))
            return self.manager.new_session()

        sid, digest = cookie_val.split('!', 1)

        if self.manager.exists(sid):
            return self.manager.get(sid, digest)

        return self.manager.new_session()

    def save_session(self, app, session, response):
        domain = self.get_cookie_domain(app)
        if not session:
            if hasattr(session, 'sid') and self.manager.exists(session.sid):
                # should not happen, session is deleted
                # with its mapping in invalidate_session
                self.manager.remove(session.sid)
            if session.modified:
                response.delete_cookie(app.session_cookie_name, domain=domain)
            return

        if not session.modified:
            # no need to save an unaltered session
            # TODO: put logic here to test if the cookie is older than N days, if so, update the expiration date
            return

        self.manager.put(app, session)
        session.modified = False

        cookie_exp = self.get_expiration_time(app, session)
        if sys.version_info[0] >= 3 and isinstance(session.hmac_digest, bytes):
            hmac_digest = session.hmac_digest.decode('utf-8')
        else:
            hmac_digest = session.hmac_digest
        response.set_cookie(app.session_cookie_name,
                            '%s!%s' % (session.sid, hmac_digest),
                            expires=cookie_exp, httponly=True, domain=domain)

    def remove_session(self, map_key):
        """eoleauth specific: use mapping to a specific session attribute to
        invalidate server side session (attribute defined in auth plugin and
        stored in app.eoleauth_map_attr)
        """
        map_key = 'map_' + self.manager.parent.prefix + map_key
        current_app.logger.debug(_('session removal requested (mapping:{0})').format(map_key))
        session_id = self.manager.get_map(map_key)
        if session_id:
            current_app.logger.debug(_('deleteting session from file cache : {0}').format(session_id))
            self.manager.remove(session_id)
            self.manager.remove(map_key)
            return True
        current_app.logger.debug(_('no session found for this mapping'))
        return False
