Home/All posts

GPT2 Distributed Training

Jun 6, 202519 min read

GPT-2 was a huge leap forward since the original GPT because it introduced the concept of zero-shot task learning: the ability to perform various NLP tasks without explicit supervised training. The authors also show that model size is crucial for this emergent multitask behaviors. This finding, thus, suggests a promising direction towards general-purpose language systems that learn to do tasks from natural data rather than from curated training for each task.

In this post, I will implement the GPT-2 from scratch via the task of language modeling (next word prediction). The model is a transformer decoder-only model, approximately 124M parameters in size. The pipeline has two stages:

  1. semi-supervised training on WebText
  2. zero-shot inference on tasks by providing suitable text prompts. Some further, yet most important, optimization tricks will also be introduced in this article, since they are applicable for other language models and large language models.

Architecture

Essentially, these are the pre-defined configurations of our GPT2: vocab_size: 50257 (50k BPE merges + 256 byte tokens + 1 <|end_of_text> token), block_size: max sequence length is 1024, n_layer: 12, n_head: 12, n_embd: 768

In addition, we also moved the layer normalization to the input of each sub-block and added an extra layer norm after the final self-attention block. Finally, we also used a scaled initialization on residual connections to stabilize training. It is also worth noting that the positional embeddings in GPT2 are completely learnable, not fixed sinusoidal formulae like in the original transformer.

def _init_weights(self, module):
    if isinstance(module, nn.Linear):
        std = 0.02
        if hasattr(module, 'GPT2_SCALE_INIT'):
            std *= (2 * self.config.n_layer) ** -0.5
        torch.nn.init.normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            torch.nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

GPT2 makes effective use of the concept of multi-head attention. To recapitulate, multi-head attention allows the model to capture different aspects of the input sequence by having multiple attention heads, each learning different representations. Each attention head possesses a n-dimensional subspace of the full embedding independently, and their results are later concatenated back to a C-dimensional vector.

GPT-2 is trained with a language modeling objective, maximize the likelihood of the next token given the preceding text. When we calculate the loss, the label is the token to our right.

p(x)=i=1np(sis1,s2,...,si1)p(x) = \prod_{i=1}^{n} p(s_i \mid s_1, s_2, ..., s_{i-1})

Moreover, the authors decided to change all the NLP tasks into next-word prediction, which is a pretty smart way to manipulate the model to output what we want. For example, for translation, they feed a few examples sentence pairs in the form "English sentence = French sentence" into the context. For question answering, they condition the model on a passage and a dialogue of question-answer pairs ending with a final question, and prompt it to generate the answer (e.g., format: Document... Q: (question)? A: → model generates answer)​.

Modified Byte-pair Encoding

The input text is processed with a custom byte-level Byte Pair Encoding (BPE) tokenizer, designed to be both general and efficient. The authors modify BPE to operate on raw bytes rather than Unicode code points. This means the base vocabulary includes all 256 possible byte values, so in principle any Unicode text can be represented without loss or “[UNK]” tokens. Starting from this base, the BPE algorithm merges frequent byte sequences into larger tokens. The result is a vocabulary of 50,257 tokens (including special tokens) covering a wide range of subwords and symbols.

The modified tokenizer also preserves text fidelity. That said, to avoid the tokenizer producing strange splits (multiple variants of the same word with punctuation), the authors impose rules on the BPE merges. They prevent BPE from merging across different character categories (such as a letter with a following punctuation or digit), with the only exception being that a space can merge with a word following it​. As the result, we do not have to worry about dog, dog! and dog? being different tokens.

Parameter Sharing

In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation. Since for the smallest GPT-2 model 124M, the largest matrix multiplication is from the top classifier layer.

# weight sharing scheme
self.transformer.wpe.weight = self.lm_head.weight # copy the reference

Parameter sharing in this context is done for efficiency** and consistency in representation learning. The reasons can be broken down as follows:

  1. Reduced number of parameters: GPT-2 has a large vocabulary corpus, and embeddings are stored in the embedding matrix WeW_e, which maps words to high-dimensional vector representations. At the output, a projection matrix WoW_o is needed to convert the final hidden state back into vocabulary logits for computing probabilities. If WeW_e and WoW_o are separate, the model maintains two large matrices if size V×dV \times d (where VV is the vocabulary size and dd is the embedding dimension)
  2. Aligning representations: the Transformer learns an embedding space where words are mapped to continuous vector representations. Sharing weights ensures that the same space is used for both input encoding and output decoding. This helps in learning better token representations because the input and output layers speak the same language in the learned embedding space.

Optimization

The model is trained on a new dataset called WebText, created to be large-scale and diverse while maintaining quality. WebText was built by scraping content from the internet, specifically all outbound web links posted on Reddit that received at least 3 karma points​. The underlying assumption is that if multiple Reddit users found a link interesting or useful, the content likely has reasonably high quality or informative value. This heuristic yields a vast collection of web pages curated by human preferences, spanning many topics and styles. The data was collected up to December 2017.

Now we will outline every single optimization trick to boost convergence rate and reduce training time

Data Loader.

The DataLoaderLite class is designed to load and iterate over tokenized datasets in a sharded, distributed fashion. It makes use of distributed processing to generate batches (input xx and target yy tensors) for next-word prediction. Training with batches is much faster.

class DataLoaderLite:
    def __init__(self, B, T, process_rank, num_processes, split):
        self.B = B # batch size
        self.T = T # seq lenght 
        self.process_rank = process_rank
        self.num_processes = num_processes
        assert split in {"train", "val"}
 
        # get the shard filenames
        data_root = "edu_fineweb10B"
        shards = os.listdir(data_root)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(data_root, s) for s in shards]
        self.shards = shards
        assert len(shards) > 0, f"No shards found for split {split}"
        if master_process:
            print(f"found {len(shards)} shards for split {split}")
        self.reset()
 
    def reset(self):
        # state, init at shard zero
        self.current_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.process_rank  
 
    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position : self.current_position+B*T+1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets for next-word prediction / language modeling
        # advance the position in the tensor
        self.current_position += B * T * self.num_processes
        # if loading the next batch would be out of bounds, advance to next shard
        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
            self.current_shard = (self.current_shard + 1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.current_shard])
            self.current_position = B * T * self.process_rank
        return x, y
 
train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train")
val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")

Mixed Precision Training.

This refers to the technique of using multiple floating-point precision formats within a single computation to maximize performance while maintaining accuracy. Here are the main takeaways:

AI Training: FP16 Tensor Core & BFLOAT16 provide the best balance between speed and precision.

AI Inference: INT8 are preferred due to high efficiency and low power consumption. The fewer bits of representations, the easier to move them around.

Scientific & HPC (High-Performance Computing): FP64 or FP64 Tensor Core are necessary for maximum accuracy.

Gaming & Graphics: FP32 is commonly used, balancing speed and precision.

Sometimes, we would like to switch from FP16 (half precision) to BFLOAT16 (Brain Floating Point) to incur a minimal enhancement but it come with trade-offs in terms of speed, precision, and computational complexity. This is because:

  1. BFLOAT16 has a larger exponent (8 bits, like FP32), so it supports a wider dynamic range and avoids issues like gradient underflow.
  2. FP16 has a larger mantissa (10 bits), meaning it can represent numbers more precisely but suffers from narrower dynamic range, leading to underflow in deep learning computations (requires gradient scaling).
  3. Minimal enhancement: Converting from FP16 to BFLOAT16 does not significantly improve precision, since both have only 16 bits. However, BFLOAT16’s wider range makes computations more stable, especially for gradient updates in training.

It is also extremely important to account for bandwidth. Sometimes the computations are extremely fast but the data fetching from memory is not fast enough. Do not just blindly aim to maximize computation speed.

TensorCore and Matrix Decomposition.

Tensor Cores are specialized hardware units in NVIDIA GPUs designed to accelerate matrix multiplications by performing fused multiply-accumulate (FMA) operations. They significantly boost performance for deep learning workloads by efficiently handling FP16, BFLOAT16, and INT8 computations. By leveraging parallel execution and optimized precision, Tensor Cores enable faster training and inference with lower power consumption.

Matrix decomposition breaks large matrix multiplications into smaller sub-matrices, allowing more efficient parallel processing. This technique optimizes Tensor Core operations by reducing computational complexity and improving memory efficiency.

That is why, torch.autocast automatically selects the optimal floating-point precision (FP16, BFLOAT16, or FP32) for each operation to maximize speed while maintaining stability. It simplifies mixed-precision training by dynamically switching between lower and higher precision, reducing memory usage and accelerating computation without requiring manual precision tuning.

with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, loss = model(x, y)

and later, we set:

torch.set_float32_matmul_precision('high')

GPU

Obviously, pre-training a LM such as GPT-2 requires a huge amount of compute power, the following is the comparison of some old but still powerful GPU processing units.

And this is the code to enable training on GPU and the training loop

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print("using device:", device)
 
model = GPT(GPTConfig())
model.to(device)
 
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()
    print(f"step {i}, loss: {loss.item()}")

Compile and Kernel Fusion

This single line of code is very powerful. Just like GCC compiles C++ into optimized machine instructions, torch.compile() PyTorch models into highly optimized, low-level code, reducing Python overhead and improving execution speed. Instead of interpreting PyTorch operations dynamically at runtime, torch.compile() converts the model into an optimized computation graph, leveraging kernel fusion and specialized execution backends (TorchInductor, Triton, etc.). Kernel fusion helps reduce redundant memory operations, improving hardware efficiency.

model = torch.compile(model)

Flash Attention

Flash Attention is an optimized attention mechanism that leverages kernel fusion to significantly speed up Transformer models while reducing memory overhead. Traditional attention mechanisms suffer from high memory usage and slow execution due to frequent reads/writes between GPU memory and compute units. Flash Attention fuses multiple tensor operations into a single GPU kernel, minimizing memory access and increasing computational efficiency. That said, instead of using:

att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) => (B, nh, T, hs)
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size)) # soon to delete after applying flash attention 

We can use:

y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

Nice and Ugly Number

In deep learning and computer architecture, "nice numbers" refer to numbers that are highly optimized for efficient computation, while "ugly numbers" cause inefficiencies due to poor memory alignment or hardware constraints. Choosing the right numbers for batch sizes, tensor dimensions, and sequence lengths can dramatically improve training and inference speed. As a rule of thumb, do:

  1. Use powers of 2 for batch sizes, hidden dimensions, tensor shapes
  2. Multiples of 16, 32 for sequence lengths
  3. Avoid prime and odd numbers
  4. Ensure divisible-by-8 dimensions

This simple trick would determine whether the model has to train for days OR for hours.

Gradient Clipping

Gradient clipping is a technique used to prevent exploding gradients by capping the maximum value of gradients during backpropagation. When gradients become too large, they can cause unstable updates, leading to poor convergence or NaN errors, By setting a threshold (in this case 1.0), gradient clipping ensures smooth optimization and stable training, especially in deep networks or reinforcement learning.

torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

Learning Rate Scheduler

A learning rate scheduler dynamically adjusts the learning rate during training to improve convergence and prevent overshooting. Popular schedules include step decay, cosine annealing, and warm-up strategies, which help balance exploration (large updates) and fine-tuning (small updates). Proper scheduling enhances training efficiency and can significantly boost model performance by adapting learning rates over time.

max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_steps = 200
max_steps = 1501
 
def get_lr(it):
    if it < warmup_steps:
        return max_lr * (it+1) / warmup_steps
    if it > max_steps:
        return min_lr
    decay_ratio = (it - warmup_steps) / (max_steps - warmup_steps)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (max_lr - min_lr)

Weight Decay

Weight decay is another regularization technique used in deep learning to prevent overfitting by discouraging large weights. It works by adding a L2 penalty term to the loss function, which bounds the magnitude of the model's parameters over time. However, weight decay is not applied to 1-dimensional tensors, such as biases and layer normalization parameters, because they typically do not benefit from regularization and can cause instability if decayed.

FusedAdamW is a high performance, optimized version of AdamW that fuses multiple tensor operations into a single GPU kernel, significantly improving speed and memory efficiency. Standard Adam updates involve multiple separate GPU kernel calls for computing gradients, momentum updates, and weight updates, leading to higher memory bandwidth usage and kernel launch overhead. In contrast, FusedAdamW merges these operations into one, reducing memory access costs and improving throughput. This is particularly beneficial for large-scale training tasks, such as LLMs and vision models, where memory efficiency directly impacts the batch size that can fit into GPU memory. It is believed that training with FusedAdamW can be twice as fast as without it. It is also highly chosen with mixed precision.

def configure_optimizers(self, weight_decay, learning_rate, device_type):
        # Start with all candidate params (those requiring gradients)
        param_dict= {pn: p for pn, p in self.named_parameters()}
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # Create optim groups. Any params that is 2D will be weight decay, otherwise no weight decay
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]  
        optim_groups = [
            {"params": decay_params, "weight_decay": weight_decay},
            {"params": nodecay_params, "weight_decay": 0.0},
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        if master_process:
            print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params} parameters")
            print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params} parameters")
        # Create AdamW optimizer and use the fused version if available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        if master_process:
            print(f"using fused AdamW: {use_fused}")
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=(0.9, 0.95), eps=1e-8, fused=use_fused)
        return optimizer

and later during initialization:

optimizer = model.configure_optimizers(weight_decay=0.01, learning_rate=6e-4, device=device)

Gradient Accumulation

Gradient accumulation is a technique that allows training with effectively larger batch sizes, even when GPU memory is limited. In standard training, gradients are computed and updated after processing each batch, but with gradient accumulation, multiple mini-batches are processed before updating the model weights. Instead of performing an optimizer step after every batch, the gradients are accumulated (summed) across multiple batches, and only after a set number of accumulation steps (e.g., accumulation_steps = 4), the optimizer updates the weights.

For example, if a model requires a batch size of 128 for stable training but the GPU can only handle batch size 32, gradient accumulation over 4 steps allows the optimizer to update as if training with batch size 128. Mathematically, this means that the total loss over multiple steps is summed, and the final gradient update is equivalent to training with a larger batch.

grad_accum_steps = total_batch_size // (B * T * ddp_world_size)

Hyperparameter Tuning

Sometimes, the loss that we gained was from basically deleting the usage of tokens that never occur, probably most of the loss gain. Do not arrive at a false realization.

Finally, the good old trick of adjust hyper-parameters such as learning rate, batch size, dropout rate,... to enhance performance. While random picking seldom leads to noticeable improvement, there are systematic searching methods such as grid search, random search, or Bayesian optimization to find the best configuration. Proper tuning significantly impacts training stability, convergence speed, and final model accuracy.

That's all for training with a single GPU. While this is sufficient for moderate models with non-resource-intensive tasks, for other important tasks we might have to turn to multiple GPUs, which we will explore next.

Distributed Training

Distributed Data Parallel is a technique in PyTorch that enables efficient deep learning training across multiple GPUs by splitting data, synchronizing gradients, and managing communication between processes. This is crucial for training large language models (LLMs), computer vision networks, and deep reinforcement learning agents, where single-GPU training can be too slow or memory-intensive.

The core idea behind DDP is replicating the model across all GPUs, assigning each process (worker) a subset of the training data, and synchronizing gradients at each step so that all GPUs contribute to the optimization process. Each GPU only computes its local portion of the batch (determined by ddp_rank), and gradients are averaged across all GPUs before updating the model parameters. This allows PyTorch to leverage multiple GPUs effectively, making training significantly faster and enabling larger batch sizes without exceeding the memory limit of a single GPU.

Firstly, we detect available GPUS and set DDP parameters:

num_gpus = torch.cuda.device_count()
if num_gpus > 1:
    device = "cuda"
    master_process = True
    ddp_rank = 0 # unique ID of each GPU process
    ddp_world_size = num_gpus # Total number of processes 
    print(f"Using {num_gpus} GPUs with DataParallel")
else:
    ddp_rank = 0
    ddp_world_size = 1
    master_process = True
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

Then we create compatible DataLoader:

train_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="train") 
val_loader = DataLoaderLite(B=B, T=T, process_rank=ddp_rank, num_processes=ddp_world_size, split="val")

Then we initialize the model:

model = GPT(GPTConfig(vocab_size=50304)) model.to(device) if torch.cuda.device_count() > 1: model = torch.nn.DataParallel(model) raw_model = model.module if hasattr(model, "module") else model

We also synchronize the gradients:

for micro_step in range(grad_accum_steps):
    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        logits, loss = model(x, y)
    loss = loss.mean() / grad_accum_steps
    train_losses.append(loss.detach())
    loss.backward()

And finally, we have the complete training loop:

for step in range(max_steps):
    t0 = time.time()
    last_step = (step == max_steps - 1)
 
	 # --- Validation Loop ---
	 if step % 250 == 0 or last_step:
		 pass
	 # --- HellaSwag Evaluation ---
	 if (step % 250 == 0 or last_step) and (not use_compile):
		 pass
	 # --- Generation Sample ---
     if ((step > 0 and step % 250 == 0) or last_step) and (not use_compile):
	     pass
	 # --- Training Step ---
    model.train()
    optimizer.zero_grad()
    train_losses = []
    for micro_step in range(grad_accum_steps):
        x, y = train_loader.next_batch()
        x, y = x.to(device), y.to(device)
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            logits, loss = model(x, y)
        # Ensure loss is reduced to a scalar and scale for gradient accumulation
        loss = loss.mean() / grad_accum_steps
        train_losses.append(loss.detach())
        loss.backward()
    train_loss_tensor = torch.stack(train_losses)
    avg_train_loss = train_loss_tensor.mean()
 
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    optimizer.step()
    if device.startswith("cuda"):
        torch.cuda.synchronize()
    t1 = time.time()
    dt = t1 - t0
    tokens_processed = train_loader.B * train_loader.T * grad_accum_steps * ddp_world_size
    tokens_per_sec = tokens_processed / dt
    if master_process:
        print(f"step {step:5d} | loss: {avg_train_loss.item():.6f} | lr {lr:.4e} | dt: {dt*1000:.2f}ms | tok/sec: {tokens_per_sec:.2f}")
        with open(log_file, "a") as f:
            f.write(f"{step} train {avg_train_loss.item():.6f}\n")

In essence, here are the step:

  1. Download the FineWeb-Edu dataset.
  2. Tokenize text using GPT-2 tokenizer (as illustrated in the OG paper)
  3. Process data in parallel using multiprocessing
  4. Split tokenized data into .npy shards (each 100M tokens)
  5. Save the first shard as 'val' (validation set), the rest as 'train'

That's everything we need to carry out in order to replicate the results of the original paper "Language Models are Unsupervised Multitask Learners". If you wish to see full source code, I recommend watching Andrej Karpathy's video. His channel is really a gem to the AI community. I hope that in the future, I would be able to take advantage of every optimization technique listed in this article to better train models.