
import os, glob
from PIL import Image

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow import keras
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

import numpy as np
from memryx import NeuralCompiler, AsyncAccl, SyncAccl
import time 

if not os.path.exists('imagenet100'):
    os.system('tar -xzf imagenet100.tar.gz')
if not os.path.exists('resnet.h5'):
    keras.applications.ResNet50().save('resnet.h5')
    dfp = NeuralCompiler(models='resnet.h5',verbose=1).run()
    dfp.write('resnet.dfp')
# Make sure to configure the path correctly to point to the downloaded dataset
imagenet_path = 'imagenet100'

# Load ground_truth
with open(imagenet_path+'/ground_truth', 'r') as f:
    ground_truth = f.read().split('\n')[:-1]

# Load images
image_paths = glob.glob(imagenet_path+'/*.JPEG')
image_paths.sort()

images = []
for image_path in image_paths:
    image = np.array(Image.open(image_path).resize((224,224)))
    # Handle grey-scale images
    if image.shape == (224,224):
        image = np.repeat(image[:,:,np.newaxis], 3, axis=2)
    # Preprocessing (shift+scale)
    image = keras.applications.resnet.preprocess_input(image)
    images.append(image.astype(np.float32))

# Prepare the images for the CPU 
images = np.stack(images)


# CPU run
model = keras.models.load_model('resnet.h5')
start = time.time()
cpu_outputs = model.predict(np.array(images))
cpu_inference_time = time.time() - start


# MXA run
def run_async_accl():

    # AsyncAccl
    img_iter = iter(np.expand_dims(img, 0) for img in images)
    mxa_outputs = []

    def get_frame():
        return next(img_iter, None)

    def process_output(*outputs):
        mxa_outputs.append(np.squeeze(outputs[0], 0))

    accl = AsyncAccl(dfp='resnet.dfp')
    start = time.time()
    accl.connect_input(get_frame)
    accl.connect_output(process_output)
    accl.wait()

    mxa_inference_time = time.time() - start
    mxa_outputs = np.stack([np.squeeze(arr) for arr in mxa_outputs])

    return mxa_outputs, mxa_inference_time


def run_sync_accl():

    accl = SyncAccl(dfp='resnet.dfp')
    start = time.time()
    mxa_outputs = accl.run(images)
    mxa_inference_time = time.time() - start
    mxa_outputs = np.stack([np.squeeze(arr) for arr in mxa_outputs])

    return mxa_outputs, mxa_inference_time


mxa_outputs, mxa_inference_time = run_async_accl()
# mxa_outputs, mxa_inference_time = run_sync_accl()

cpu_preds = keras.applications.mobilenet.decode_predictions(cpu_outputs, top=5)
mxa_preds = keras.applications.mobilenet.decode_predictions(mxa_outputs, top=5)

print("CPU Inference time (100 images): {:.1f} msec".format(cpu_inference_time*1000))
print("MXA Inference time (100 images): {:.1f} msec".format(mxa_inference_time*1000))

def compare_with_ground_truth(predictions):
    top1, top5, total = 0, 0, len(predictions)
    for i,pred in enumerate(predictions):
        gt = ground_truth[i]
    
        classes = [guess[0] for guess in pred]
        if gt in classes:
            top5 += 1
        if gt == classes[0]:
            top1 += 1
    
    print("Top 1: ({}/{})  {:.2f} % ".format(top1, total, top1/total*100))
    print("Top 5: ({}/{})  {:.2f} % ".format(top5, total, top5/total*100))

print("CPU Results: ")
compare_with_ground_truth(cpu_preds)

print("MXA Results: ")
compare_with_ground_truth(mxa_preds)
