Перейти к основному содержимому

Устройство трансформеров — теория и практика с нуля

Разработчику Инженеру

Промышленные модели (GPT, BERT) содержат миллиарды параметров и оптимизации (FlashAttention, fused kernels). Для понимания достаточно собрать один encoder block на PyTorch: scaled dot-product attention, multi-head обёртка, FFN, residual и LayerNorm.

Предполагается базовое знакомство с PyTorch и нейроном. Теория attention — предыдущая статья.


Каркас учебной модели

Упростим задачу:

  • фиксированная длина последовательности seq_len;
  • словарь размером vocab_size;
  • один encoder layer (без полного стека и без decoder);
  • обучаемые positional embeddings.

Полный Transformer из статьи 2017 года — стек из $N$ таких блоков + выходная проекция.


Scaled dot-product attention

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


def scaled_dot_product_attention(q, k, v, mask=None):
# q, k, v: (batch, heads, seq_len, d_k)
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, v), weights

Mask с shape (batch, 1, seq_len, seq_len) обнуляет запрещённые позиции. Для encoder на классификации mask часто не нужен. Для decoder (GPT) — нижний треугольник (causal mask).


Multi-head attention

class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)

def forward(self, x, mask=None):
batch, seq_len, d_model = x.size()
q = self.w_q(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
k = self.w_k(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
v = self.w_v(x).view(batch, seq_len, self.num_heads, self.d_k).transpose(1, 2)
out, _ = scaled_dot_product_attention(q, k, v, mask)
out = out.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
return self.w_o(out)

Position-wise FFN

Два линейных слоя с расширением (обычно d_ff = 4 * d_model):

class PositionwiseFFN(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
)

def forward(self, x):
return self.net(x)

Encoder block

Post-LN (как в оригинальной статье) — нормализация после residual:

class EncoderBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = PositionwiseFFN(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)

def forward(self, x, mask=None):
attn_out = self.self_attn(x, mask)
x = self.norm1(x + self.dropout(attn_out))
ffn_out = self.ffn(x)
x = self.norm2(x + self.dropout(ffn_out))
return x

Современные модели (GPT-2, некоторые LLM) часто используют Pre-LN (нормализация перед подблоком) — стабильнее при большой глубине.


Мини-модель для классификации

class TinyTransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model=128, num_heads=4,
d_ff=512, num_layers=2, num_classes=2, max_len=128):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([
EncoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
])
self.head = nn.Linear(d_model, num_classes)

def forward(self, input_ids):
batch, seq_len = input_ids.size()
positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0)
x = self.token_emb(input_ids) + self.pos_emb(positions)
for layer in self.layers:
x = layer(x)
# [CLS]-подобный pooling — первый токен
return self.head(x[:, 0, :])

Для обучения нужны input_ids, метки классов, CrossEntropyLoss и AdamW. На CPU учебная модель на сотнях примеров сходится за минуты — удобно для эксперимента с attention weights.


Causal mask для decoder

GPT-стиль — нижний треугольник единиц:

def causal_mask(seq_len, device):
return torch.tril(torch.ones(seq_len, seq_len, device=device)).unsqueeze(0).unsqueeze(0)

Применяют в self-attention декодера, чтобы позиция $i$ видела только $0..i$.


Что добавляют «боевые» реализации

КомпонентЗачем
KV cacheПри генерации не пересчитывать K/V для старых токенов
FlashAttentionМеньше памяти и быстрее attention на GPU
RoPE / ALiBiПозиционная информация для длинного контекста
Gradient checkpointingЭкономия VRAM при обучении
Mixed precision (fp16/bf16)Скорость и объём модели

Готовые стеки — Hugging Face transformers, Keras. С нуля имеет смысл один раз для обучения; в проде — проверенная библиотека.

Упражнение
Обучите TinyTransformerClassifier на бинарной классификации (например, RuSentiment с урезанным словарём). Сравните с TF-IDF + LogisticRegression из Scikit-learn — разница в F1 покажет выигрыш контекстуальных эмбеддингов.


Дальше


См. также

Другие статьи этого же раздела в боковом меню (как на странице "О разделе").