Let’s build a simple LSTM model and train it to predict the next token given a prefix of tokens. Now, you might ask what a token is.
Typically for language models, a token can mean
A single character (or a single byte)An entire word in the target languageSomething in between 1 and 2. This is usually called a sub-word
Mapping a single character (or byte) to a token is very restrictive since we’re overloading that token to hold a lot of context about where it occurs. This is because the character “c” for example, occurs in many different words, and to predict the next character after we see the character “c” requires us to really look hard at the leading context.
Mapping a single word to a token is also problematic since English itself has anywhere between 250k and 1 million words. In addition, what happens when a new word is added to the language? Do we need to go back and re-train the entire model to account for this new word?
Sub-word tokenization is considered the industry standard in the year 2023. It assigns substrings of bytes frequently occurring together to unique tokens. Typically, language models have anywhere from a few thousand (say 4,000) to tens of thousands (say 60,000) of unique tokens. The algorithm to determine what constitutes a token is determined by the BPE (Byte pair encoding) algorithm.
To choose the number of unique tokens in our vocabulary (called the vocabulary size), we need to be mindful of a few things:
If we choose too few tokens, we’re back in the regime of a token per character, and it’s hard for the model to learn anything useful.If we choose too many tokens, we end up in a situation where the model’s embedding tables over-shadow the rest of the model’s weight and it becomes hard to deploy the model in a constrained environment. The size of the embedding table will depend on the number of dimensions we use for each token. It’s not uncommon to use a size of 256, 512, 786, etc… If we use a token embedding dimension of 512, and we have 100k tokens, we end up with an embedding table that uses 200MiB in memory.
Hence, we need to strike a balance when choosing the vocabulary size. In this example, we pick 6600 tokens and train our tokenizer with a vocabulary size of 6600. Next, let’s take a look at the model definition itself.
The PyTorch Model
The model itself is pretty straightforward. We have the following layers:
Token Embedding (vocab size=6600, embedding dim=512), for a total size of about 15MiB (assuming 4 byte float32 as the embedding table’s data type)LSTM (num layers=1, hidden dimension=786) for a total size of about 16MiBMulti-Layer Perceptron (786 to 3144 to 6600 dimensions) for a total size of about 93MiB
The complete model has about 31M trainable parameters for a total size of about 120MiB.
Here’s the PyTorch code for the model.
class WordPredictionLSTMModel(nn.Module):def __init__(self, num_embed, embed_dim, pad_idx, lstm_hidden_dim, lstm_num_layers, output_dim, dropout):super().__init__()self.vocab_size = num_embedself.embed = nn.Embedding(num_embed, embed_dim, pad_idx)self.lstm = nn.LSTM(embed_dim, lstm_hidden_dim, lstm_num_layers, batch_first=True, dropout=dropout)self.fc = nn.Sequential(nn.Linear(lstm_hidden_dim, lstm_hidden_dim * 4),nn.LayerNorm(lstm_hidden_dim * 4),nn.LeakyReLU(),nn.Dropout(p=dropout),
nn.Linear(lstm_hidden_dim * 4, output_dim),)#
def forward(self, x):x = self.embed(x)x, _ = self.lstm(x)x = self.fc(x)x = x.permute(0, 2, 1)return x##
Here’s the model summary using torchinfo.
LSTM Model Summary
=================================================================Layer (type:depth-idx) Param #=================================================================WordPredictionLSTMModel – ├─Embedding: 1–1 3,379,200├─LSTM: 1–2 4,087,200├─Sequential: 1–3 – │ └─Linear: 2–1 2,474,328│ └─LayerNorm: 2–2 6,288│ └─LeakyReLU: 2–3 – │ └─Dropout: 2–4 – │ └─Linear: 2–5 20,757,000=================================================================Total params: 30,704,016Trainable params: 30,704,016Non-trainable params: 0=================================================================
Interpreting the accuracy: After training this model on 12M English language sentences for about 8 hours on a P100 GPU, we achieved a loss of 4.03, a top-1 accuracy of 29% and a top-5 accuracy of 49%. This means that 29% of the time, the model was able to correctly predict the next token, and 49% of the time, the next token in the training set was one of the top 5 predictions by the model.
What should our success metric be? While the top-1 and top-5 accuracy numbers for our model aren’t impressive, they aren’t as important for our problem. Our candidate words are a small set of possible words that fit the swipe pattern. What we want from our model is to be able to select an ideal candidate to complete the sentence such that it is syntactically and semantically coherent. Since our model learns the nature of language through the training data, we expect it to assign a higher probability to coherent sentences. For example, if we have the sentence “The baseball player” and possible completion candidates (“ran”, “swam”, “hid”), then the word “ran” is a better follow-up word than the other two. So, if our model predicts the word ran with a higher probability than the rest, it works for us.
Interpreting the loss: A loss of 4.03 means that the negative log-likelihood of the prediction is 4.03, which means that the probability of predicting the next token correctly is e^-4.03 = 0.0178 or 1/56. A randomly initialized model typically has a loss of about 8.8 which is -log_e(1/6600), since the model randomly predicts 1/6600 tokens (6600 being the vocabulary size). While a loss of 4.03 may not seem great, it’s important to remember that the trained model is about 120x better than an untrained (or randomly initialized) model.
Next, let’s take a look at how we can use this model to improve suggestions from our swipe keyboard.
Using the model to prune invalid suggestions
Let’s take a look at a real example. Suppose we have a partial sentence “I think”, and the user makes the swipe pattern shown in blue below, starting at “o”, going between the letters “c” and “v”, and ending between the letters “e” and “v”.
Some possible words that could be represented by this swipe pattern are
OverOct (short for October)IceI’ve (with the apostrophe implied)
Of these suggestions, the most likely one is probably going to be “I’ve”. Let’s feed these suggestions into our model and see what it spits out.
[I think] [I’ve] = 0.00087[I think] [over] = 0.00051[I think] [ice] = 0.00001[I think] [Oct] = 0.00000
The value after the = sign is the probability of the word being a valid completion of the sentence prefix. In this case, we see that the word “I’ve” has been assigned the highest probability. Hence, it is the most likely word to follow the sentence prefix “I think”.
The next question you might have is how we can compute these next-word probabilities. Let’s take a look.
Computing the next word probability
To compute the probability that a word is a valid completion of a sentence prefix, we run the model in eval (inference) mode and feed in the tokenized sentence prefix. We also tokenize the word after adding a whitespace prefix to the word. This is done because the HuggingFace pre-tokenizer splits words with spaces at the beginning of the word, so we want to make sure that our inputs are consistent with the tokenization strategy used by HuggingFace Tokenizers.
Let’s assume that the candidate word is made up of 3 tokens T0, T1, and T2.
We first run the model with the original tokenized sentence prefix. For the last token, we check the probability of predicting token T0. We add this to the “probs” list.Next, we run a prediction on the prefix+T0 and check the probability of token T1. We add this probability to the “probs” list.Next, we run a prediction on the prefix+T0+T1 and check the probability of token T2. We add this probability to the “probs” list.
The “probs” list contains the individual probabilities of generating the tokens T0, T1, and T2 in sequence. Since these tokens correspond to the tokenization of the candidate word, we can multiply these probabilities to get the combined probability of the candidate being a completion of the sentence prefix.
The code for computing the completion probabilities is shown below.
def get_completion_probability(self, input, completion, tok):self.model.eval()ids = tok.encode(input).idsids = torch.tensor(ids, device=self.device).unsqueeze(0)completion_ids = torch.tensor(tok.encode(completion).ids, device=self.device).unsqueeze(0)probs = for i in range(completion_ids.size(1)):y = self.model(ids)y = y[0,:,-1].softmax(dim=0)# prob is the probability of this completion.prob = y[completion_ids[0,i]]probs.append(prob)ids = torch.cat([ids, completion_ids[:,i:i+1]], dim=1)#return torch.tensor(probs)#
We can see some more examples below.
[That ice-cream looks] [really] = 0.00709[That ice-cream looks] [delicious] = 0.00264[That ice-cream looks] [absolutely] = 0.00122[That ice-cream looks] [real] = 0.00031[That ice-cream looks] [fish] = 0.00004[That ice-cream looks] [paper] = 0.00001[That ice-cream looks] [atrocious] = 0.00000[Since we’re heading] [toward] = 0.01052[Since we’re heading] [away] = 0.00344[Since we’re heading] [against] = 0.00035[Since we’re heading] [both] = 0.00009[Since we’re heading] [death] = 0.00000[Since we’re heading] [bubble] = 0.00000[Since we’re heading] [birth] = 0.00000 [Did I make] [a] = 0.22704[Did I make] [the] = 0.06622[Did I make] [good] = 0.00190[Did I make] [food] = 0.00020[Did I make] [color] = 0.00007[Did I make] [house] = 0.00006[Did I make] [colour] = 0.00002[Did I make] [pencil] = 0.00001[Did I make] [flower] = 0.00000 [We want a candidate] [with] = 0.03209[We want a candidate] [that] = 0.02145[We want a candidate] [experience] = 0.00097[We want a candidate] [which] = 0.00094[We want a candidate] [more] = 0.00010[We want a candidate] [less] = 0.00007[We want a candidate] [school] = 0.00003 [This is the definitive guide to the] [the] = 0.00089[This is the definitive guide to the] [complete] = 0.00047[This is the definitive guide to the] [sentence] = 0.00006[This is the definitive guide to the] [rapper] = 0.00001[This is the definitive guide to the] [illustrated] = 0.00001[This is the definitive guide to the] [extravagant] = 0.00000[This is the definitive guide to the] [wrapper] = 0.00000[This is the definitive guide to the] [miniscule] = 0.00000 [Please can you] [check] = 0.00502[Please can you] [confirm] = 0.00488[Please can you] [cease] = 0.00002[Please can you] [cradle] = 0.00000[Please can you] [laptop] = 0.00000[Please can you] [envelope] = 0.00000[Please can you] [options] = 0.00000[Please can you] [cordon] = 0.00000[Please can you] [corolla] = 0.00000 [I think] [I’ve] = 0.00087[I think] [over] = 0.00051[I think] [ice] = 0.00001[I think] [Oct] = 0.00000 [Please] [can] = 0.00428[Please] [cab] = 0.00000 [I’ve scheduled this] [meeting] = 0.00077[I’ve scheduled this] [messing] = 0.00000
These examples show the probability of the word completing the sentence before it. The candidates are sorted in decreasing order of probability.
Since Transformers are slowly replacing LSTM and RNN models for sequence-based tasks, let’s take a look at what a Transformer model for the same objective would look like.