#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Nov 28 17:41:06 2019

@author: vagrant
"""

import pysmile
try:
    import pysmile_license
except:
    pass



def update_and_print_results(node):
    net.update_beliefs()
    print_results(node)

def print_cpt_matrix(node_handle):
    if net.get_node_type(node_handle) != pysmile.NodeType.CPT:
        print("WARNING: CALLED print_cpt_matrix ON A NON CPT NODE!")
    else:
        cpt = net.get_node_definition(node_handle)
        parents = net.get_parents(node_handle)
        dim_count = 1 + len(parents)
        dim_sizes = [0] * dim_count
        for i in range(0, dim_count - 1):
            dim_sizes[i] = net.get_outcome_count(parents[i])
        dim_sizes[len(dim_sizes) - 1] = net.get_outcome_count(node_handle)
        coords = [0] * dim_count
        for elem_idx in range(0, len(cpt)):
            index_to_coords(elem_idx, dim_sizes, coords)
            outcome = net.get_outcome_id(node_handle, coords[dim_count - 1])
            out_str = "    P(" + outcome
            if dim_count > 1:
                out_str += " | "
                for parent_idx in range(0, len(parents)):
                    if parent_idx > 0:
                        out_str += ","
                    parent_handle = parents[parent_idx]
                    out_str += net.get_node_id(parent_handle) + "=" + \
                    net.get_outcome_id(parent_handle, coords[parent_idx])
            prob = cpt[elem_idx]
            out_str += ")=" + str(prob)
            print(out_str)  

def print_results(node_handle):
    if net.get_node_type(node_handle) == pysmile.NodeType.CPT:
        if net.get_node_temporal_type(node_handle) == pysmile.NodeTemporalType.PLATE:
            show_temporal_results(node_handle)
        else:    
            print_cpt_believes(node_handle)
    elif net.get_node_type(node_handle) == pysmile.NodeType.EQUATION:
        print_equation_stats(node_handle)
    else:
        print("WARNING: CANT PRINT BELIEVES, UNKOWN NODE TYPE: " + net.get_node_type(node_handle))

def print_equation_stats(node_handle):
    node_id = net.get_node_id(node_handle)
    if net.is_evidence(node_handle):
        v = net.get_cont_evidence(node_handle)
        print(node_id + " has evidence set " + str(v))
        return

    if net.is_value_discretized(node_handle):
        print(node_id + " is discretized.")
        iv = net.get_node_equation_discretization(node_handle)
        bounds = net.get_node_equation_bounds(node_handle)
        disc_beliefs = net.get_node_value(node_handle)
        lo = bounds[0]
        for i in range(0, len(disc_beliefs)):
            hi = iv[i].boundary
            print("\tP(" + node_id + " in " + str(lo) + ".." + str(hi)
            + ")=" + str(disc_beliefs[i]))
            lo = hi
    else:
        stats = net.get_node_sample_stats(node_handle)
        print(node_id + ": mean=" + str(stats[0]) + " stddev="
              + str(stats[1]) + " min=" + str(stats[2]) + " max="
              + str(stats[3]))
        
def print_cpt_believes(node):
    beliefs = net.get_node_value(node)
    for i in range(0, len(beliefs)):
        print("    " + net.get_outcome_id(node, i) + "=" + str(beliefs[i]))

def show_temporal_results(node_handle):
    slice_count = net.get_slice_count()
    if net.get_node_temporal_type(node_handle) == pysmile.NodeTemporalType.PLATE:
        outcome_count = net.get_outcome_count(node_handle)
        print("Temporal beliefs for " + net.get_node_id(node_handle) + ":")
        v = net.get_node_value(node_handle)
        for slice_idx in range(0, slice_count):
            s = "\tt=" + str(slice_idx) + ":"
            for i in range(0, outcome_count):
                s = s + " " + str(v[slice_idx * outcome_count + i])
            print(s)
    print("")
    
def index_to_coords(index, dim_sizes, coords):
    prod = 1
    for i in range(len(dim_sizes) - 1, -1, -1):
        coords[i] = int(index / prod) % dim_sizes[i]
        prod *= dim_sizes[i]


networkFile = "dev_PerfNetwork1_3BE_1574983185.5145876.xdsl"
networkOutFile = networkFile[:-5]+ "B.xdsl"
dataFile = "PerfNetwork1_3BE_Data.csv"
net = pysmile.Network()
net.read_file(networkFile)

ds = pysmile.learning.DataSet()
ds.read_file(dataFile);
print("learning from records: " + str(ds.get_record_count()))
matching = ds.match_network(net)
em = pysmile.learning.EM()
em.learn(ds, net, matching)
lastScore = em.get_last_score()
net.write_file(networkOutFile)


print(lastScore)
    
    
    
for node in net.get_all_nodes():
    print("\nnode id/name: " + str(node) + "/" + str(net.get_node_id(node)))
    print("    " + str(net.get_node_name(node)))
    print_cpt_matrix(node)
    
    
    print("no evidence probabilities: ")
    net.update_beliefs()
    update_and_print_results(node)
    print("number of node outcomes: " + str(net.get_outcome_count(node)))
    
    if net.get_node_temporal_type(node) == pysmile.NodeTemporalType.PLATE:
        tempOrder = net.get_max_node_temporal_order(node)
        print('temporal order maximum: ' + str(tempOrder))
        for order in range(1, tempOrder+1):
            print('trying order ' + str(order))
            parents = net.get_temporal_parents(node, order)
            if len(parents) > 0:
                print('temporal cpt')
                cpt = net.get_node_temporal_definition(node, order)
                print(str(cpt))
            for parent in parents:
                print('temporal parents: handle, id, order')
                print(str([parent.handle,
                           parent.id, 
                           parent.order]))
print("score: " + str(lastScore))