Building a GPT from scratch: adding single-head and multi-head attention
In this post, I walk through the process of implementing both single-head and multi-head attention mechanisms for my GPT model. Starting with the basics of attention, I describe how I build the single-head version and then extend it to handle multiple attention heads. Each step involves clear and concise Python code, emphasizing how the model handles queries, keys, and values to compute attention scores. I then explain how to switch between the two attention mechanisms, maintaining flexibility in my model’s architecture.
Single-head attention implementation
The single-head attention mechanism starts with projecting the input embeddings into queries, keys, and values:
class SingleHeadAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.scale = 1.0 / math.sqrt(d_model)
def forward(self, x):
seq_length = x.size(1)
Q = self.query(x)
K = self.key(x)
V = self.value(x)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(seq_length, seq_length),
diagonal=1).bool().to(x.device)
attention_scores = attention_scores.masked_fill(mask, float('-inf'))
attention_weights = torch.softmax(attention_scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
return attention_output
This class computes the attention weights using the queries, keys, and values, and applies a causal mask to ensure the model doesn’t attend to future tokens. After softmax normalization, the output is a weighted sum of the values.
Multi-head attention implementation
Next, I extend the single-head attention to multi-head attention to improve the model’s capacity to capture complex dependencies across different subspaces:
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.query = nn.Linear(d_model, d_model)
self.key = nn.Linear(d_model, d_model)
self.value = nn.Linear(d_model, d_model)
self.fc_out = nn.Linear(d_model, d_model)
self.scale = 1.0 / math.sqrt(self.head_dim)
def forward(self, x):
B, seq_length, d_model = x.shape
Q = self.query(x)
K = self.key(x)
V = self.value(x)
Q = Q.view(B, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale
mask = torch.triu(torch.ones(seq_length, seq_length),
diagonal=1).bool().to(x.device)
attention_scores = attention_scores.masked_fill(mask, float('-inf'))
attention_weights = torch.softmax(attention_scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
attention_output = attention_output.transpose(1, 2).contiguous().view(B, seq_length, d_model)
out = self.fc_out(attention_output)
return out
In multi-head attention, I split the queries, keys, and values into multiple heads. Each head computes attention scores separately, focusing on different parts of the input sequence. The results are then concatenated and projected back to the original dimension.
Integrating attention into my model
To allow flexibility in my model architecture, I added the option to switch between single-head and multi-head attention:
class MyGPT(nn.Module):
def __init__(self, vocab_size, d_model, max_len=5000, hidden_dim=2048, use_multiple_head=True, num_heads=8):
super().__init__()
if use_multiple_head:
self.attention = MultiHeadAttention(d_model, num_heads)
else:
self.attention = SingleHeadAttention(d_model)
self.ffn = FeedForwardNN(d_model, hidden_dim)
self.ln_f = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, inputs, targets=None):
embeddings = self.embedding(inputs)
attn_output = self.attention(embeddings)
ffn_output = self.ffn(attn_output)
logits = self.fc_out(ffn_output)
if targets is not None:
loss = self.compute_loss(logits, targets)
return logits, loss
return logits
The model allows me to choose between single-head and multi-head attention by toggling a flag, making it easy to experiment with different configurations and adapt the model for various tasks.
For more insights into this topic, you can find the details here.