#include "mex.hpp"
#include "mexAdapter.hpp"
#include "smile_license.h"
#include "smile.h"

#define CREATENETWORK 1
#define FREENETWORK 2
#define NODESCOUNT 3
#define READFILE 4
#define UPDATEBELIEFS 5
#define SETEVIDENCE 6
#define GETVALUE 7

class MexFunction : public matlab::mex::Function {
public:
    void operator()(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
        int method = checkMethod(inputs);
        checkArguments(method, outputs, inputs);
        wrapperBegin(method, outputs, inputs);
    }

private:    
    
    std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
    matlab::data::ArrayFactory factory;
    
    int checkMethod(matlab::mex::ArgumentList & inputs) {
        matlab::data::CharArray charArray = inputs[0];
        std::string methodString = charArray.toAscii();
        int method = -1;
        if (!methodString.compare("createnetwork")) {
            method = CREATENETWORK;
        } else if (!methodString.compare("freenetwork")) {
            method = FREENETWORK;
        } else if (!methodString.compare("nodescount")) {
            method = NODESCOUNT;
        } else if (!methodString.compare("readfile")) {
            method = READFILE;
        } else if (!methodString.compare("updatebeliefs")) {
            method = UPDATEBELIEFS;
        } else if (!methodString.compare("setevidence")) {
            method = SETEVIDENCE;
        } else if (!methodString.compare("getvalue")) {
            method = GETVALUE;
        }
        return method;
    }
    
    void checkArguments(int method, matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
        auto inputsSize = inputs.size();
        auto outputsSize = outputs.size();
        if (method < 0) {
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Unknown method usage detected.") }));
        } else if (method == CREATENETWORK) {
            if (!(outputsSize == 1 && inputsSize == 1)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Create Network requires one input and one output.") }));
            }
        } else if (method == FREENETWORK) {
            if (!(outputsSize == 0 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Free Network requires two inputs.") }));
            }
        } else if (method == NODESCOUNT) {
            if (!(outputsSize <= 1 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Nodes Count requires two inputs.") }));
            }
        } else if (method == READFILE) {
            if (!(outputsSize == 0 && inputsSize == 3)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Read File requires three inputs") }));
            }
        } else if (method == UPDATEBELIEFS) {
            if (!(outputsSize == 0 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Update Beliefs requires two input.") }));
            }
        } else if (method == SETEVIDENCE) {
            if (!(outputsSize == 0 && inputsSize == 4)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Set Evidence requires four inputs.") }));
            }
        } else if (method == GETVALUE) {
            if (!(outputsSize == 0 && inputsSize == 3)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Get Value requires three inputs.") }));
            }
        }
    }
    
    DSL_node * validateNodeHandle(const DSL_network & net, int handle) {
        DSL_node *node = net.GetNode(handle);
        if (NULL == node)
        {
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Cannot find given node.") }));
        }

        return node;
    }
    
    int validateNodeId(const DSL_network & net, const char * nodeId) {
        int handle = net.FindNode(nodeId);
        if (handle < 0) {
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Cannot find given node.") }));
        }
        return handle;
    }
    
    DSL_node * validateOutcomeIndex(DSL_network & net, int nodeHandle, int outcomeIndex) {
        DSL_node *node = validateNodeHandle(net, nodeHandle);
        int count = node->Definition()->GetNumberOfOutcomes();
        if (outcomeIndex < 0 || outcomeIndex >= count)
        {
        	std::string msg;
        	msg = "Invalid outcome index ";
        	msg += outcomeIndex;
        	msg += " for node '";
        	msg += node->Info().Header().GetId();
        	msg += "', valid indices are 0..";
        	DSL_appendInt(msg, count - 1);
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar(msg.c_str()) }));
        }

        return node;
    }
    
    int validateOutcomeId(const DSL_network & net, int nodeHandle, const char * outcomeId) {
        DSL_node * node = validateNodeHandle(net, nodeHandle);

        DSL_idArray *outcomeNames = node->Definition()->GetOutcomesNames();
        int outcomeIndex = outcomeNames->FindPosition(outcomeId);
        if (outcomeIndex < 0)
        {
            std::string msg;
            msg = "Invalid outcome identifier '";
            msg += outcomeId;
            msg += "' for node '";
            msg += node->Info().Header().GetId();
            msg += '\'';
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar(msg.c_str()) }));
        }

        return outcomeIndex;
    }
    
    // NET = matsmile('createnetwork');   
    void * createNetwork() {
        DSL_network * net = new DSL_network();
        return (void*)net;
    }
  
    // matsmile('freenetwork', NET);
    void freeNetwork(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        delete net;
    }
    
    //COUNT = matsmile('nodescount', NET);
    int nodesCount(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        return net->GetNumberOfNodes();
    }
    
    // matsmile('readfile', NET);
    void readFile(const void * netPtr, std::string filePath) {
        DSL_network * net = (DSL_network*)netPtr;
        net->ReadFile(filePath.c_str());
    }
    
    // matsmile('updatebeliefs', NET);
    void updateBeliefs(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        net->UpdateBeliefs();
    }
    
    // matsmile('setevidence', NET, 'nodeid', 'outcomeid');
    void setEvidence(const void * netPtr, std::string nodeId, std::string outcomeId) {
        DSL_network * net = (DSL_network*)netPtr;
        int nodeHandle = validateNodeId(*net, nodeId.c_str());
        int outcomeIndex = validateOutcomeId(*net, nodeHandle, outcomeId.c_str());
        validateOutcomeIndex(*net, nodeHandle, outcomeIndex)->Value()->SetEvidence(outcomeIndex);
    }
    
    // matsmile('getvalue', NET, 'nodeid');
    void getValue(const void * netPtr, std::string nodeId, std::vector<double> & resultVector) {
        DSL_network * net = (DSL_network*)netPtr;
        int nodeHandle = validateNodeId(*net, nodeId.c_str());
        DSL_nodeValue * nodeValue = validateNodeHandle(*net,nodeHandle)->Value();
        if(!nodeValue->IsValueValid()) {
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Invalid node value") }));
            return;
        }
        DSL_Dmatrix *m = NULL;
        nodeValue->GetValue(&m);
        const DSL_doubleArray & arr = m->GetItems();
        resultVector = std::vector<double>();
        int arrSize = arr.GetSize();
        resultVector.resize(arrSize);
        for (int i = 0; i < arrSize; i++) {
            resultVector[i] = arr[i];
        }
    }
    
    void wrapperBegin(int method, matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        if (method == CREATENETWORK) {
            wrapperCreateNetwork(outputs, inputs);
        } else if (method == FREENETWORK) {
            wrapperFreeNetwork(outputs, inputs);
        } else if (method == NODESCOUNT) {
            wrapperNodesCount(outputs, inputs);
        } else if (method == READFILE) {
            wrapperReadFile(outputs, inputs);
        } else if (method == UPDATEBELIEFS) {
            wrapperUpdateBeliefs(outputs, inputs);
        } else if (method == SETEVIDENCE) {
            wrapperSetEvidence(outputs, inputs);
        } else if (method == GETVALUE) {
            wrapperGetValue(outputs, inputs);
        }
    }
    
    long long encodePtr(void* thepointer) {
        union {uint64_t theinteger; void *thepointer;} ivp;
        ivp.thepointer = thepointer;
        return ivp.theinteger;
    }
    
    void * decodePtr(uint64_t theinteger) {
        union {uint64_t theinteger; void *thepointer;} ivp;
        ivp.theinteger = theinteger;
        return ivp.thepointer;
    }
    
    void wrapperCreateNetwork(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        auto netPtr = createNetwork();
        outputs[0] = factory.createScalar<uint64_t>(encodePtr(netPtr));
    }
    
    void wrapperFreeNetwork(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        freeNetwork(netPtr);
    }
    
    void wrapperNodesCount(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        outputs[0] = factory.createScalar(nodesCount(netPtr));
    }
    
    void wrapperReadFile(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray charArray = inputs[2];
        auto filePath = charArray.toAscii();
        readFile(netPtr, filePath);
    }
    
    void wrapperUpdateBeliefs(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        updateBeliefs(netPtr);
    }
    
    void wrapperSetEvidence(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray nodeIdCharArray = inputs[2];
        matlab::data::CharArray outcomeIdCharArray = inputs[3];
        auto nodeId = nodeIdCharArray.toAscii();
        auto outcomeId = outcomeIdCharArray.toAscii();
        setEvidence(netPtr, nodeId, outcomeId);
    }
    
    void wrapperGetValue(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray nodeIdCharArray = inputs[2];
        auto nodeId = nodeIdCharArray.toAscii();
        std::vector<double> resultVector;
        getValue(netPtr, nodeId, resultVector);
        matlab::data::TypedArray<double> resArr = factory.createArray( {resultVector.size(), 1}, resultVector.begin(), resultVector.end());
        outputs[0] = resArr;
    }
};