import torch
import numpy as np
import json
from pathlib import Path
import argparse, os  
import onnx

from ultralytics import YOLO
from ultralytics.models.yolo.detect.val import DetectionValidator
from ultralytics.models.yolo.pose.val import PoseValidator
from ultralytics.models.yolo.segment.val import SegmentationValidator
from ultralytics.data.utils import check_det_dataset
from ultralytics.utils import LOGGER, TQDM

import memryx as mx
import onnxruntime as ort


class MxaDetectionValidator(DetectionValidator):
    """
    The Validator must be a child of BaseValidator which is the parent
    of DetectionValidator. The BaseValidator defines the __call__
    method which we need to override.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Set required attributes
        self.stride = 32
        self.training = False

        model_name = Path(self.args.model).stem
        LOGGER.info(f"\033[32mRunning {model_name} inference on MXA\033[0m")

        # Ensure your paths/naming scheme matches
        self.mxa = mx.SyncAccl(f"weights/{model_name}.dfp")
        self.ort = ort.InferenceSession(f"weights/{model_name}_post.onnx")

    def __call__(self, model):
        model.eval()

        # Create COCO dataloader
        self.data = check_det_dataset(self.args.data)
        self.dataloader = self.get_dataloader(
            self.data.get(self.args.split), self.args.batch
        )

        # Validation Loop
        self.init_metrics((model))
        self.jdict = []
        progress_bar = TQDM(
            self.dataloader, desc=self.get_desc(), total=len(self.dataloader)
        )
        for batch in progress_bar:
            batch = self.preprocess(batch)
            preds = self.mxa_detect(batch["img"])
            preds = self.postprocess(preds)
            self.update_metrics(preds, batch)

        # Compute and print stats
        stats = self.get_stats()
        self.finalize_metrics()
        self.print_results()

        # Save predictions and evaluate on pycocotools
        with open(str(self.save_dir / "predictions.json"), "w") as f:
            LOGGER.info(f"Saving {f.name}...")
            json.dump(self.jdict, f)
        stats = self.eval_json(stats)

        return stats

    def mxa_detect(self, img):
        """
        Detection using MXA accelerator.

        Args:
            img (torch.Tensor): Input image. (1, 3, 640, 640)

        Returns:
            preds (list): List of length 2.
                preds[0] (torch.Tensor): Predictions. (1, 84, 8400)
                preds[1] (None): Unused fmaps
        Notes:
            Fj in (64, 80) and Fi in (80, 40, 20)
        """
        # Pass images through accelerator
        img = img.detach().cpu().numpy()  # (1, 3, 640, 640)
        accl_out = self.mxa.run(img)  

        # Process accl out for onnxruntime
        onnx_inp_names = [inp.name for inp in self.ort.get_inputs()]
        input_feed = {k: v for k, v in zip(onnx_inp_names, accl_out)}

        # Pass fmaps through onnxruntime
        onnx_out = self.ort.run(None, input_feed)
        out = torch.from_numpy(onnx_out[0])  # (1, 84, 8400)

        preds = [out, None]
        return preds


class MxaSegmentationValidator(SegmentationValidator):
    """
    The Validator must be a child of BaseValidator which is the parent
    of SegmentationValidator. The BaseValidator defines the __call__
    method which we need to override.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Set required attributes
        self.stride = 32
        self.training = False
        self.args.plots = False

        model_name = Path(self.args.model).stem
        LOGGER.info(f"\033[32mRunning {model_name} inference on MXA\033[0m")

        # Ensure your paths/naming scheme matches
        self.mxa = mx.SyncAccl(f"weights/{model_name}.dfp")
        self.ort = ort.InferenceSession(f"weights/{model_name}_post.onnx")

    def __call__(self, model):
        model.eval()

        # Create COCO dataloader
        self.data = check_det_dataset(self.args.data)
        self.dataloader = self.get_dataloader(
            self.data.get(self.args.split), self.args.batch
        )

        # Validation Loop
        self.init_metrics((model))
        self.jdict = []
        progress_bar = TQDM(
            self.dataloader, desc=self.get_desc(), total=len(self.dataloader)
        )
        for i, batch in enumerate(progress_bar):
            self.batch_i = i  # For plots
            batch = self.preprocess(batch)
            preds = self.mxa_segment(batch["img"])
            preds = self.postprocess(preds)
            self.update_metrics(preds, batch)

        # Compute and print stats
        stats = self.get_stats()
        self.finalize_metrics()
        self.print_results()

        # Save predictions and evaluate on pycocotools
        with open(str(self.save_dir / "predictions.json"), "w") as f:
            LOGGER.info(f"Saving {f.name}...")
            json.dump(self.jdict, f)
        stats = self.eval_json(stats)

        return stats

    def mxa_segment(self, img):
        """
        Segmentation using MXA accelerator.

        Args:
            img (torch.Tensor): Input image. (1, 3, 640, 640)

        Returns:
            preds (list): List of length 2.
                preds[0] (torch.Tensor): Boxes (1, 116, 8400)
                preds[1] (torch.Tensor): Masks (1, 32, 160, 160)

        Notes:
            For shapes: Fj in (64, 80) and Fi in (80, 40, 20)
        """
        # Pass images through accelerator
        img = img.detach().cpu().numpy()  # (1, 3, 640, 640)
        accl_out = self.mxa.run(img)  # (10, ...)

        # Prepare accelerator output as input to onnx post-processor
        # Reorder names to match accl output order
        onnx_inp_names = [inp.name for inp in self.ort.get_inputs()]
        for i, j in [(3, 7), (6, 8)]:
            onnx_inp_names.insert(i, onnx_inp_names.pop(j))

        # Trailing reshapes need to be handled manually
        input_feed = {
            name: (
                fmap
                if "Reshape" not in name
                else np.reshape(fmap, (1, fmap.shape[1], -1))
            )
            for name, fmap in zip(onnx_inp_names, accl_out)
        }

        onnx_out = self.ort.run(None, input_feed)
        preds = [
            torch.from_numpy(onnx_out[1]),  # Boxes (1, 116, 8400)
            torch.from_numpy(onnx_out[0]),  # Masks (1, 32, 160, 160)
        ]
        return preds


class MxaPoseValidator(PoseValidator):
    """
    The Validator must be a child of BaseValidator which is the parent
    of PoseValidator. The BaseValidator defines the __call__ method
    which we need to override.
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Set required attributes
        self.stride = 32
        self.training = False

        model_name = Path(self.args.model).stem
        LOGGER.info(f"\033[32mRunning {model_name} inference on MXA\033[0m")

        # Create MXA and Onnx runtimes
        self.mxa = mx.SyncAccl(f"weights/{model_name}.dfp")
        self.ort = ort.InferenceSession(f"weights/{model_name}_post.onnx")

    def __call__(self, model):
        model.eval()

        # Create COCO dataloader
        self.data = check_det_dataset(self.args.data)
        self.dataloader = self.get_dataloader(
            self.data.get(self.args.split), self.args.batch
        )

        # Validation Loop
        self.init_metrics((model))
        self.jdict = []
        progress_bar = TQDM(
            self.dataloader, desc=self.get_desc(), total=len(self.dataloader)
        )
        for batch in progress_bar:
            batch = self.preprocess(batch)
            preds = self.mxa_pose(batch["img"])
            preds = self.postprocess(preds)
            self.update_metrics(preds, batch)

        # Compute and print stats
        stats = self.get_stats()
        self.finalize_metrics()
        self.print_results()

        # Save predictions and evaluate on pycocotools
        with open(str(self.save_dir / "predictions.json"), "w") as f:
            LOGGER.info(f"Saving {f.name}...")
            json.dump(self.jdict, f)
        stats = self.eval_json(stats)

        return stats

    def mxa_pose(self, img):
        """
        Pose Estimation using MXA accelerator.

        Args:
            img (torch.Tensor): Input image. 1, 3, 640, 640)

        Returns:
            preds (list): List of length 2.
                preds[0] (torch.Tensor): Predictions. (1, 56, 8400)
                preds[1] (None): Unused loss output
        """
        # Pass images through accelerator
        img = img.detach().cpu().numpy()  # (1, 3, 640, 640)
        accl_out = self.mxa.run(img)  # (9, ...)

        # Process accl out for onnxruntime
        # Reorder names to match accl output order
        onnx_inp_names = [inp.name for inp in self.ort.get_inputs()]
        for i, j in [(2, 6), (5, 7)]:
            onnx_inp_names.insert(i, onnx_inp_names.pop(j))
        # Trailing reshapes need to be handled manually
        input_feed = {
            name: (
                fmap
                if "Reshape" not in name
                else np.reshape(fmap, (1, fmap.shape[1], -1))
            )
            for name, fmap in zip(onnx_inp_names, accl_out)
        }

        # Pass fmaps through onnxruntime
        onnx_out = self.ort.run(None, input_feed)
        out = torch.from_numpy(onnx_out[0])  # (1, 56, 8400)

        preds = [out, None]
        return preds


weights_dir = os.getcwd() / Path("weights")


def dfp_exists(model):
    """Checks that the DFP and post-processing ONNX model exists"""
    model_name = Path(model.ckpt_path).stem
    dfp = (weights_dir / f"{model_name}.dfp").exists()
    post_onnx = (weights_dir / f"{model_name}_post.onnx").exists()
    return dfp and post_onnx


def compile_model(model):
    """Exports model to ONNX and compiles it to DFP."""
    model_name = Path(model.ckpt_path).stem
    # Export to onnx
    model.export(format="onnx", simplify=True, batch=1)
    onnx_model = onnx.load(weights_dir / f"{model_name}.onnx")
    # Compile and save the DFP file
    nc = mx.NeuralCompiler(
        models=onnx_model,
        autocrop=True,
        no_sim_dfp=True,
        dfp_fname=weights_dir / f"{model_name}.dfp",
        verbose=1,
    )
    nc.run()
    # Rename the exported ONNX files
    os.rename(
        weights_dir / "main_graph_crop.onnx",
        weights_dir / f"{model_name}_crop.onnx",
    )
    os.rename(
        weights_dir / "main_graph_post.onnx",
        weights_dir / f"{model_name}_post.onnx",
    )
    

if __name__=="__main__":

    parser = argparse.ArgumentParser(description="Validate YOLOv8 on the MXA")
    parser.add_argument(
        "--model",
        type=str,
        choices=["detect", "segment", "pose"],
        default="detect",
        help="Specify the yolov8 model you wish to use for validation on the mxa. Choices are [detect, segment, pose]. Default is detect.",
    )

    args = parser.parse_args()
    model_to_use = args.model

    if model_to_use == "detect":
        model = YOLO(f"weights/yolov8m.pt")
        if not dfp_exists(model):
            compile_model(model)
        model.val(validator=MxaDetectionValidator, batch=1, rect=False)
        
    elif model_to_use == 'segment':
        model = YOLO(f"weights/yolov8m-seg.pt")
        if not dfp_exists(model):
            compile_model(model)
        model.val(validator=MxaSegmentationValidator, batch=1, rect=False)

    elif model_to_use == 'pose':
        model = YOLO(f"weights/yolov8m-pose.pt")
        if not dfp_exists(model):
            compile_model(model)
        model.val(validator=MxaPoseValidator, data="coco-pose.yaml", batch=1, rect=False)

