# -*- coding: utf-8 -*-

import pysmile
import pysmile_license
import numpy as np
import cpd

# This class creates the network,
class CollaborationNetwork:

    OUTCOMES = ["Bajo", "Medio", "Alto"]

    def __init__(self):
        help(pysmile)
        net = pysmile.Network()
        cg = self.create_cpt_node(net,
            "Colaborar_G", "Colaborar en X con Otros",
            self.OUTCOMES)

        pg = self.create_cpt_node(net,
            "Proponer_G", "Proponer sobre X a Otros",
            self.OUTCOMES)
        
        ag = self.create_cpt_node(net,
            "Aportar_G", "Aportar a X con Otros",
            self.OUTCOMES)
        
        ce = self.create_cpt_node(net,
            "Colaborar_E", "Colaborar en Proyecto con Otros",
            self.OUTCOMES)
        
        pe = self.create_cpt_node(net,
            "Proponer_E", "Proponer sobre Proyecto a Otros",
            self.OUTCOMES)
        
        ae = self.create_cpt_node(net,
            "Aportar_E", "Aportar a Proyecto con Otros",
            self.OUTCOMES)

        ec = self.create_cpt_node(net,
            "e_colaborar", "Evidencia de Colaborar en Proyecto con Otros",
            self.OUTCOMES)
        
        ep = self.create_cpt_node(net,
            "e_proponer", "Evidencia de Proponer sobre Proyecto a Otros",
            self.OUTCOMES)
        
        ea = self.create_cpt_node(net,
            "e_aportar", "Evidencia de Aportar a Proyecto con Otros",
            self.OUTCOMES)

        net.add_arc(cg, pg)
        net.add_arc(cg, ag)
        net.add_arc(cg, ce)
        net.add_arc(pg, pe)
        net.add_arc(ce, pe)
        net.add_arc(ag, ae)
        net.add_arc(ce, ae)
        net.add_arc(ce, ec)
        net.add_arc(pe, ep)
        net.add_arc(ae, ea)

        net.add_temporal_arc(cg,cg,1)
        net.add_temporal_arc(pg,pg,1)
        net.add_temporal_arc(ag,ag,1)
        net.add_temporal_arc(ce,ce,1)
        net.add_temporal_arc(pe,pe,1)
        net.add_temporal_arc(ae,ae,1)

        # Non temporal conditionals
        # we can also use node identifiers when creating arcs  
        # net.add_arc("Colaborar_G", "Proponer_G");
        t = self.stream_table(cpd.flatDist(3))
        net.set_node_definition(cg, t)
        
        t = self.stream_table(cpd.incsub(2))    
        net.set_node_definition(pg, t)
        net.set_node_definition(ag, t)

        t = self.stream_table(cpd.genesp())
        net.set_node_definition(ce, t)

        t = self.stream_table(cpd.join(cpd.genesp(),cpd.incsub(2)))
        net.set_node_definition(pe, t)
        net.set_node_definition(ae, t)

        t = self.stream_table(cpd.compevi())        
        net.set_node_definition(ec, t)
        net.set_node_definition(ep, t)
        net.set_node_definition(ea, t)

        # Conditionals that include temporal relations

        tpp = cpd.paspre()
        t = self.stream_table(tpp)
        net.set_node_temporal_definition(cg, 1, t)
        
        t = self.stream_table(cpd.join(cpd.incsub(2),tpp))    
        net.set_node_temporal_definition(pg, 1, t)
        net.set_node_temporal_definition(ag, 1, t)

        t = self.stream_table(cpd.join(cpd.genesp(),tpp))
        net.set_node_temporal_definition(ce, 1, t)

        t = self.stream_table(cpd.join(cpd.genesp(),cpd.incsub(2),tpp))
        net.set_node_temporal_definition(pe, 1, t)
        net.set_node_temporal_definition(ae, 1, t)

        net.write_file("colaborar_din.xdsl")        
        print("Network written to colaborar_din.xdsl")

        self._network = net


    def get_network (self):
        return self._network


    def create_cpt_node(self, net, id, name, outcomes):
        handle = net.add_node()
        net.set_node_temporal_type(handle, pysmile.NodeTemporalType.PLATE)
        net.set_node_id(handle,id)
        net.set_node_name(handle, name)

        initial_outcome_count = net.get_outcome_count(handle)
        
        for i in range(0, initial_outcome_count):
            net.set_outcome_id(handle, i, outcomes[i])
        
        for i in range(initial_outcome_count, len(outcomes)):
            net.add_outcome(handle, outcomes[i])
            
        return handle


    # We asume a  is a numpy array of columns.
    def stream_table(self,a):
        return a.flatten().tolist()


    # Recreate the numpy array from the list, assuming n rows.
    def recreate_table (l, n):
        z = list()
        temp = list()
        m = len(l)

        for i in range(m):
            temp.append(l[i])
            if (i % n) == (n-1):
                z.append(temp)
                temp = list()
        
        return np.array(z)


class InferenceOnNetwork:
    def __init__(self, net):
        print("Starting inferencing...")
        
        """
        net.update_beliefs()
        print("Posteriors with no evidence set:")
        self.print_all_posteriors(net)
        """
        
        self.change_evidence_and_update(net, 'e_proponer', 1, 'Bajo')    
        self.change_evidence_and_update(net, 'e_aportar', 2, 'Bajo')    
        self.change_evidence_and_update(net, 'e_proponer', 4, 'Medio')    
        self.change_evidence_and_update(net, 'e_aportar', 5, 'Medio')    
        self.change_evidence_and_update(net, 'e_colaborar', 7, 'Medio')    
        self.change_evidence_and_update(net, 'e_proponer', 10, 'Alto')    
        self.change_evidence_and_update(net, 'e_aportar', 11, 'Medio')    
        self.change_evidence_and_update(net, 'e_colaborar', 14, 'Medio')    
        net.set_slice_count(16)
        net.update_beliefs()
        print("Posteriors with evidence set:")
        self.print_all_posteriors(net)

        net.write_file("colaborar_din2.xdsl")        
        print("Network written to colaborar_din2.xdsl")


    def print_posteriors(self, net, node_handle):
        node_id = net.get_node_id(node_handle)
        if net.is_evidence(node_handle):
            print(node_id + " has evidence set (" +
                  net.get_outcome_id(node_handle, 
                                     net.get_evidence(node_handle)) + ")")
        else :
            posteriors = net.get_node_value(node_handle)
            n = net.get_outcome_count(node_handle)
            for i in range(0, len(posteriors)):
                value = value = net.get_outcome_id(node_handle, i%n)
                p = posteriors[i]
                print("P(" + node_id + " = " + value + ") = " + str(p))

    def print_all_posteriors(self, net):
        handles = net.get_all_nodes()
        for h in handles:
            self.print_posteriors(net, h)
    
    def change_evidence_and_update(self, net, node_id, n_time, outcome_id):
        if outcome_id is not None:
            net.set_temporal_evidence(node_id, n_time, outcome_id)
        else:
            net.clear_temporal_evidence(node_id, n_time)
        

cn = CollaborationNetwork()
InferenceOnNetwork(cn.get_network())
