from transformers import AutoTokenizer
import onnxruntime as ort
import numpy as np
import memryx as mx
import argparse

def top_k_indices(arr, k):
    indices = np.argpartition(arr, -k)[-k:]
    sorted_indices = indices[np.argsort(-arr[indices])]
    return arr[sorted_indices], sorted_indices

class MxaTinyStories:
    """
    MxaTinyStories class to load and run the tiny stories 33M parameter model that generates multiple tiny stories 
    based on a single prompt. 
    Parameters
    ----------
    model_dir: string
        Absolute path to the directory containing all the models
    max_len: int
        Context length of the model. This is defined while exporting the model to onnx.
    """
    def __init__(self, model_dir, max_len = 128):

        self.dfp = mx.Dfp(f"{model_dir}/tinystories33M.dfp")

        session_options = ort.SessionOptions()
        session_options.intra_op_num_threads = 4
        session_options.inter_op_num_threads = 1 

        session_options.enable_mem_pattern = True
        session_options.enable_cpu_mem_arena = True
        self.embedding_model = ort.InferenceSession(f"{model_dir}/tinystories33M_pre.onnx",session_options)
        self.mxa_core_model = mx.SyncAccl(self.dfp)
        self.rev_embedding_model = ort.InferenceSession(f"{model_dir}/tinystories33M_post.onnx",session_options)
        self.embedding_input_name = self.embedding_model.get_inputs()[0].name
        self.rev_embedding_input_name = self.rev_embedding_model.get_inputs()[0].name
        self.max_len = max_len

    def run_inference(self, input_data):
        """
        Run the inference return the output given the input tokens
        Parameters
        ----------
        input_data: List[np.array()]
            List to input token sequences of size (1x128)
        """
        core_input = []
        for inp in input_data:
            core_inp = self.embedding_model.run(None,{self.embedding_input_name: inp})[0]
            #MXA requires the input shapes to be channel last
            core_inp = core_inp.reshape(self.dfp.input_shapes[0]).astype(np.float32)
            core_input.append(core_inp)
            channels = core_input[0].shape[-1] 
        core_output = self.mxa_core_model.run(core_input)
        rev_input = []
        if isinstance(core_output,list):
            for out in core_output:
                #converting the channel last output back to channel first for onnx post processing
                out = out.reshape([1,self.max_len,channels])
                rev_input.append(out)
        else:
            rev_input = [core_output.reshape([1,self.max_len,channels])]
        output_data = []
        for out in rev_input:
            output_data.append(self.rev_embedding_model.run(None, {self.rev_embedding_input_name: out})[0])
        return output_data

    def generate(self,tokenizer, prompt, beam_width=3):
        """
        Generate the output token and print the latest token that eventually generates a full story
        Parameters
        ----------
        tokenizer: transformers.tokenizer()
            Official tokenizer for this model
        prompt: np.array()
            Tokenized input prompt
        beam_width: int
            Length of beam search(defaults to 3). Larger beam_width gives better results but takes longer to compute
        """
        init_len = len(prompt)
        sequences = [prompt]
        scores = [0]
        try:
            while(True):
                all_candidates = []
                
                n = len(sequences)
                inp = [np.zeros([1,self.max_len], dtype=np.int64) for _ in range(beam_width)]
                idx = 0
                for i in range(n):
                    idx = len(sequences[i])
                    inp[i][:,0:idx] = sequences[i]
                out = self.run_inference(inp)
                for i in range(n):
                    seq = sequences[i]
                    score = scores[i]
                    y = out[i][0,idx-1,:]
                    top_k_probs, top_k_tokens = top_k_indices(y, beam_width)
                    for j in range(beam_width):
                        if(len(seq)==self.max_len):
                            seq.pop(0)
                        candidate_seq = seq + [top_k_tokens[j]]
                        candidate_score = score + top_k_probs[j]
                        all_candidates.append((candidate_seq, candidate_score))
                all_candidates = sorted(all_candidates, key=lambda x: x[1], reverse=True)
                sequences, scores = zip(*all_candidates[:beam_width])
                
                if all(seq[-1] == tokenizer.eos_token_id for seq in sequences):
                    break
                print(tokenizer.decode(sequences[0][-init_len-1], skip_special_tokens=True), end='', flush=True)
        except KeyboardInterrupt:
            print("")
        return sequences[0]

if __name__=="__main__":
    parser = argparse.ArgumentParser(description="Tiny stories")
    parser.add_argument('--beam_width', type=int, help="the width of beam search",default=3,required=False)
    args = parser.parse_args()
    model_dir = "models"
    prompt = input("Enter the prompt: ")
    model = MxaTinyStories(model_dir)

    #load tiny stories tokenizer from official huggingface repositories
    tokenizer = AutoTokenizer.from_pretrained('roneneldan/TinyStories-33M')

    #encode the input prompt
    input_ids = tokenizer.encode(prompt, return_tensors="np")

    #start generating the stories
    output = model.generate(tokenizer,input_ids.tolist()[0],args.beam_width)

    #decode the left over final tokens
    generated_story = tokenizer.decode(output, skip_special_tokens=True)
    print(generated_story)
