"""
dfp.py
DFP file parsing and dfp_inspect functionality


Copyright (c) 2024 MemryX Inc.
MIT License

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

"""

import sys
import json
import struct
import argparse

from io import BytesIO
from pathlib import Path
from pprint import pprint

from mxpack import MxPackDecode

class Dfp:
    """
    A container class for the DFP that provides easy access to metadata like input/output shapes, number of MXAs etc.

    Parameters
    ----------
    path_or_bytes: str/pathlib.Path
        Path to a .dfp file or DFP bytearray

    Example
    -------

    .. code-block:: python

        nc = NeuralCompiler(models='/models/mobilenet.h5', num_chips=4)
        dfp = nc.run()
        dfp.num_chips # 4
        dfp.num_inports # 1
        dfp.num_outports # 1
        dfp.input_shapes # [[224,224,1,3]]
        dfp.output_shapes # [[1,1,1,1000]]
        dfp.write(file_name) # -> file_name.dfp

    """
    def __init__(self, path_or_bytes):
        """
        """
        self._dfp_bytes = None
        if isinstance(path_or_bytes, (str, Path)):
            if not str(path_or_bytes).endswith('.dfp'):
                raise RuntimeError(f"Input file: {path_or_bytes} is not a .dfp file generated by the NeuralCompiler")
            try:
                self._dfp_bytes = open(str(path_or_bytes), 'rb')
            except:
                raise RuntimeError(f"Error opening file {path_or_bytes}")
        elif type(path_or_bytes) is bytearray:
            self._dfp_bytes = BytesIO(path_or_bytes)
        else:
            raise RuntimeError(f"dfp argument is not type str or bytearray, it is type={type(path_or_bytes)}")

        possible_version_info = int.from_bytes(self._dfp_bytes.read(8), byteorder='little')
        if possible_version_info == 6:
            metadata = self.__decode_v6(self._dfp_bytes)
        elif possible_version_info == 5:
            metadata = self.__decode_v5(self._dfp_bytes)
        else:
            raise RuntimeError(f"Unsupported DFP version: {possible_version_info}, please recompile to the latest version")

        # rewind (in case it's a bytearray they wanna use again)
        self._dfp_bytes.seek(0)

        self._metadata = metadata

    @property
    def version(self):
        """
        Returns the DFP version
        """
        return self._metadata['dfp_version']

    @property
    def num_inports(self):
        """
        Returns the number of active/used input ports
        """
        return self._metadata['num_inports']

    @property
    def num_outports(self):
        """
        Returns the number of active/used output ports
        """
        return self._metadata['num_outports']

    @property
    def input_shapes(self):
        """
        Returns a list of feature map shapes of the active/used input ports
        """
        shapes = []
        for port in self._metadata['input_ports'].values():
            if not port['active']:
                continue
            shapes.append(port['shape'])
        return shapes

    @property
    def input_ports(self):
        """
        Returns a dictionary of metadata of the active/used input ports
        """
        ports = {}
        for idx, port in self._metadata['input_ports'].items():
            if not port['active']:
                continue
            ports[idx] = port
        return ports

    @property
    def output_ports(self):
        """
        Returns a dictionary of metadata of the active/used output ports
        """
        ports = {}
        for idx, port in self._metadata['output_ports'].items():
            if not port['active']:
                continue
            ports[idx] = port
        return ports

    @property
    def output_shapes(self):
        """
        Returns a list of feature map shapes of the active/used output ports
        """
        shapes = []
        for port in self._metadata['output_ports'].values():
            if not port['active']:
                continue
            shapes.append(port['shape'])
        return shapes

    @property
    def output_names(self):
        """
        Returns a list of node names of the active/used output ports
        """
        names = []
        for port in self._metadata['output_ports'].values():
            if not port['active']:
                continue
            names.append(port['layer_name'])
        return names

    @property
    def num_chips(self):
        """
        Returns the number of MXAs the DFP was compiled for
        """
        return self._metadata['num_mxas']

    @property
    def chip_gen(self):
        """
        Returns the generation of the MXA the DFP was compiled for
        """
        return self._metadata['mxa_gen']

    @property
    def models(self):
        """
        Returns a list of paths/type information of the models compiled in the DFP
        """
        return self._metadata['models']

    def write(self, fname):
        """
        Writes the DFP contents to a file

        Parameters
        ----------
        fname: str/pathlib.Path
            Path to write/save the DFP on disk
        """
        with open(fname, 'wb') as f:
            f.write(self._dfp_bytes.read())
        self._dfp_bytes.seek(0)
        return fname

    def __decode_v6(self, f):

        dat = MxPackDecode(bytearray(f.read()))

        meta = {"dfp_version": "6"}


        # metadata section
        #####################################################################################
        if dat["sim_dfp"]["enabled"]:
            meta["simulator_dfp_size"] = float( int(len(dat["sim_dfp"]["data"])) / 1000000 )
        else:
            meta["simulator_dfp_size"] = 0

        meta["hardware_dfp_size"] = float( int(len(dat["hw_dfp"])) / 1000000 )

        meta["compile_time"] = dat["compile_timestamp"]
        meta["models"] = dat["models"]
        meta["compiler_version"] = dat["compiler_version"]
        meta["compiler_args"] = dat["compiler_args"]

        meta["num_inports"] = dat["num_inports"]
        meta["num_outports"] = dat["num_outports"]

        meta["sim_info"] = {"towers": dat["sim_meta"]["towers"],
                            "frequency": dat["freq"] }

        if dat["sim_meta"]["intgen"] == 4:
            meta["mxa_gen"] = "Cascade+"
        elif dat["sim_meta"]["intgen"] == 3:
            meta["mxa_gen"] = "Cascade"
        elif dat["sim_meta"]["intgen"] == 5:
            meta["mxa_gen"] = "Detroit"
        elif dat["sim_meta"]["intgen"] == 2:
            meta["mxa_gen"] = "Barton"
        else:
            raise RuntimeError(f"Unknown MXA generation {dat['sim_meta']['intgen']} detected!")

        meta["num_mxas"] = dat["sim_meta"]["num_mpus"]

        # port info section
        #####################################################################################

        inport_info_dict = {}
        for p in dat["inport_info"]:

            if p["active"]:
                inport_info_dict[p["port"]] = {
                        "active": True,
                        "set": p["port_set"],
                        "mxa_id": p["mpu_id"],
                        "model_index": p["model_index"],
                        "layer_name": p["layer_name"],
                        "data_type": "float",   # FIXME: we disabled RGB888
                        "data_range_enabled": p["range_convert"]["enabled"],
                        "data_range_shift": p["range_convert"].get("shift", float(0)),
                        "data_range_scale": p["range_convert"].get("scale", float(0)),
                        "packing_format": p["packing_format"]["name"],
                        "shape": p["mxa_shape"]
                    }
            else:
                inport_info_dict[p["port"]] = {"active": False, "set": 0, "mxa_id": -1,
                                "model_index": -1, "layer_name": "", "data_type": "float",
                                               "data_range_enabled": False, "data_range_shift": 0, "data_range_scale": 0, "packing_format": "fp32",
                                "shape": [0, 0, 0, 0] }

        meta['input_ports'] = inport_info_dict


        outport_info_dict = {}
        for p in dat["outport_info"]:

            if p["active"]:
                outport_info_dict[p["port"]] = {
                        "active": True,
                        "set": p["port_set"],
                        "mxa_id": p["mpu_id"],
                        "model_index": p["model_index"],
                        "layer_name": p["layer_name"],
                        "data_type": "float",
                        "packing_format": p["packing_format"]["name"],
                        "shape": p["mxa_shape"],
                        "hpoc_enabled": p["hpoc"]["enabled"],
                        "hpoc_list": p["hpoc"].get("channels", []),
                        "hpoc_fm_shape": p["hpoc"].get("shape", [0,0,0,0])
                        }
            else:
                outport_info_dict[p["port"]] = {"active": False, "set": 0, "mxa_id": -1,
                        "model_index": -1, "layer_name": "", "data_type": "float", "packing_format": "fp32",
                        "shape": [0, 0, 0, 0], "hpoc_enabled": False, "hpoc_list": [],
                        "hpoc_fm_shape": [0, 0, 0, 0] }

        meta['output_ports'] = outport_info_dict

        return meta

    def __decode_v5(self, f):

        meta = {"dfp_version": "5"}

        total_sim_bytes = int.from_bytes(f.read(8), byteorder='little')
        meta['simulator_dfp_size'] = float(total_sim_bytes / 1000000)

        # compile date/time
        numtimebytes = int.from_bytes(f.read(1), byteorder='little')
        meta['compile_time'] = str(f.read(numtimebytes), 'ascii')

        # model names
        total_model_str_bytes = int.from_bytes(f.read(4), byteorder='little')
        total_num_models = int.from_bytes(f.read(1), byteorder='little')
        meta['models'] = []
        for i in range(total_num_models):
            stringleng = int.from_bytes(f.read(2), byteorder='little')
            meta['models'].append(str(f.read(stringleng), 'utf-8'))

        # compiler version
        ver_str_len = int.from_bytes(f.read(1), byteorder='little')
        meta['compiler_version'] = str(f.read(ver_str_len), 'ascii')

        # compiler args
        args_len = int.from_bytes(f.read(4), byteorder='little')
        args_str = str(f.read(args_len), 'utf-8')
        meta['compiler_args'] = json.loads(args_str)


        gen_towers = int.from_bytes(f.read(1), byteorder='little')
        if (gen_towers & 0x0F) == 4:
            gen = "Cascade+"
        elif (gen_towers & 0x0F) == 3:
            gen = "Cascade"
        elif (gen_towers & 0x0F) == 2:
            gen = "Barton"
        else:
            raise RuntimeError(f"Unknown MXA generation {gen_towers & 0x0F} detected!")

        towers = (gen_towers & 0xF0) >> 4
        mpus = int.from_bytes(f.read(1), byteorder='little')
        frequency = int.from_bytes(f.read(2), byteorder='little')
        num_inports = int.from_bytes(f.read(1), "little")
        num_outports = int.from_bytes(f.read(1), "little")
        meta['num_inports'] = num_inports
        meta['num_outports'] = num_outports

        meta['sim_info'] = {'towers': towers, 'frequency': frequency}

        meta['mxa_gen'] = gen
        meta['num_mxas'] = mpus
        inport_info_dict = {}
        outport_info_dict = {}

        # INPORTS
        # ==============================================================
        for i in range(num_inports):
            port_idx_and_status = int.from_bytes(f.read(1), "little")
            # if inport is enabled
            if (port_idx_and_status & 0x0080) == 0x80:
                port_set = int.from_bytes(f.read(1), "little")
                mpu_id = int.from_bytes(f.read(1), "little")
                model_index = int.from_bytes(f.read(1), "little")
                layer_name_leng = int.from_bytes(f.read(2), "little")
                layer_name = str(f.read(layer_name_leng), 'utf-8')
                fmt = int.from_bytes(f.read(1), "little")
                range_en = int.from_bytes(f.read(1), "little")
                range_sh = struct.unpack('<f', bytearray(f.read(4)))[0]
                range_sc = struct.unpack('<f', bytearray(f.read(4)))[0]
                row = int.from_bytes(f.read(2), "little")
                col = int.from_bytes(f.read(2), "little")
                z = int.from_bytes(f.read(2), "little")
                ch = int.from_bytes(f.read(4), "little")

                # set the data dict
                port = port_idx_and_status & 0x007F
                inport_info_dict[port] = {"active": True}
                inport_info_dict[port]['mxa_id'] = mpu_id
                inport_info_dict[port]['model_index'] = model_index
                inport_info_dict[port]['layer_name'] = layer_name
                inport_info_dict[port]['set'] = port_set
                if fmt == 0:
                    inport_info_dict[port]['data_type'] = "float"
                    inport_info_dict[port]['packing_format'] = "gbf80"
                elif fmt == 1:
                    inport_info_dict[port]['data_type'] = "rgb565"
                    inport_info_dict[port]['packing_format'] = "rgb565"
                elif fmt == 2:
                    inport_info_dict[port]['data_type'] = "uint8"
                    inport_info_dict[port]['packing_format'] = "rgb888"
                elif fmt == 3:
                    inport_info_dict[port]['data_type'] = "yuv422"
                    inport_info_dict[port]['packing_format'] = "yuv422"
                elif fmt == 4:
                    inport_info_dict[port]['data_type'] = "float"
                    inport_info_dict[port]['packing_format'] = "bf16" # actually BF16
                elif fmt == 5:
                    inport_info_dict[port]['data_type'] = "float"
                    inport_info_dict[port]['packing_format'] = "fp32" # actually FP32
                elif fmt == 6:
                    inport_info_dict[port]['data_type'] = "float"
                    inport_info_dict[port]['packing_format'] = "gbf80_row"
                else:
                    inport_info_dict[port]['data_type'] = "unknown"
                    inport_info_dict[port]['packing_format'] = "unknown: " + str(fmt)
                # range info
                if range_en == 1:
                    inport_info_dict[port]['data_range_enabled'] = True
                else:
                    inport_info_dict[port]['data_range_enabled'] = False
                inport_info_dict[port]['data_range_shift'] = range_sh
                inport_info_dict[port]['data_range_scale'] = range_sc
                # shapes
                inport_info_dict[port]['shape'] = [row, col, z, ch]
            else:
                # inactive port
                port = port_idx_and_status
                inport_info_dict[port] = {"active": False, "set": 0, "mxa_id": -1,
                        "model_index": -1, "layer_name": "", "data_type": "float",
                        "data_range_enabled": False, "data_range_shift": 0, "data_range_scale": 0,
                        "shape": [0, 0, 0, 0] }

        # OUTPORTS
        # ==============================================================
        for i in range(num_outports):
            port_idx_and_status = int.from_bytes(f.read(1), "little")
            if (port_idx_and_status & 0x0080) == 0x80:
                port_set = int.from_bytes(f.read(1), "little")
                mpu_id = int.from_bytes(f.read(1), "little")
                model_index = int.from_bytes(f.read(1), "little")
                layer_name_leng = int.from_bytes(f.read(2), "little")
                layer_name = str(f.read(layer_name_leng), 'utf-8')
                fmt = int.from_bytes(f.read(1), "little")
                row = int.from_bytes(f.read(2), "little")
                col = int.from_bytes(f.read(2), "little")
                z = int.from_bytes(f.read(2), "little")
                ch = int.from_bytes(f.read(4), "little")

                hpoc_eni = int.from_bytes(f.read(1), "little")
                # hpoc en?
                if hpoc_eni == 1:
                    hpoc_en = True
                    # hpoc_fm_shape
                    hpoc_row = int.from_bytes(f.read(2), "little")
                    hpoc_col = int.from_bytes(f.read(2), "little")
                    hpoc_z = int.from_bytes(f.read(2), "little")
                    hpoc_ch = int.from_bytes(f.read(4), "little")
                    hpoc_fm_shape = [hpoc_row, hpoc_col, hpoc_z, hpoc_ch]
                    # hpoc list length + data
                    hpoc_ch_list_len = int.from_bytes(f.read(2), "little")
                    hpoc_ch_list = []
                    for _ in range(hpoc_ch_list_len):
                        hpoc_ch_list.append(int.from_bytes(f.read(2), "little"))
                else:
                    # hpoc disabled!
                    hpoc_en = False
                    hpoc_ch_list_len = 0
                    hpoc_ch_list = []
                    hpoc_fm_shape = []

                # port index + status
                port = port_idx_and_status & 0x007F
                outport_info_dict[port] = {"active": True}
                outport_info_dict[port]['mxa_id'] = mpu_id
                outport_info_dict[port]['model_index'] = model_index
                outport_info_dict[port]['layer_name'] = layer_name
                # port set
                outport_info_dict[port]['set'] = port_set
                # data format
                if fmt == 0:
                    outport_info_dict[port]['data_type'] = "float"
                    outport_info_dict[port]['packing_format'] = "gbf80"
                elif fmt == 4:
                    outport_info_dict[port]['data_type'] = "float"
                    outport_info_dict[port]['packing_format'] = "bf16"
                elif fmt == 5:
                    outport_info_dict[port]['data_type'] = "float"
                    outport_info_dict[port]['packing_format'] = "fp32"
                elif fmt == 6:
                    outport_info_dict[port]['data_type'] = "float"
                    outport_info_dict[port]['packing_format'] = "gbf80_row"
                else:
                    outport_info_dict[port]['data_type'] = "unknown"
                    outport_info_dict[port]['packing_format'] = "unknown: " + str(fmt)
                # shapes
                outport_info_dict[port]['shape'] = [row, col, z, ch]

                # hpoc info
                outport_info_dict[port]['hpoc_enabled'] = hpoc_en
                outport_info_dict[port]['hpoc_list'] = hpoc_ch_list
                outport_info_dict[port]['hpoc_fm_shape'] = hpoc_fm_shape
            else:
                # inactive port
                port = port_idx_and_status
                outport_info_dict[port] = {"active": False, "set": 0, "mxa_id": -1,
                        "model_index": -1, "layer_name": "", "format": "float",
                        "shape": [0, 0, 0, 0], "hpoc_enabled": False, "hpoc_list": [],
                        "hpoc_fm_shape": [0, 0, 0, 0] }


        meta['input_ports'] = inport_info_dict
        meta['output_ports'] = outport_info_dict

        # return to start and skip to hw dfp
        f.seek(0)
        f.seek(16, 1)
        f.seek(total_sim_bytes, 1)
        hw_data_size = int.from_bytes(f.read(8), "little")
        meta['hardware_dfp_size'] = float(hw_data_size / 1000000)

        return meta

def pretty_print(meta):

    # Define colors
    orange = "\033[38;2;255;153;51m"
    blue = "\033[38;2;0;153;255m"
    green = "\033[38;2;102;255;102m"
    yellow= "\033[38;2;255;255;51m"
    gray= "\033[38;2;150;150;150m"
    BOLD = '\033[1m'
    d = "\033[0m"

    print("\n╔"+"═"*38+"╗\n"+
              "║             DFP Inspect              ║\n"+
              "║  Copyright (c) 2019-2024 MemryX Inc. ║\n"+
              "╚"+"═"*38+"╝\n")

    print("═"*40)
    print(blue+BOLD+"DFP Info:"+d)
    if meta['dfp_version'] == "legacy":
        print(f"DFP format:        {meta['dfp_version']}")
    else:
        print(f"DFP format:        v{meta['dfp_version']}")

    if 'compiler_version' in meta:
        print(f"Compiler version:  {meta['compiler_version']}")

    if 'compile_time' in meta:
        print(f"Compile Time:      {meta['compile_time']}")

    if 'simulator_dfp_size' in meta:
        print(f"Simulator DFP:     {round(meta['simulator_dfp_size'],2)} MB")

    if 'hardware_dfp_size' in meta:
        print(f"Hardware DFP:      {round(meta['hardware_dfp_size'],2)} MB")

    print(f"MXA Generation:    {meta['mxa_gen']}")
    print(f"Number of MXAs:    {meta['num_mxas']}")
    print("─"*40)
    print(blue+BOLD+"Models Info:"+d)

    if 'models' in meta:
        for i, m in enumerate(meta['models']):
            print(f"Model {i:2d}: {m}")

    print("─"*40)
    print(blue+BOLD+"Active Input Ports Configs:"+d)
    for i in range(meta['num_inports']):
        if meta['input_ports'][i]['active']:
            print(f"Port {i}:  ", end='')
            pprint(meta['input_ports'][i])
        # else:
        #     print(f"Port {i}:  inactive")
    print("─"*40)
    print(blue+BOLD+"Active Output Ports Configs:"+d)
    for i in range(meta['num_outports']):
        if meta['output_ports'][i]['active']:
            print(f"Port {i}:  ", end='')
            pprint(meta['output_ports'][i])
        # else:
        #     print(f"Port {i}:  inactive")

    print("═"*40)


###############################################################################

def main():
    parser = argparse.ArgumentParser(description = "\033[34mMemryX DFP Inspector\033[0m")
    parser.add_argument("filename",
                            type = str,
                            action = "store",
                            help = "path to .dfp file to inspect")

    args = parser.parse_args()
    pretty_print(Dfp(args.filename)._metadata)


if __name__=="__main__":
    main()
