#include "mex.hpp"
#include "mexAdapter.hpp"
#include "smile_license.h"
#include "smile.h"

#define CREATENETWORK 1
#define FREENETWORK 2
#define NODESCOUNT 3
#define READFILE 4
#define UPDATEBELIEFS 5
#define SETEVIDENCE 6
#define GETVALUE 7

class MexFunction : public matlab::mex::Function {
public:
    void operator()(matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
        std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
        matlab::data::ArrayFactory factory;
        int method = checkMethod(inputs);
        checkArguments(method, outputs, inputs);
        wrapperBegin(method, outputs, inputs);
    }

private:    
    int checkMethod(matlab::mex::ArgumentList & inputs) {
        std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
        matlab::data::ArrayFactory factory;
        matlab::data::CharArray charArray = inputs[0];
        std::string methodString = charArray.toAscii();
        int method = -1;
        if (!methodString.compare("createnetwork")) {
            method = CREATENETWORK;
        } else if (!methodString.compare("freenetwork")) {
            method = FREENETWORK;
        } else if (!methodString.compare("nodescount")) {
            method = NODESCOUNT;
        } else if (!methodString.compare("readfile")) {
            method = READFILE;
        } else if (!methodString.compare("updatebeliefs")) {
            method = UPDATEBELIEFS;
        } else if (!methodString.compare("setevidence")) {
            method = SETEVIDENCE;
        } else if (!methodString.compare("getvalue")) {
            method = GETVALUE;
        }
        return method;
    }
    
    void checkArguments(int method, matlab::mex::ArgumentList outputs, matlab::mex::ArgumentList inputs) {
        std::shared_ptr<matlab::engine::MATLABEngine> matlabPtr = getEngine();
        matlab::data::ArrayFactory factory;
        auto inputsSize = inputs.size();
        auto outputsSize = outputs.size();
        if (method < 0) {
            matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Unknown method usage detected.") }));
        } else if (method == CREATENETWORK) {
            if (!(outputsSize == 1 && inputsSize == 1)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Create Network requires one input and one output.") }));
            }
        } else if (method == FREENETWORK) {
            if (!(outputsSize == 0 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Free Network requires two inputs.") }));
            }
        } else if (method == NODESCOUNT) {
            if (!(outputsSize <= 1 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Nodes Count requires two inputs.") }));
            }
        } else if (method == READFILE) {
            if (!(outputsSize == 0 && inputsSize == 3)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Read File requires three inputs") }));
            }
        } else if (method == UPDATEBELIEFS) {
            if (!(outputsSize == 0 && inputsSize == 2)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Update Beliefs requires two input.") }));
            }
        } else if (method == SETEVIDENCE) {
            if (!(outputsSize == 0 && inputsSize == 4)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Set Evidence requires four inputs.") }));
            }
        } else if (method == GETVALUE) {
            if (!(outputsSize == 0 && inputsSize == 3)) {
                matlabPtr->feval(u"error", 
                0, std::vector<matlab::data::Array>({ factory.createScalar("Get Value requires three inputs.") }));
            }
        }
    }
    
    
    
    // NET = matsmile('createnetwork');   
    void * createNetwork() {
        DSL_network * net = new DSL_network();
        return (void*)net;
    }
  
    // matsmile('freenetwork', NET);
    void freeNetwork(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        delete net;
    }
    
    //COUNT = matsmile('nodescount', NET);
    int nodesCount(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        return net->GetNumberOfNodes();
    }
    
    // matsmile('readfile', NET);
    void readFile(const void * netPtr, std::string filePath) {
        DSL_network * net = (DSL_network*)netPtr;
        net->ReadFile(filePath.c_str());
    }
    
    // matsmile('updatebeliefs', NET);
    void updateBeliefs(const void * netPtr) {
        DSL_network * net = (DSL_network*)netPtr;
        net->UpdateBeliefs();
    }
    
    // matsmile('setevidence', NET, 'nodeid', 'outcomeid');
    void setEvidence(const void * netPtr, std::string nodeId, std::string outcomeId) {
        DSL_network * net = (DSL_network*)netPtr;
        int nodeHandle = net->FindNode(nodeId.c_str());
        DSL_node * node = net->GetNode(nodeHandle);
        int outcomeIndex = node->Definition()->GetOutcomesNames()->FindPosition(outcomeId.c_str());
        node->Value()->SetEvidence(outcomeIndex);
    }
    
    // matsmile('getvalue', NET, 'nodeid');
    void getValue(const void * netPtr, std::string nodeId, std::vector<double> & resultVector) {
        DSL_network * net = (DSL_network*)netPtr;
        int nodeHandle = net->FindNode(nodeId.c_str());
        DSL_nodeValue * nodeValue = net->GetNode(nodeHandle)->Value();
        if(!nodeValue->IsValueValid()) {
            return;
        }
        DSL_Dmatrix *m = NULL;
        nodeValue->GetValue(&m);
        const DSL_doubleArray & arr = m->GetItems();
        resultVector = std::vector<double>();
        int arrSize = arr.GetSize();
        for (int i = 0; i < arrSize; i++) {
            resultVector.push_back(arr[i]);
        }
    }
    
    void wrapperBegin(int method, matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        if (method == CREATENETWORK) {
            wrapperCreateNetwork(outputs, inputs);
        } else if (method == FREENETWORK) {
            wrapperFreeNetwork(outputs, inputs);
        } else if (method == NODESCOUNT) {
            wrapperNodesCount(outputs, inputs);
        } else if (method == READFILE) {
            wrapperReadFile(outputs, inputs);
        } else if (method == UPDATEBELIEFS) {
            wrapperUpdateBeliefs(outputs, inputs);
        } else if (method == SETEVIDENCE) {
            wrapperSetEvidence(outputs, inputs);
        } else if (method == GETVALUE) {
            wrapperGetValue(outputs, inputs);
        }
    }
    
    long long encodePtr(void* thepointer) {
        union {uint64_t theinteger; void *thepointer;} ivp;
        ivp.thepointer = thepointer;
        return ivp.theinteger;
    }
    
    void * decodePtr(uint64_t theinteger) {
        union {uint64_t theinteger; void *thepointer;} ivp;
        ivp.theinteger = theinteger;
        return ivp.thepointer;
    }
    
    void wrapperCreateNetwork(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::ArrayFactory factory;
        auto netPtr = createNetwork();
        outputs[0] = factory.createScalar<uint64_t>(encodePtr(netPtr));
    }
    
    void wrapperFreeNetwork(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        freeNetwork(netPtr);
    }
    
    void wrapperNodesCount(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::ArrayFactory factory;
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        outputs[0] = factory.createScalar(nodesCount(netPtr));
    }
    
    void wrapperReadFile(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::ArrayFactory factory;
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray charArray = inputs[2];
        auto filePath = charArray.toAscii();
        readFile(netPtr, filePath);
    }
    
    void wrapperUpdateBeliefs(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        updateBeliefs(netPtr);
    }
    
    void wrapperSetEvidence(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray nodeIdCharArray = inputs[2];
        matlab::data::CharArray outcomeIdCharArray = inputs[3];
        auto nodeId = nodeIdCharArray.toAscii();
        auto outcomeId = outcomeIdCharArray.toAscii();
        setEvidence(netPtr, nodeId, outcomeId);
    }
    
    void wrapperGetValue(matlab::mex::ArgumentList & outputs, matlab::mex::ArgumentList & inputs) {
        matlab::data::ArrayFactory factory;
        matlab::data::TypedArray<uint64_t> arr = inputs[1];
        auto netPtr = decodePtr(arr[0]);
        matlab::data::CharArray nodeIdCharArray = inputs[2];
        auto nodeId = nodeIdCharArray.toAscii();
        std::vector<double> resultVector;
        getValue(netPtr, nodeId, resultVector);
        matlab::data::TypedArray<double> resArr = factory.createArray( {resultVector.size(), 1}, resultVector.begin(), resultVector.end());
        outputs[0] = resArr;
    }
};