TinyStories#
Introduction#
TinyStories is a model created with the goal of making language models smaller and usable on edge hardware, hence the name Small Language Model (SLM). In this tutorial, we demonstrate that MXA works well not only with vision models but also with language models.
In this tutorial, we use the MXA to generate stories using the popular Small Language Model (SLM), TinyStories .
Note
This application assumes that the MXA drivers, runtimes, and compilers have been successfully installed. For more information, please refer to the Install page.
Export the Model#
The official model is in PyTorch. Therefore, we will make a few changes and export it to ONNX to further compile it to DFP for use on MXA.
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
tokenizer = AutoTokenizer.from_pretrained("roneneldan/TinyStories-33M")
model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-33M")
model.eval()
max_length = 128
dummy_input = tokenizer("Once upon a time", return_tensors="pt").input_ids
dummy_input = torch.cat([dummy_input, torch.zeros(1, max_length - dummy_input.size(1), dtype=torch.long)], dim=1)
class WrappedModel(nn.Module):
def __init__(self, original_model):
super(WrappedModel, self).__init__()
self.original_model = original_model
def forward(self, x):
out = self.original_model(x)
return out[0]
wrapped_model = WrappedModel(model)
wrapped_model.eval()
torch.onnx.export(wrapped_model, dummy_input,
"models/tinystories33M.onnx", input_names=["input_ids"],
output_names=["output"], dynamic_axes=None
)
The parameter max_length
is used to set the maximum context length of the model. Other values can also be used to experiment with speed and accuracy.
Note
Support for transformer models is rapidly expanding. For more information, please check the Transformer Support page.
Compile the Model#
Compiling this model is more involved than other vision models as we require a few experimental flags. The --effort hard
flag is required, which takes longer than a standard compilation. For this purpose, we encourage users to use the precompiled models provided in the attached zip file at the end of the tutorial.
cd models
mx_nc -m tinystories33M.onnx --inputs /original_model/transformer/Add_output_0:0,/original_model/transformer/Add_output_0:1,/original_model/transformer/Add_output_0:2 --outputs /original_model/transformer/ln_f/Add_1_output_0 --graph_extensions TinyStories --insert_onnx_io_format_adapters "io" --effort hard
Model Details#
The TinyStories model takes a sequence of input tokens and outputs a sequence of tokens. We are only interested in the final token of the output, as the other tokens are generated from previous iterations. We use a concept called beam search, which selects the top n
tokens of the output instead of just the best one. This approach produces much better results than using a single token. With a beam_width
of 1, only one story is generated, while with a beam_width
of 3, an average of 50 or more stories are generated.
Model Initialization#
Certain layers in the model that are not supported are cropped into pre/post processing models, which are run on the host system. Below is the code to initialize all the models based on where they are supposed to run:
self.dfp = mx.Dfp(f"{model_dir}/tinystories33M.dfp")
session_options = ort.SessionOptions()
session_options.intra_op_num_threads = 4
session_options.inter_op_num_threads = 1
session_options.enable_mem_pattern = True
session_options.enable_cpu_mem_arena = True
self.embedding_model = ort.InferenceSession(f"{model_dir}/tinystories33M_pre.onnx", session_options)
self.mxa_core_model = mx.SyncAccl(self.dfp)
self.rev_embedding_model = ort.InferenceSession(f"{model_dir}/tinystories33M_post.onnx", session_options)
self.embedding_input_name = self.embedding_model.get_inputs()[0].name
self.rev_embedding_input_name = self.rev_embedding_model.get_inputs()[0].name
self.max_len = max_len
Input Encoding#
Encoding the input to send to the model for inference:
prompt = input("Enter the prompt: ")
input_ids = tokenizer.encode(prompt, return_tensors="np")
sequences = [input_ids]
while(True):
inp = [np.zeros([1,max_len], dtype=np.int64) for _ in range(beam_width)]
n = len(sequences)
for i in range(n):
idx = len(sequences[i])
inp[i][:,0:idx] = sequences[i]
In the above code, we start with a prompt and later create input encodings based on the outputs of the current inference and beam_width
.
Inference#
The inference pipeline runs through cropped parts and MXA:
core_input = []
for inp in input_data:
core_inp = self.embedding_model.run(None, {self.embedding_input_name: inp})[0]
# MXA requires the input shapes to be channel-last
core_inp = core_inp.reshape(self.dfp.input_shapes[0]).astype(np.float32)
core_input.append(core_inp)
channels = core_input[0].shape[-1]
core_output = self.mxa_core_model.run(core_input)
rev_input = []
if isinstance(core_output, list):
for out in core_output:
# Converting the channel-last output back to channel-first for ONNX post-processing
out = out.reshape([1, self.max_len, channels])
rev_input.append(out)
else:
rev_input = [core_output.reshape([1, self.max_len, channels])]
output_data = []
for out in rev_input:
output_data.append(self.rev_embedding_model.run(None, {self.rev_embedding_input_name: out})[0])
return output_data
Usage#
To use this application, simply run the stories.py
file and provide an input prompt when asked. A typical prompt might be Once upon a time
.
python stories.py --beam_width=3
Final Thoughts#
Try using different beam_width
inputs to observe the variation in story generation. MXA excels in throughput when used with large batch sizes in Sync mode, so a larger beam_width
will be much faster on MXA compared to CPU.
Third-Party License#
This tutorial uses third-party software and libraries. Below are the details of the licenses for these dependencies:
Model: Copyright (c) Ronen Eldan, MIT license