
import os
import torch
from PIL import Image
from torchvision import transforms
import numpy as np

from memryx import NeuralCompiler, Simulator

# Get model
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()

# Convert to Onnx
sample_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model,                     # The model to be exported
    sample_input,              # The sample input tensor
    "mobilenet_v2.onnx",       # The output file name
    export_params=True,        # Store the trained parameter weights inside the model file
    opset_version=17,          # The ONNX version to export the model to
    do_constant_folding=True,  # Whether to execute constant folding for optimization
    input_names=['input'],     # The model's input names
    output_names=['output'],   # The model's output names
)


# Prepare image
input_image = Image.open('image.png')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
image = input_tensor.unsqueeze(0).numpy()

# Compile model
dfp = NeuralCompiler(models='mobilenet_v2.onnx', verbose=1).run()

# Run Simulator
s = Simulator(dfp=dfp, verbose=1)
outputs = s.infer(inputs=image)
latency, fps = s.benchmark(frames=4)

# Process outputs and save the classification result!
os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
with open("imagenet_classes.txt", 'r') as f:
    classes = f.read().split('\n')

def softmax(x): 
    return np.exp(x - np.max(x)) / np.sum(np.exp(x - np.max(x)))

outputs = softmax(outputs)
outputs = np.squeeze(outputs)
idx = np.argmax(outputs)
print("I see a '{}' with {:.1f} % certainty".format(classes[idx], outputs[idx]*100))
print("Simulated MXA FPS: ", fps)
print("Simulated MXA Latency: ", latency)



