import memryx as mx
import argparse
import shutil

# ANSI color codes for terminal output
BLUE = "\033[94m"
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"

# Suppress warnings
import warnings
warnings.filterwarnings("ignore")

###############################################################################

def compile_model(model_path, split_number=0):
    """
    Compiles a model for a given split number using the memryx NeuralCompiler.

    Parameters:
        model_path (str): The file path to the model.
        split_number (int): The split number for which to compile the model.

    Returns:
        tuple: A tuple containing:
            - dfp: The compiled dfp object if successful, otherwise None.
            - all_splits_compiled (bool): True if all valid splits have been exhausted.
    """
    all_splits_compiled = False
    try:
        dfp_fname = f"dfp_{split_number}.dfp"
        nc = mx.NeuralCompiler(
            models=model_path,
            verbose=1,
            autocrop=True,
            no_sim_dfp=True,
            effort="hard",
            dfp_fname=dfp_fname,
            split_number=split_number
        )
        dfp = nc.run()
    except Exception as e:
        if "overflow sweep" in str(e):
            all_splits_compiled = True
            print(f"{GREEN}[Done]{RESET} All valid splits have been exhausted.")
        else:
            print(f"{RED}[ERROR]{RESET} Error compiling model with split_number={split_number}: {e}")
        dfp = None

    return dfp, all_splits_compiled

###############################################################################

def benchmark_dfp(dfp, num_runs=3):
    """
    Benchmarks the given dfp object by running a fixed number of frames.

    Parameters:
        dfp: The dfp object to benchmark.
        num_runs (int): The number of times to run the benchmark.

    Returns:
        float: The frames per second (FPS) achieved during benchmarking.
    """
    max_fps = 0
    for _ in range(num_runs):
        with mx.Benchmark(dfp=dfp, verbose=1) as bm:
            _, _, fps = bm.run(frames=500)
            if fps > max_fps:
                max_fps = fps
    return max_fps

###############################################################################

def main(model_path):
    """
    Main function to perform closed-loop compilation and benchmarking of model splits.

    Parameters:
        model_path (str): The file path to the model.
    """
    best_fps = -1
    best_split_number = -1

    print(f"{BLUE}" + "=" * 50+f"{RESET}")
    print("Starting Closed-Loop Compilation and Benchmarking")
    print(f"{BLUE}" + "=" * 50+f"{RESET}")

    for split_number in range(25):
        print(f"\n{BLUE}[INFO]{RESET} Processing split number: {split_number}")
        dfp, all_splits_compiled = compile_model(model_path, split_number)
        if all_splits_compiled:
            break
        if dfp is None:
            continue
        fps = benchmark_dfp(dfp)
        print(f"{GREEN}[RESULT]{RESET} Split {split_number} - FPS: {fps:.0f}")

        if fps > best_fps:
            best_fps = fps
            best_split_number = split_number
            print(f"{GREEN}[UPDATE]{RESET} New best split: {best_split_number} with FPS: {best_fps:.0f}")

        print(f"\n{BLUE}" + "=" * 50 +f"{RESET}")

    if best_split_number >= 0:
        source_file = f"dfp_{best_split_number}.dfp"
        destination_file = "best_split.dfp"
        try:
            shutil.copyfile(source_file, destination_file)
            print(f"\n{BLUE}[INFO]{RESET} Best split dfp copied from {source_file} to {destination_file}")
        except Exception as e:
            print(f"{RED}[ERROR]{RESET} Failed to copy best split dfp: {e}")
    else:
        print(f"\n{BLUE}[INFO]{RESET} No valid split found.")

###############################################################################

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Compile and benchmark model splits.')
    parser.add_argument('-m', '--model', type=str, required=True,
                        help='Path to the model file.')
    args = parser.parse_args()
    main(args.model)
