# -*- coding: utf-8 -*-
"""
implementation de l'algorithme de merge par héritage des modèles
"""

from xml.etree import ElementTree as ET
from xml.etree.ElementTree import tostring


def add(flux1, flux2):
    """accumule flux2 sur flux1 **sans test d'unicite**

    flux1 = [d1, d2]
    flux2 = [d'1, d'2]
    d'1 devient d3 (la priorite est changee)
    d'2 devient d4 (la priorité est changee)
    add(flux1, flux2) = [d1, d2, d3, d4]

    """
    # MONTANTES
    montantes = flux1.find('montantes')
    len_montantes1 = len(flux1.findall('montantes/directive'))
    montantes2 = flux2.findall('montantes/directive')
    sorted_montantes2 = sorted(montantes2, key=lambda directive: int(directive.get('priority')))
    for directive in sorted_montantes2:
        priority = str(len_montantes1 + sorted_montantes2.index(directive) + 1)
        directive.set('priority', priority)
        montantes.append(directive)
    # DESCENDANTES
    descendantes = flux1.find('descendantes')
    len_descendantes1 = len(flux1.findall('descendantes/directive'))
    descendantes2 = flux2.findall('descendantes/directive')
    sorted_descendantes2 = sorted(descendantes2, key=lambda directive: int(directive.get('priority')))
    for directive in sorted_descendantes2:
        priority = str(len_descendantes1 + sorted_descendantes2.index(directive) + 1)
        directive.set('priority', priority)
        descendantes.append(directive)

def complementary_flux(flux1, flux2):
    """
    :returns: flux list. Les flux qui sont dans flux2 mais pas dans flux1
    """
    existing_flux = []
    complementary = []

    for flx in flux1:
        existing_flux.append( (flx.get('zoneA'), flx.get('zoneB')) )

    for flx in flux2:
        zoneA = flx.get('zoneA')
        zoneB = flx.get('zoneB')
        if (zoneA, zoneB) not in existing_flux:
            complementary.append(flx)

    return complementary


def main(fname1):
    """entry point

    :param fname1: xml file name
    :returns: xml string
    """
    fwtree = read_from_file(fname1)
    root = fw_recurse(fwtree)
    # puts a generated name for convenience
    root.attrib['name'] = "Concatenated_Do_Not_Edit"
    return b'<?xml version="1.0" encoding="UTF-8" ?>\n' + \
            tostring(root).replace(b'&apos;', b"'").replace(b'&#233;', bytes('é', 'utf-8')).replace(b'&#232;', bytes('è', 'utf-8'))


def fw_recurse(fwtree):
    """enclenche la récursivité verticale
    (le fichier cible a un attribut model)

    :param fwtree: xml tree node
    :returns: xml
    """
    if 'model' in fwtree.attrib:
        # un attribut 'model'
        fnames = [fname.strip() for fname in fwtree.attrib['model'].split(',')]
        for fname in fnames:
            model = read_from_file(fname)
            fwtree = merge(fwtree, fw_recurse(model))
    sort_directive(fwtree.findall('flux-list/flux'))
    return fwtree

def sort_directive(flux):
    for fx in flux:
        montantes = fx.findall('montantes/directive')
        priority = 0
        for directive in montantes:
            priority += 1
            directive.set('priority', str(priority))
        descendantes = fx.findall('descendantes/directive')
        priority = 0
        for directive in descendantes:
            priority += 1
            directive.set('priority', str(priority))
    pass

def merge(fw1, fw2):
    """concatenation des nodes element tree

    IMPORTANT : c'est le **fw2** qui dirige les options
    et notamment l'attribut 'model'

    :param fw1: xml root tree node description
    :param fw2: xml root tree node description
    :returns: merged root xml string
    """
    root = ET.Element("firewall")
    root.text = '\n    '
    root.attrib.update(fw2.attrib)
    # ____________________________________________________________
    # ajout des zones
    zones = fw1.findall("zones/zone")
    for zone in fw2.findall("zones/zone"):
        if zone.get('name') not in [z.get('name') for z in zones]:
            zones.append(zone)

    zone_root = ET.SubElement(root, "zones")
    zone_root.text = "\n        "
    for zone in zones:
        zone_root.append(zone)
    # ____________________________________________________________
    # add includes
    incl_root = ET.SubElement(root, "include")
    incl_root.text = ""
    inc1 = fw1.findtext('include')
    if inc1 is not None:
        incl_root.text += '\n' + inc1.strip()
    inc2 = fw2.findtext('include')
    if inc2 is not None:
        incl_root.text += '\n' + inc2.strip()
    incl_root.text += '\n    '
    # ____________________________________________________________
    # add services
    service_root = ET.SubElement(root, "services")
    service_root.text = "\n        "
    for serv in fw1.findall('services/service'):
        service_root.append(serv)

    for serv in fw2.findall('services/service'):
        if serv.get('name') not in [servi.get('name') for servi in service_root]:
           service_root.append(serv)
    # ____________________________________________________________
    # service group
    for serv in fw1.findall('services/groupe'):
        service_root.append(serv)


    for serv in fw2.findall('services/groupe'):
        if serv.get('id') not in [servi.get('id') for servi in service_root]:
            service_root.append(serv)

    # ____________________________________________________________
    # extremites
    extremite_root = ET.SubElement(root, "extremites")
    extremite_root.text = "\n        "
    for serv in fw1.findall('extremites/extremite'):
        extremite_root.append(serv)
    for serv in fw2.findall('extremites/extremite'):
        extremite_root.append(serv)
    # ____________________________________________________________
    # ranges
    ranges_root = ET.SubElement(root, "ranges")
    ranges_root.text = "\n        "
    for serv in fw1.findall('ranges/range'):
        ranges_root.append(serv)
    for serv in fw2.findall('ranges/range'):
        ranges_root.append(serv)

    # ____________________________________________________________
    # user_group
    user_root = ET.SubElement(root, "user_groups")
    user_root.text = "\n        "
    for serv in fw1.findall('user_groups/user_group'):
        user_root.append(serv)
    for serv in fw2.findall('user_groups/user_group'):
        user_root.append(serv)
    # ____________________________________________________________
    # applications
    user_root = ET.SubElement(root, "applications")
    user_root.text = "\n        "
    for serv in fw1.findall('applications/application'):
        user_root.append(serv)
    for serv in fw2.findall('applications/application'):
        user_root.append(serv)

    # ____________________________________________________________
    # qosclasses
    user_root = ET.SubElement(root, "qosclasses")
    user_root.attrib['download'] = ''
    user_root.attrib['upload'] = ''
    user_root.text = "\n        "

    # ____________________________________________________________
    # flux
    flux_root = ET.SubElement(root, "flux-list")
    flux_root.text = "\n        "

    flux1 = fw1.findall('flux-list/flux')
    flux2 = fw2.findall('flux-list/flux')
    different_flux =  complementary_flux(flux1, flux2)

    # accumulation des flux identiques
    for fx1 in flux1:
        for fx2 in flux2:
            if fx1.get('zoneA') == fx2.get('zoneA') and fx1.get('zoneB') == fx2.get('zoneB'):
                add(fx1, fx2)

    for flx in flux1:
        flux_root.append(flx)

    # ajout des flux complementaires
    for flx in different_flux:
        flux_root.append(flx)

    return root


#def add_xml_declaration(root_node):
#    """
#    :param root_node: element tree root node
#    :returns: xml string with xml declaration
#    """
#    return '<?xml version="1.0" encoding="UTF-8" ?>\n' + \
#           tostring(root,  encoding='utf-8')


def read_from_file(fname):
    """
    :param fname: xml filename
    :returns: element tree node
    """
    try:
        xml = ET.parse(fname)
    except Exception as err:
        raise Exception('impossible de charger {0} : {1}'.format(fname, err))

    firewall = xml.getroot()
    return firewall


def fromstring(xmlstr):
    """
    :param xmlstr: xml as a string
    :returns: element tree node
    """
    firewall = ET.fromstring(xmlstr)
    return firewall

if __name__ == '__main__':
    import sys
    fname1 = sys.argv[1]
    print(main(fname1))
