Building a GPT from scratch: understanding token embeddings and training
In this post, I explain how I built a GPT model from scratch using PyTorch, focusing on the token embedding process, model layers, and training routines. My approach revolves around initializing a custom GPT model and understanding the workflow of embedding tokens, passing them through the model, and fine-tuning the training process to predict the next token in a sequence.
Initialization of the GPT model
To start, I defined the dimensionality of my embeddings, using the d_model
variable. This is crucial as it determines how each token from my vocabulary is represented in a dense vector. Below is how I initialize the GPT model:
d_model = 1024 # Define size of embeddings
model = GPT(vocab_size=vocab_size, d_model=d_model).to(device)
The d_model
defines the size of the embeddings that represent each token. The GPT model is initialized using both the vocabulary size and the embedding dimension, and it is moved to the correct device for computation.
Inside MyGPT class
The core of my model is implemented through the MyGPT
class. In this class, the token embeddings are created, layer normalization is applied, and a linear output layer is used to predict the next token. Here’s a simplified version of how this is done:
class MyGPT(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.wte = nn.Embedding(vocab_size, d_model) # Token embeddings
self.ln_f = nn.LayerNorm(d_model) # Layer normalization
self.fc_out = nn.Linear(d_model, vocab_size) # Final output layer
#include <stdio.h>
The forward pass of my model involves converting input tokens into embeddings, normalizing them, and generating logits through a final linear layer.
def forward(self, inputs, targets=None):
embeddings = self.wte(inputs)
logits = self.fc_out(self.ln_f(embeddings))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
return logits, loss
In the forward pass, embeddings are passed through layer normalization and then fed into the output layer. If target labels are provided (during training), I calculate the cross-entropy loss, which is essential for guiding the model in predicting the correct tokens.
Training the model
Training the GPT model is a critical aspect of my work. I initialize the Adam optimizer, which adjusts the learning rate and optimizes the model’s parameters during each epoch.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
I begin by feeding the model with batches of sequences. Inputs are taken from the sequence excluding the last token, while targets are the same sequence shifted by one token. This way, the model learns to predict the next token in the sequence.
for epoch in range(num_epochs):
model.train()
total_loss = 0
for batch in train_loader:
inputs = batch[:, :-1].to(device)
targets = batch[:, 1:].to(device)
optimizer.zero_grad()
logits, loss = model(inputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch}/{num_epochs}], Loss: {avg_loss:.4f}")
During each epoch, the training loss is accumulated, and the model parameters are updated based on the computed gradients. After each epoch, I compute the average training loss to monitor progress. Evaluating the model on validation data follows a similar approach, though without updating the parameters.
Saving the model
Once training is completed, I save the model’s configuration and weights using torch.save()
. This step is vital as it allows me to reload the model later without needing to retrain it from scratch. Here’s how I save the model:
torch.save({
'vocab_size': vocab_size,
'd_model': d_model,
'state_dict': model.state_dict()
}, './build/gpt_model.pth')
The state_dict
holds all the parameters of the model (such as weights), while vocab_size
and d_model
are used to reconstruct the model architecture when loading it for inference or further training.
Conclusion
Building a GPT model requires careful consideration of the embeddings, layer normalization, and the training loop to ensure that the model can learn effectively from sequences of tokens. Understanding how to embed tokens and process them through transformer layers plays a critical role in sequence prediction tasks. For more insights into this topic, you can find the details here.