import json
import argparse
import os
import shutil
from memryx import NeuralCompiler

def set_compiler_config(nc, params):
    for key, value in params.items():
        if isinstance(value, dict):  # If a nested dictionary, call recursively
            set_compiler_config(nc, value)
        else:
            # Set the parameter if it exists in the config
            nc.set_config(**{key: value})

def compile_group(nc, group_data):
    
    # Reset configuration of the neural compiler for the new dfp group
    nc.reset_config()

    # Set group-level configurations
    group_params = {key: value for key, value in group_data.items() if key not in ["models", "models_dir"]}
    set_compiler_config(nc, group_params)

    # Compile all models in the group together
    models = group_data["models"]
    nc.set_config(models=models)

    # Set DFP file name from json or dfp_fname if specified
    dfp_fname = group_data.get("dfp_fname", f"{group_data['models_dir']}.dfp")
    nc.set_config(dfp_fname=dfp_fname)

    # Verify and print the current configuration
    config = nc.get_config()
    print(f"Compiling group: {group_data.get('models_dir')} with models: {models}")
    
    print("printing config after final set ")
    print(config)
    
    dfp_dir = group_data['models_dir']

    # Create a directory in the name of the dfp group name given will replace old directory
    os.makedirs(dfp_dir, exist_ok=True)


# Change directory into newly created directory
    original_dir = os.getcwd()
    os.chdir(dfp_dir)

    try:
        # Run the compiler and generate the DFP file
        dfp = nc.run()
        if 'dfp_fname' in group_data:
            print(f"DFP file saved to {group_data['dfp_fname']}.dfp")
        else:
            print(f"DFP file saved to {group_data['models_dir']}.dfp")

    finally:
        #Move back to original directory
        os.chdir(original_dir)


def compile_models_recursively(nc, groups_data):

    # Base case: if no dfp groups left to compile then enbd and return
    if not groups_data:
        return

    # Compile the first dfdp group of models
    group_data = groups_data[0]
    compile_group(nc, group_data)

    # Recursively compile the remaining dfp groups
    compile_models_recursively(nc, groups_data[1:])

def compile_groups(json_file):
    # Load the dfp groups and their parameters from the input JSON file
    with open(json_file, 'r') as file:
        groups_data = json.load(file)["dfps"]
    
    #Initialize the NeuralCompiler
    nc = NeuralCompiler()

    # Start the recursive group compilation
    compile_models_recursively(nc, groups_data)

if __name__ == "__main__":
    # Parse CLI arguments to accept the input JSON file path
    parser = argparse.ArgumentParser(description="Compile neural network models using configurations given in a JSON file")
    parser.add_argument("json_file", type=str, help="Path to the input JSON configuration file.")
    
    # Get the file path from the command line input
    args = parser.parse_args()

    # Compile models based on the provided JSON file
    compile_groups(args.json_file)
