#include <iostream>
#include <thread>
#include <signal.h>
#include <atomic>
#include <deque>
#include <opencv2/opencv.hpp>    /* imshow */
#include <opencv2/imgproc.hpp>   /* cvtcolor */
#include <opencv2/imgcodecs.hpp> /* imwrite */
#include <filesystem>
#include "memx/accl/MxAccl.h"

#include <vector>
#include <algorithm>
#include <numeric>
#include <cmath>

namespace fs = std::filesystem;

bool use_cam = true; // Use camera flag
std::atomic_bool runflag; // Control flag for processing loop

bool window_created = false; // Ensure window opens only once

// Model file paths
const fs::path modelPath = "yolov8m-pose.dfp"; // dfp model file
const fs::path onnx_postprocessing_model_path = "yolov8m-pose_post.onnx"; // Post-processing model file

const fs::path videoPath = "../Friends.mp4"; // Video file path

// Model information
MX::Types::MxModelInfo model_info; // dfp model info
MX::Types::MxModelInfo post_model_info; // Post-processing model info

std::vector<float*> ofmap;

int model_input_width = 640; 
int model_input_height = 640; 
double origHeight = 0.0;  // Original frame height
double origWidth = 0.0;  // Original frame width

float box_score = 0.25; // Box confidence threshold
float rat = 0.0; // Aspect ratio for resizing
float kpt_score = 0.5; // Keypoint confidence threshold
float nms_thr = 0.2; // Non-Maximum Suppression threshold
int dets_length = 8400; // values per detection 
int num_kpts = 17; //no of keypoints

// Video capture object for camera or file
cv::VideoCapture vcap; 

// Queues for input frames
std::deque<cv::Mat> frames_queue; // Frame queue
std::mutex frameQueue_mutex; // Mutex for frame queue

#define AVG_FPS_CALC_FRAME_COUNT  50
int frame_count = 0;
float fps_number =.0;
char fps_text[64] = "FPS = ";
std::chrono::milliseconds start_ms;

// Color list for keypoint visualization
const std::vector<cv::Scalar> COLOR_LIST = {
    cv::Scalar(128, 255, 0), cv::Scalar(255, 128, 50), cv::Scalar(128, 0, 255),
    cv::Scalar(255, 255, 0), cv::Scalar(255, 102, 255), cv::Scalar(255, 51, 255),
    cv::Scalar(51, 153, 255), cv::Scalar(255, 153, 153), cv::Scalar(255, 51, 51),
    cv::Scalar(153, 255, 153), cv::Scalar(51, 255, 51), cv::Scalar(0, 255, 0),
    cv::Scalar(255, 0, 51), cv::Scalar(153, 0, 153), cv::Scalar(51, 0, 51),
    cv::Scalar(0, 0, 0), cv::Scalar(0, 102, 255), cv::Scalar(0, 51, 255),
    cv::Scalar(0, 153, 255), cv::Scalar(0, 153, 153)
};

// Pairs of keypoints for drawing skeleton
const std::vector<std::pair<int, int>> KEYPOINT_PAIRS = {
    {0, 1}, {0, 2}, {1, 3}, {2, 4}, {0, 5}, {0, 6}, {5, 7}, {7, 9}, {6, 8},
    {8, 10}, {5, 6}, {5, 11}, {6, 12}, {11, 12}, {11, 13}, {13, 15}, {12, 14}, {14, 16}
};

//signal handler
void signalHandler(int pSignal){
    runflag.store(false);
}

bool incallback_getframe(vector<const MX::Types::FeatureMap<float>*> dst, int streamLabel){

    if(runflag.load()){

        cv::Mat inframe;

        bool got_frame = vcap.read(inframe); // Capture a frame from video/camera

        if (!got_frame) {
            std::cout << "\n\n No frame - End of cam? \n\n\n";
            runflag.store(false);
            return false; 
        }

        else{
            // Put the frame in the cap_queue to be overlayed later
            {
                std::lock_guard<std::mutex> flock(frameQueue_mutex); // Lock the queue
                frames_queue.push_back(inframe); // Push the captured frame to the queue
            }
            cv::Mat rgbImage;
            cv::cvtColor(inframe, rgbImage, cv::COLOR_BGR2RGB);
            // Preprocess frame
            cv::Mat preProcframe;
            cv::resize(rgbImage,preProcframe,cv::Size(model_input_width,model_input_height),cv::INTER_LINEAR); // Resize frame to model input size
            cv::Mat floatImage;
            preProcframe.convertTo(floatImage, CV_32F,  1.0 / 255.0); // Normalize frame to float

            // Set preprocessed input data to be sent to accelarator
            dst[0]->set_data((float*)floatImage.data, false);

            return true;
        }           
    }
    else{
        vcap.release(); // Release video capture resources
        return false;
    }    
}

struct Box {
    float x1, y1, x2, y2, confidence;
    std::vector<std::pair<float, float>> keypoints; //  keypoints are stored as pairs of (x, y)
};

// Output callback function
bool outcallback_getmxaoutput(vector<const MX::Types::FeatureMap<float>*> src, int streamLabel) {

    // Extract feature maps from the accelerator output
    for(int i = 0; i < post_model_info.num_out_featuremaps; ++i) {
        src[i]->get_data(ofmap[i], false);
    }

    cv::Mat inframe;
    {
        std::lock_guard<std::mutex> flock(frameQueue_mutex); // Lock the frame queue
        inframe = frames_queue.front(); // Get the first frame from the queue
        frames_queue.pop_front(); // Remove the frame from the queue
    }

    // Create display window if not already created
    if(!window_created) {
        cv::namedWindow("Pose Estimation", cv::WINDOW_NORMAL | cv::WINDOW_KEEPRATIO);
        cv::resizeWindow("Pose Estimation", cv::Size(origWidth, origHeight));
        cv::moveWindow("Pose Estimation", 0, 0);
        window_created = true;
    }  

    std::vector<Box> all_boxes;
    std::vector<float> all_scores;
    std::vector<cv::Rect> cv_boxes;
         
    for (int i = 0; i < dets_length; ++i) {

        float x0 = ofmap[0][dets_length * 0 + i]; // Extract x-coordinate of the bounding box center
        float y0 = ofmap[0][dets_length * 1 + i]; // Extract y-coordinate of the bounding box center
        float w = ofmap[0][dets_length * 2 + i];  // Extract width of the bounding box
        float h = ofmap[0][dets_length * 3 + i];  // Extract height of the bounding box
        float confidence = ofmap[0][dets_length * 4 + i]; // Extract confidence score of the detection


        if (confidence > box_score) { // Check if confidence exceeds threshold
            Box box;
            box.confidence = confidence;

            // Scale box coordinates back to the original image size
            float y_factor = inframe.rows / float(model_input_height); // Calculate scaling factor for height
            float x_factor = inframe.cols / float(model_input_width);  // Calculate scaling factor for width
            x0 = x0 * x_factor; // Scale x-coordinate
            y0 = y0 * y_factor; // Scale y-coordinate
            w = w * x_factor;   // Scale width
            h = h * y_factor;   // Scale height

            int x1 = (int)(x0 - 0.5 * w); // Calculate top-left x-coordinate
            int x2 = (int)(x0 + 0.5 * w); // Calculate bottom-right x-coordinate
            int y1 = (int)(y0 - 0.5 * h); // Calculate top-left y-coordinate
            int y2 = (int)(y0 + 0.5 * h); // Calculate bottom-right y-coordinate

            // Extract keypoints
            for (int j = 0; j < num_kpts; ++j) { 
                float kpt_x = ofmap[0][dets_length * (5 + j * 3) + i] * x_factor; // Scale keypoint x-coordinate
                float kpt_y = ofmap[0][dets_length * (5 + j * 3 + 1) + i] * y_factor; // Scale keypoint y-coordinate
                float kpt_conf = ofmap[0][dets_length * (5 + j * 3 + 2) + i]; // Get keypoint confidence

                if (kpt_conf > kpt_score) { // Check if keypoint confidence exceeds threshold
                    box.keypoints.push_back(std::make_pair(kpt_x, kpt_y)); // Add valid keypoint
                } else {
                    box.keypoints.push_back(std::make_pair(-1, -1)); // Add invalid keypoint as placeholder
                }
            }

            all_boxes.push_back(box); // Store the box with keypoints
            all_scores.push_back(confidence); // Store the box confidence score
            cv_boxes.push_back(cv::Rect(x1, y1, x2 - x1, y2 - y1)); // Store bounding box for NMS
        }
    }

    // Apply Non-Maximum Suppression (NMS) to filter overlapping boxes
    std::vector<int> nms_result;
    cv::dnn::NMSBoxes(cv_boxes, all_scores, box_score, nms_thr, nms_result);

    std::vector<Box> filtered_boxes;
    for (int idx : nms_result) {
        filtered_boxes.push_back(all_boxes[idx]); // Keep only the boxes after NMS
    }

    // Draw keypoints and connections on the frame
    for (const auto &box : filtered_boxes) {
        for (const auto &connection : KEYPOINT_PAIRS) {
            int idx1 = connection.first;
            int idx2 = connection.second;

            if (idx1 < box.keypoints.size() && idx2 < box.keypoints.size()) {
                auto kpt1 = box.keypoints[idx1];
                auto kpt2 = box.keypoints[idx2];

                if (kpt1.first != -1 && kpt1.second != -1 && kpt2.first != -1 && kpt2.second != -1) {
                    cv::line(inframe, cv::Point(kpt1.first, kpt1.second), cv::Point(kpt2.first, kpt2.second), cv::Scalar(255, 255, 255), 3);
                }
            }
        }
        
        // Draw keypoints on the frame
        for (int k = 0; k < box.keypoints.size(); ++k) {
            auto &kpt = box.keypoints[k];
            if (kpt.first != -1 && kpt.second != -1) {
                cv::circle(inframe, cv::Point(kpt.first, kpt.second), 4, COLOR_LIST[k % COLOR_LIST.size()], -1);
            }
        }
    }

    //Calulate FPS once every AVG_FPS_CALC_FRAME_COUNT frames     
    frame_count++;
    if (frame_count == 1)
    {
        start_ms = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch());
    }
    else if (frame_count % AVG_FPS_CALC_FRAME_COUNT == 0)
    {
        std::chrono::milliseconds duration =
            std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()) - start_ms;
        fps_number = (float)AVG_FPS_CALC_FRAME_COUNT * 1000 / (float)(duration.count());
        sprintf(fps_text, "FPS = %.1f", fps_number);
        frame_count = 0;
    }

    //Write FPS values on the display image
    cv::putText(inframe,fps_text,
        cv::Point2i(450, 30), // origin of text (bottom left of textbox)
        cv::FONT_ITALIC,
        0.8, // font scale
        cv::Scalar(255, 255, 255), // color (green)
        2 // thickness
    );

    // Display the frame with the detected keypoints and connections
    cv::imshow("Pose Estimation", inframe);
    
    if (cv::waitKey(1) == 'q') {
        runflag.store(false);
    }

    return true; 
}

void run_inference() {

    runflag.store(true); 
    
    if(use_cam) { 
        std::cout << "use cam";
        vcap.open(0, cv::CAP_V4L2); 
    } else {
        vcap.open(videoPath.c_str()); 
    }

    if(vcap.isOpened()) { 
        std::cout << "videocapture opened \n";
        origWidth = vcap.get(cv::CAP_PROP_FRAME_WIDTH); // Get original frame width
        origHeight = vcap.get(cv::CAP_PROP_FRAME_HEIGHT); // Get original frame height
    } else {
        std::cout << "videocapture NOT opened \n";
        runflag.store(false); 
    }

    if(runflag.load()) { 

        MX::Runtime::MxAccl accl;
        accl.connect_dfp(modelPath); // Initialize the accelerator with the model
        
        accl.connect_post_model(onnx_postprocessing_model_path, 0); // Connect post-processing model
        post_model_info = accl.get_post_model_info(0); // Get post-processing model info

        model_info = accl.get_model_info(0); // Get main model info

        model_input_height = model_info.in_featuremap_shapes[0][0]; // Get model input height
        model_input_width = model_info.in_featuremap_shapes[0][1]; // Get model input width

        ofmap.reserve(post_model_info.num_out_featuremaps);

        // Allocate memory for feature maps
        for(int i = 0; i < post_model_info.num_out_featuremaps; ++i) {
            ofmap.push_back(new float[post_model_info.out_featuremap_sizes[i]]);
        }

        accl.connect_stream(&incallback_getframe, &outcallback_getmxaoutput, 10 /*unique stream ID*/, 0 /*Model ID*/); // Connect input and output streams
        
        std::cout << "Connected stream \n\n\n";
        accl.start(); // Start inference
        accl.wait();  // Wait for inference to complete
        accl.stop();  // Stop the accelerator

        // Clean up allocated memory for feature maps
        for (auto& fmap : ofmap) {
            delete[] fmap;
            fmap = NULL;
        }
        std::cout << "\n\rAccl stop called \n";  

    }
}

int main(int argc, char* argv[]){

    if(argc>1){

        std::string inputType(argv[1]);

        if(inputType == "--cam"){
            use_cam = true;
            runflag.store(true);
        }
        else if(inputType == "--video"){
            use_cam = false;
            runflag.store(true);
        }
        else{
            std::cout << "\n\nIncorrect Argument Passed \n\tuse ./app [--cam ]\n\n\n";
            runflag.store(false);
        }

    }
    else{
        std::cout << "\n\nNo Arguments Passed \n\tuse ./app [--cam ]\n\n\n";
        runflag.store(false);
    }

    signal(SIGINT, signalHandler);

    if(runflag.load()){

        std::cout << "application start \n";
        std::cout << "model path = " << modelPath.c_str() << "\n";

        run_inference();
    }

    else{
        std::cout << "App exiting without execution \n\n\n";       
    }

    return 1;
}
