pytorch-lightning

Framework PyTorch de haut niveau avec classe Trainer, entraînement distribué automatique (DDP/FSDP/DeepSpeed), système de callbacks et code minimal. Évolue de…

npx skills add https://github.com/firecrawl/ai-research-skills --skill pytorch-lightning

PyTorch Lightning - High-Level Training Framework

Quick start

PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.

Installation:

pip install lightning

Convert PyTorch to Lightning (3 steps):

import lightning as L
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

# Step 1: Define LightningModule (organize your PyTorch code)
class LitModel(L.LightningModule):
    def __init__(self, hidden_size=128):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(28 * 28, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 10)
        )

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)  # Auto-logged to TensorBoard
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Step 2: Create data
train_loader = DataLoader(train_dataset, batch_size=32)

# Step 3: Train with Trainer (handles everything else!)
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
model = LitModel()
trainer.fit(model, train_loader)

That's it! Trainer handles:

  • GPU/TPU/CPU switching
  • Distributed training (DDP, FSDP, DeepSpeed)
  • Mixed precision (FP16, BF16)
  • Gradient accumulation
  • Checkpointing
  • Logging
  • Progress bars

Common workflows

Workflow 1: From PyTorch to Lightning

Original PyTorch code:

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
model.to('cuda')

for epoch in range(max_epochs):
    for batch in train_loader:
        batch = batch.to('cuda')
        optimizer.zero_grad()
        loss = model(batch)
        loss.backward()
        optimizer.step()

Lightning version:

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyModel()

    def training_step(self, batch, batch_idx):
        loss = self.model(batch)  # No .to('cuda') needed!
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

# Train
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
trainer.fit(LitModel(), train_loader)

Benefits: 40+ lines → 15 lines, no device management, automatic distributed

Workflow 2: Validation and testing

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyModel()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        val_loss = nn.functional.cross_entropy(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val_loss', val_loss)
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        test_loss = nn.functional.cross_entropy(y_hat, y)
        self.log('test_loss', test_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Train with validation
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)

# Test
trainer.test(model, test_loader)

Automatic features:

  • Validation runs every epoch by default
  • Metrics logged to TensorBoard
  • Best model checkpointing based on val_loss

Workflow 3: Distributed training (DDP)

# Same code as single GPU!
model = LitModel()

# 8 GPUs with DDP (automatic!)
trainer = L.Trainer(
    accelerator='gpu',
    devices=8,
    strategy='ddp'  # Or 'fsdp', 'deepspeed'
)

trainer.fit(model, train_loader)

Launch:

# Single command, Lightning handles the rest
python train.py

No changes needed:

  • Automatic data distribution
  • Gradient synchronization
  • Multi-node support (just set num_nodes=2)

Workflow 4: Callbacks for monitoring

from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

# Create callbacks
checkpoint = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    filename='model-{epoch:02d}-{val_loss:.2f}'
)

early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    mode='min'
)

lr_monitor = LearningRateMonitor(logging_interval='epoch')

# Add to Trainer
trainer = L.Trainer(
    max_epochs=100,
    callbacks=[checkpoint, early_stop, lr_monitor]
)

trainer.fit(model, train_loader, val_loader)

Result:

  • Auto-saves best 3 models
  • Stops early if no improvement for 5 epochs
  • Logs learning rate to TensorBoard

Workflow 5: Learning rate scheduling

class LitModel(L.LightningModule):
    # ... (training_step, etc.)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

        # Cosine annealing
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=100,
            eta_min=1e-5
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',  # Update per epoch
                'frequency': 1
            }
        }

# Learning rate auto-logged!
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, train_loader)

When to use vs alternatives

Use PyTorch Lightning when:

  • Want clean, organized code
  • Need production-ready training loops
  • Switching between single GPU, multi-GPU, TPU
  • Want built-in callbacks and logging
  • Team collaboration (standardized structure)

Key advantages:

  • Organized: Separates research code from engineering
  • Automatic: DDP, FSDP, DeepSpeed with 1 line
  • Callbacks: Modular training extensions
  • Reproducible: Less boilerplate = fewer bugs
  • Tested: 1M+ downloads/month, battle-tested

Use alternatives instead:

  • Accelerate: Minimal changes to existing code, more flexibility
  • Ray Train: Multi-node orchestration, hyperparameter tuning
  • Raw PyTorch: Maximum control, learning purposes
  • Keras: TensorFlow ecosystem

Common issues

Issue: Loss not decreasing

Check data and model setup:

# Add to training_step
def training_step(self, batch, batch_idx):
    if batch_idx == 0:
        print(f"Batch shape: {batch[0].shape}")
        print(f"Labels: {batch[1]}")
    loss = ...
    return loss

Issue: Out of memory

Reduce batch size or use gradient accumulation:

trainer = L.Trainer(
    accumulate_grad_batches=4,  # Effective batch = batch_size × 4
    precision='bf16'  # Or 'fp16', reduces memory 50%
)

Issue: Validation not running

Ensure you pass val_loader:

# WRONG
trainer.fit(model, train_loader)

# CORRECT
trainer.fit(model, train_loader, val_loader)

Issue: DDP spawns multiple processes unexpectedly

Lightning auto-detects GPUs. Explicitly set devices:

# Test on CPU first
trainer = L.Trainer(accelerator='cpu', devices=1)

# Then GPU
trainer = L.Trainer(accelerator='gpu', devices=1)

Advanced topics

Callbacks: See references/callbacks.md for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.

Distributed strategies: See references/distributed.md for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.

Hyperparameter tuning: See references/hyperparameter-tuning.md for integration with Optuna, Ray Tune, and WandB sweeps.

Hardware requirements

  • CPU: Works (good for debugging)
  • Single GPU: Works
  • Multi-GPU: DDP (default), FSDP, or DeepSpeed
  • Multi-node: DDP, FSDP, DeepSpeed
  • TPU: Supported (8 cores)
  • Apple MPS: Supported

Precision options:

  • FP32 (default)
  • FP16 (V100, older GPUs)
  • BF16 (A100/H100, recommended)
  • FP8 (H100)

Resources

Plus de skills de firecrawl

oracle
firecrawl
Meilleures pratiques pour utiliser l'interface en ligne de commande oracle (invite + regroupement de fichiers, moteurs, sessions et modèles de pièces jointes).
official
firecrawl-monitor
firecrawl
Détectez quand le contenu d'un site web change et recevez une notification par webhook ou e-mail — sans cron jobs, scrapers ni scripts de diff. Utilisez cette compétence lorsque l'utilisateur souhaite suivre les modifications d'une page, surveiller les prix des concurrents, être alerté de nouvelles offres d'emploi ou articles de blog, surveiller les pages de documentation/changelog/statut, ou dit « surveiller », « suivre », « tracker », « alerte-moi quand », « notifie-moi quand X change », « préviens-moi si », « envoie-moi un e-mail quand » ou « envoie un webhook quand ». Un juge IA intégré filtre la mise en forme, les horodatages et...
officialweb-scrapingresearch
firecrawl-deep-research
firecrawl
Effectuer une recherche approfondie multi-sources avec Firecrawl. À utiliser lorsque l'utilisateur demande de rechercher un sujet, comparer des perspectives, produire un briefing sourcé, investiguer une question technique ou de marché, ou synthétiser des preuves web provenant de nombreuses sources.
officialresearchweb-scraping
firecrawl-research-papers
firecrawl
Trouver et synthétiser des articles de recherche, livres blancs, PDF, rapports techniques et sources académiques avec Firecrawl. À utiliser lorsque l'utilisateur souhaite une revue de littérature, un résumé d'article, un panorama de la recherche ou une synthèse sourcée à partir de PDF et de publications académiques ou industrielles.
officialresearchweb-scraping
firecrawl-market-research
firecrawl
Extraire les métriques de marché, financières, de résultats, sectorielles et d'entreprise avec Firecrawl. À utiliser lorsque l'utilisateur demande des études de marché, des tendances sectorielles, des données sur les entreprises publiques, des comparaisons financières, des recherches sur les résultats ou des rapports de marché structurés.
officialresearchweb-scraping
firecrawl-website-design-clone
firecrawl
Extraire le système de design de n'importe quel site web dans un DESIGN.md prêt pour un agent, en utilisant les preuves de scraping de Firecrawl. À utiliser lorsque l'utilisateur souhaite obtenir des couleurs, polices, espacements, composants, motifs de mise en page ou directives de marque/UI d'un site web, afin que des agents IA puissent créer de nouveaux sites web, cloner une apparence ou construire des pages inspirées de ce design.
officialdesignweb-scraping
firecrawl-knowledge-base
firecrawl
Construisez une base de connaissances à partir de contenu web avec Firecrawl. Utilisez-la pour des documents de référence locaux, des segments prêts pour le RAG, des jeux de données de fine-tuning, des miroirs de documentation, des corpus thématiques ou du markdown prêt pour LLM organisé à partir de sources web.
officialweb-scrapingresearch
firecrawl-lead-research
firecrawl
Produire des fiches de renseignement pré-réunion avec Firecrawl. À utiliser lorsque l'utilisateur a besoin de recherches sur une entreprise, une personne, d'actualités récentes, de points de discussion, de points sensibles ou de préparation de prospection avant un appel commercial, une réunion de partenariat, une conversation avec un investisseur ou un entretien client.
officialresearchweb-scraping