Training deep neural networks requires more than just architecture design — practical techniques prevent common failure modes. Weight initialisation ensures gradients flow properly at training start. Gradient clipping prevents exploding gradients (common in RNNs). Batch size affects gradient noise and generalisation. Learning rate scheduling adapts the learning rate over training. Monitoring training and validation loss curves diagnoses overfitting, underfitting, and other training pathologies early.
Weight initialisation strategies
Xavier, He, and orthogonal initialisation compared
import torch
import torch.nn as nn
import numpy as np
# ── ZERO INITIALISATION (never use for weights) ──
# All neurons compute the same output → symmetry never broken → useless
# ── RANDOM INITIALISATION (naive) ──
# Problem: variance grows with layer depth → activations explode or vanish
# ── XAVIER/GLOROT INITIALISATION ──
# For sigmoid/tanh: balances gradient variance across layers
# W ~ U(-sqrt(6/(n_in + n_out)), sqrt(6/(n_in + n_out)))
n_in, n_out = 256, 128
xavier_std = np.sqrt(2.0 / (n_in + n_out))
print(f"Xavier std: {xavier_std:.4f}")
# ── HE INITIALISATION (Kaiming) ──
# For ReLU: accounts for the 50% dead neurons — doubled variance
# W ~ N(0, sqrt(2/n_in))
he_std = np.sqrt(2.0 / n_in)
print(f"He (Kaiming) std: {he_std:.4f}")
# PyTorch: initialise layers manually
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, mode='fan_in',
nonlinearity='relu') # He for ReLU
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out',
nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0, std=0.01)
model = nn.Sequential(
nn.Linear(784, 512), nn.ReLU(),
nn.Linear(512, 256), nn.ReLU(),
nn.Linear(256, 10)
)
model.apply(init_weights) # Apply to all submodules
# ── ORTHOGONAL INITIALISATION ──
# For RNNs: initialise recurrent weight matrix as orthogonal matrix
# Preserves gradient norm over many timesteps
nn.init.orthogonal_(nn.GRU(128, 256).weight_hh_l0)
# ── PRETRAINED INITIALISATION ──
# Most powerful: start from a pretrained model (ImageNet, BERT, etc.)
# Fine-tuning is usually better than training from scratch
# from transformers import BertModel
# model = BertModel.from_pretrained('bert-base-uncased')Gradient clipping — taming exploding gradients
Gradient clipping and monitoring
import torch
import torch.nn as nn
model = nn.GRU(128, 256, 2, batch_first=True)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
# Problem without clipping in RNNs:
# Gradients through time steps can multiply exponentially → exploding
# Symptom: loss goes to NaN after a few steps
# SOLUTION: Gradient clipping — cap gradient norm at max_norm
for epoch in range(100):
# Forward + loss computation
output, hidden = model(torch.randn(32, 50, 128))
logits = output[:, -1] # Last timestep
loss = loss_fn(logits, torch.randint(0, 256, (32,)))
optimizer.zero_grad()
loss.backward()
# Clip gradients BEFORE optimizer.step()
total_norm = nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=1.0 # Clip if global gradient norm > 1.0
# Typical values: 0.5, 1.0, 5.0 depending on task
)
# Monitor: if total_norm > max_norm often, reduce learning rate
if total_norm > 1.0:
pass # print(f"Clipping: norm={total_norm:.2f}")
optimizer.step()
# Gradient norm monitoring without clipping (for debugging)
def get_gradient_norm(model):
total_norm = 0
for p in model.parameters():
if p.grad is not None:
total_norm += p.grad.data.norm(2).item() ** 2
return total_norm ** 0.5Batch size, epochs, and training diagnostics
| Hyperparameter | Small value | Large value | Sweet spot |
|---|---|---|---|
| Batch size | Noisy gradients, slow per epoch, better generalisation, less memory | Smooth gradients, fast training, may overfit, more memory | 32-256 for most tasks; transformers 512-4096 |
| Epochs | Underfitting — not enough training | Overfitting — memorising training data | Use early stopping with patience |
| Learning rate | Slow convergence | Divergence / oscillation | Use LR finder or warmup + cosine decay |
Training loop with loss curve monitoring
import torch
import matplotlib
matplotlib.use('Agg')
from torch.utils.tensorboard import SummaryWriter
# TensorBoard: visualise training in real-time
writer = SummaryWriter('runs/experiment_1')
model = nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
loss_fn = nn.MSELoss()
train_losses, val_losses = [], []
for epoch in range(100):
# ── Training ──
model.train() # Enables dropout, batchnorm in training mode
train_loss = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
loss = loss_fn(model(X_batch), y_batch)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
# ── Validation ──
model.eval() # Disables dropout, uses running stats for batchnorm
with torch.no_grad():
val_loss = sum(loss_fn(model(X_b), y_b).item() for X_b, y_b in val_loader)
val_loss /= len(val_loader)
scheduler.step()
train_losses.append(train_loss)
val_losses.append(val_loss)
# Log to TensorBoard
writer.add_scalars('Loss', {'Train': train_loss, 'Val': val_loss}, epoch)
writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch)
# Diagnose from loss curves:
# Train↓ Val↓ → Underfitting (both improving) → train longer, larger model
# Train↓ Val→ (gap) → Overfitting → regularise, more data, smaller model
# Train↓ Val↑ (cross) → Classic overfitting → early stop here
# Both oscillating → LR too high → reduce learning rate
# Both stuck → LR too low or vanishing gradientsPractice questions
- Why does He initialisation use sqrt(2/n_in) while Xavier uses sqrt(2/(n_in+n_out))? (Answer: He accounts for ReLU zeroing out ~50% of activations. With ReLU, effective fan-in is n_in/2. Multiplying by 2 compensates, maintaining variance across layers. Xavier was designed for sigmoid/tanh which do not zero out activations.)
- Gradient clipping: global norm clipping vs per-parameter clipping — which is preferred? (Answer: Global norm clipping (clip_grad_norm_): computes the global L2 norm of all parameters' gradients, scales all down proportionally if it exceeds max_norm. Preserves gradient direction. Per-parameter clipping (clip_grad_value_): clips each gradient value independently — can distort the overall gradient direction. Use global norm clipping.)
- Training loss = 0.01, validation loss = 2.50. What is happening and what do you do? (Answer: Severe overfitting — the model has memorised training data. Fix: (1) Increase regularisation (dropout, weight decay). (2) Reduce model capacity (fewer layers/neurons). (3) Get more training data or augment. (4) Use early stopping. (5) Apply L1/L2 weight decay.)
- Why does small batch size (e.g., 8) often generalise better than large batch size (e.g., 4096)? (Answer: Small batches introduce gradient noise — gradients are computed from 8 examples, not the whole dataset. This noise acts as regularisation, helping the model find flatter minima in the loss landscape that generalise better. Large batches find sharper minima closer to training data, potentially overfitting.)
- model.train() vs model.eval() — what changes between these modes? (Answer: Dropout: active in train(), disabled in eval() (all neurons used). BatchNorm: uses mini-batch statistics in train(), uses running mean/variance in eval(). Failing to switch to eval() for inference gives inconsistent predictions and incorrect BatchNorm behaviour — always use model.eval() + torch.no_grad() for evaluation and deployment.)
On LumiChats
LumiChats can diagnose training issues from your loss curves, suggest the right initialisation strategy for your architecture, and write complete training loops with proper weight initialisation, gradient clipping, learning rate scheduling, and early stopping.
Try it free