#include <iostream>
#include <signal.h>
#include <opencv2/opencv.hpp>
#include <vector>

#include "memx/accl/MxAccl.h"

#include <string>
#include <numeric>
#include <fstream>

namespace fs = std::filesystem;

std::atomic_bool runflag;

// Imagenet folder
const fs::path imagenetPath = "imagenet100/";
std::string baseFilename = "ILSVRC2012_val_";

// model file 
const fs::path modelPath = "resnet.dfp";

//model info
MX::Types::MxModelInfo model_info;

std::vector<int> ground_truth  = {  65,970,230,809,516,57,334,415,674,332,109,286,370,757,595,147,473,23,478,517,334,173,
                                    948,727,23,846,270,167,55,858,324,573,150,981,586,887,32,398,777,74,516,756,129,198,
                                    256,725,565,167,717,394,92,29,844,591,358,468,259,994,872,588,474,183,107,46,842,390,
                                    101,887,870,841,467,149,21,476,80,424,159,275,175,461,970,160,788,58,479,498,369,28,487,
                                    50,270,383,366,780,373,705,330,142,949,349 };

int totalSamples = 0;
int correctTop1 = 0;
int correctTop5 = 0;
int counter = 1;

//signal handler
void signalHandler(int pSignal){
    runflag.store(false);
}

cv::Mat preprocessImage(cv::Mat& img) {
    
    cv::resize(img, img, cv::Size(224, 224));

    if (img.channels() == 1) {
        cv::cvtColor(img, img, cv::COLOR_GRAY2BGR);
    }

    cv::Mat img_float;
    img.convertTo(img_float, CV_32F);

    return img_float; 
}

bool incallback_getframe(vector<const MX::Types::FeatureMap*> dst, int streamLabel){

    if(runflag.load()){

        bool got_frame = false;
        cv::Mat inframe;

        if(counter > 100) {
            std::cout << "\n\nEnd of video/cam/img \n\n\n";
            runflag.store(false);
            return false; 
        }

        std::stringstream ss;
        ss << baseFilename << std::setw(8) << std::setfill('0') << counter << ".JPEG";
        fs::path fullPath = imagenetPath / ss.str();
        inframe = cv::imread(fullPath.string());
        counter++;
        
        if(!inframe.empty()){
            got_frame = true;
        }

        if (!got_frame) {
            std::cout << "\n\n End of video/cam/img \n\n\n";
            runflag.store(false);
            return false;  // return false if frame retrieval fails
        }

        else{
            // Preprocess frame
            cv::Mat preProcframe = preprocessImage(inframe);

            // Set preprocessed input data to be sent to accelarator
            dst[0]->set_data((float*)preProcframe.data);

            return true;
        }           
    }
    else
    {
        runflag.store(false);
        return false;
    }    
}

void printScore(const std::vector<float*>& ofmaps){
    
    totalSamples += 1; 
    std::vector<int> indices(1000);  
    
    std::iota(indices.begin(), indices.end(), 0);  // Fill the indices vector with consecutive integers starting at 0

    std::sort(indices.begin(), indices.end(),  // Sort the indices based on the comparison of values in ofmaps
        [&ofmaps](int i1, int i2) {
            return ofmaps[0][i1] > ofmaps[0][i2];
        }
    );

    int trueIndex = ground_truth[totalSamples - 1]; // Retrieve the ground truth index for the current sample

    // Check if the top prediction (highest probability) matches the true index (top-1 accuracy)
    if (indices[0] == trueIndex) {  
        correctTop1++;
    }

    // Check top-5 accuracy: see if the true index is among the top 5 predictions
    if (std::find(indices.begin(), indices.begin() + 5, trueIndex) != indices.begin() + 5) {   
        correctTop5++;
    }
}

bool outcallback_getmxaoutput(vector<const MX::Types::FeatureMap*> src, int streamLabel){

    std::vector<float*> ofmap;
    
    ofmap.reserve(src.size());
    
    for(int i; i<model_info.num_out_featuremaps ; ++i){
        float * fmap = new float[model_info.out_featuremap_sizes[i]];
        src[i]->get_data(fmap);
        ofmap.push_back(fmap);
    }

    printScore(ofmap);

    return true;
}

int main(){

    signal(SIGINT, signalHandler);

    std::cout << "application start \n";
    std::cout << "model path = " << modelPath.c_str() << "\n";

    MX::Runtime::MxAccl* accl = new MX::Runtime::MxAccl(modelPath);

    runflag.store(true);

    if(runflag.load()){

        model_info = accl->get_model_info(0);

        accl->connect_stream(&incallback_getframe, &outcallback_getmxaoutput, 10 /*unique stream ID*/, 0 /*Model ID */);
        std::cout << "Connected stream \n\n\n";

        accl->start();

        while(runflag.load()){
            std::this_thread::sleep_for(std::chrono::milliseconds(2));
        }  
        accl->stop();

        // Print accuracy
        std::cout << "Top-1 Accuracy: " << static_cast<double>(correctTop1) / totalSamples * 100.0 << "%" << std::endl;
        std::cout << "Top-5 Accuracy: " << static_cast<double>(correctTop5) / totalSamples * 100.0 << "%" << std::endl;

    }

    else{
        std::cout << "App exiting without execution \n\n\n";       
    }

    return 1;
}