ray-train

Distributed training orchestration across clusters. Scales PyTorch/TensorFlow/HuggingFace from laptop to 1000s of nodes. Built-in hyperparameter tuning with…

npx skills add https://github.com/firecrawl/ai-research-skills --skill ray-train

Ray Train - Distributed Training Orchestration

Quick start

Ray Train scales machine learning training from single GPU to multi-node clusters with minimal code changes.

Installation:

pip install -U "ray[train]"

Basic PyTorch training (single node):

import ray
from ray import train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer
import torch
import torch.nn as nn

# Define training function
def train_func(config):
    # Your normal PyTorch code
    model = nn.Linear(10, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Prepare for distributed (Ray handles device placement)
    model = train.torch.prepare_model(model)

    for epoch in range(10):
        # Your training loop
        output = model(torch.randn(32, 10))
        loss = output.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Report metrics (logged automatically)
        train.report({"loss": loss.item(), "epoch": epoch})

# Run distributed training
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(
        num_workers=4,  # 4 GPUs/workers
        use_gpu=True
    )
)

result = trainer.fit()
print(f"Final loss: {result.metrics['loss']}")

That's it! Ray handles:

  • Distributed coordination
  • GPU allocation
  • Fault tolerance
  • Checkpointing
  • Metric aggregation

Common workflows

Workflow 1: Scale existing PyTorch code

Original single-GPU code:

model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters())

for epoch in range(epochs):
    for batch in dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()

Ray Train version (scales to multi-GPU/multi-node):

from ray.train.torch import TorchTrainer
from ray import train

def train_func(config):
    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters())

    # Prepare for distributed (automatic device placement)
    model = train.torch.prepare_model(model)
    dataloader = train.torch.prepare_data_loader(dataloader)

    for epoch in range(epochs):
        for batch in dataloader:
            loss = model(batch)
            loss.backward()
            optimizer.step()

            # Report metrics
            train.report({"loss": loss.item()})

# Scale to 8 GPUs
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)
trainer.fit()

Benefits: Same code runs on 1 GPU or 1000 GPUs

Workflow 2: HuggingFace Transformers integration

from ray.train.huggingface import TransformersTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments

def train_func(config):
    # Load model and tokenizer
    model = AutoModelForCausalLM.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")

    # Training arguments (HuggingFace API)
    training_args = TrainingArguments(
        output_dir="./output",
        num_train_epochs=3,
        per_device_train_batch_size=8,
        learning_rate=2e-5,
    )

    # Ray automatically handles distributed training
    from transformers import Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
    )

    trainer.train()

# Scale to multi-node (2 nodes × 8 GPUs = 16 workers)
trainer = TransformersTrainer(
    train_func,
    scaling_config=ScalingConfig(
        num_workers=16,
        use_gpu=True,
        resources_per_worker={"GPU": 1}
    )
)

result = trainer.fit()

Workflow 3: Hyperparameter tuning with Ray Tune

from ray import tune
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import ASHAScheduler

def train_func(config):
    # Use hyperparameters from config
    lr = config["lr"]
    batch_size = config["batch_size"]

    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    model = train.torch.prepare_model(model)

    for epoch in range(10):
        # Training loop
        loss = train_epoch(model, optimizer, batch_size)
        train.report({"loss": loss, "epoch": epoch})

# Define search space
param_space = {
    "lr": tune.loguniform(1e-5, 1e-2),
    "batch_size": tune.choice([16, 32, 64, 128])
}

# Run 20 trials with early stopping
tuner = tune.Tuner(
    TorchTrainer(
        train_func,
        scaling_config=ScalingConfig(num_workers=4, use_gpu=True)
    ),
    param_space=param_space,
    tune_config=tune.TuneConfig(
        num_samples=20,
        scheduler=ASHAScheduler(metric="loss", mode="min")
    )
)

results = tuner.fit()
best = results.get_best_result(metric="loss", mode="min")
print(f"Best hyperparameters: {best.config}")

Result: Distributed hyperparameter search across cluster

Workflow 4: Checkpointing and fault tolerance

from ray import train
from ray.train import Checkpoint

def train_func(config):
    model = MyModel()
    optimizer = torch.optim.Adam(model.parameters())

    # Try to resume from checkpoint
    checkpoint = train.get_checkpoint()
    if checkpoint:
        with checkpoint.as_directory() as checkpoint_dir:
            state = torch.load(f"{checkpoint_dir}/model.pt")
            model.load_state_dict(state["model"])
            optimizer.load_state_dict(state["optimizer"])
            start_epoch = state["epoch"]
    else:
        start_epoch = 0

    model = train.torch.prepare_model(model)

    for epoch in range(start_epoch, 100):
        loss = train_epoch(model, optimizer)

        # Save checkpoint every 10 epochs
        if epoch % 10 == 0:
            checkpoint = Checkpoint.from_directory(
                train.get_context().get_trial_dir()
            )
            torch.save({
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "epoch": epoch
            }, checkpoint.path / "model.pt")

            train.report({"loss": loss}, checkpoint=checkpoint)

trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True)
)

# Automatically resumes from checkpoint if training fails
result = trainer.fit()

Workflow 5: Multi-node training

from ray.train import ScalingConfig

# Connect to Ray cluster
ray.init(address="auto")  # Or ray.init("ray://head-node:10001")

# Train across 4 nodes × 8 GPUs = 32 workers
trainer = TorchTrainer(
    train_func,
    scaling_config=ScalingConfig(
        num_workers=32,
        use_gpu=True,
        resources_per_worker={"GPU": 1, "CPU": 4},
        placement_strategy="SPREAD"  # Spread across nodes
    )
)

result = trainer.fit()

Launch Ray cluster:

# On head node
ray start --head --port=6379

# On worker nodes
ray start --address=<head-node-ip>:6379

When to use vs alternatives

Use Ray Train when:

  • Training across multiple machines (multi-node)
  • Need hyperparameter tuning at scale
  • Want fault tolerance (auto-restart failed workers)
  • Elastic scaling (add/remove nodes during training)
  • Unified framework (same code for PyTorch/TF/HF)

Key advantages:

  • Multi-node orchestration: Easiest multi-node setup
  • Ray Tune integration: Best-in-class hyperparameter tuning
  • Fault tolerance: Automatic recovery from failures
  • Elastic: Add/remove nodes without restarting
  • Framework agnostic: PyTorch, TensorFlow, HuggingFace, XGBoost

Use alternatives instead:

  • Accelerate: Single-node multi-GPU, simpler
  • PyTorch Lightning: High-level abstractions, callbacks
  • DeepSpeed: Maximum performance, complex setup
  • Raw DDP: Maximum control, minimal overhead

Common issues

Issue: Ray cluster not connecting

Check ray status:

ray status

# Should show:
# - Nodes: 4
# - GPUs: 32
# - Workers: Ready

If not connected:

# Restart head node
ray stop
ray start --head --port=6379 --dashboard-host=0.0.0.0

# Restart worker nodes
ray stop
ray start --address=<head-ip>:6379

Issue: Out of memory

Reduce workers or use gradient accumulation:

scaling_config=ScalingConfig(
    num_workers=4,  # Reduce from 8
    use_gpu=True
)

# In train_func, accumulate gradients
for i, batch in enumerate(dataloader):
    loss = model(batch) / accumulation_steps
    loss.backward()

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

Issue: Slow training

Check if data loading is bottleneck:

import time

def train_func(config):
    for epoch in range(epochs):
        start = time.time()
        for batch in dataloader:
            data_time = time.time() - start
            # Train...
            start = time.time()
            print(f"Data loading: {data_time:.3f}s")

If data loading is slow, increase workers:

dataloader = DataLoader(dataset, num_workers=8)

Advanced topics

Multi-node setup: See references/multi-node.md for Ray cluster deployment on AWS, GCP, Kubernetes, and SLURM.

Hyperparameter tuning: See references/hyperparameter-tuning.md for Ray Tune integration, search algorithms (Optuna, HyperOpt), and population-based training.

Custom training loops: See references/custom-loops.md for advanced Ray Train usage, custom backends, and integration with other frameworks.

Hardware requirements

  • Single node: 1+ GPUs (or CPUs)
  • Multi-node: 2+ machines with network connectivity
  • Cloud: AWS, GCP, Azure (Ray autoscaling)
  • On-prem: Kubernetes, SLURM clusters

Supported accelerators:

  • NVIDIA GPUs (CUDA)
  • AMD GPUs (ROCm)
  • TPUs (Google Cloud)
  • CPUs

Resources

More skills from firecrawl

oracle
firecrawl
Best practices for using the oracle CLI (prompt + file bundling, engines, sessions, and file attachment patterns).
official
firecrawl-monitor
firecrawl
Detect when content on a website changes and get notified by webhook or email — no cron jobs, scrapers, or diff scripts required. Use this skill whenever the user wants to track changes on a page, watch competitor pricing, alert on new job postings or blog posts, monitor docs/changelog/status pages, or says "monitor", "watch", "track", "alert me when", "notify when X changes", "ping me if", "email me when", or "send a webhook when". A built-in AI judge filters out formatting, timestamp, and...
officialweb-scrapingresearch
firecrawl-deep-research
firecrawl
Run multi-source deep research with Firecrawl. Use when the user asks to research a topic, compare perspectives, produce a sourced briefing, investigate a technical or market question, or synthesize web evidence across many sources.
officialresearchweb-scraping
firecrawl-research-papers
firecrawl
Find and synthesize research papers, whitepapers, PDFs, technical reports, and academic sources with Firecrawl. Use when the user wants a literature review, paper summary, research landscape, or sourced synthesis from PDFs and scholarly/industry publications.
officialresearchweb-scraping
firecrawl-market-research
firecrawl
Extract market, financial, earnings, industry, and company metrics with Firecrawl. Use when the user asks for market research, industry trends, public company data, financial comparisons, earnings research, or structured market reports.
officialresearchweb-scraping
firecrawl-website-design-clone
firecrawl
Extract any website's design system into an agent-ready DESIGN.md using Firecrawl scrape evidence. Use when the user wants colors, fonts, spacing, components, layout patterns, or brand/UI guidance from a website so AI agents can create new websites, clone a look, or build pages inspired by that design.
officialdesignweb-scraping
firecrawl-knowledge-base
firecrawl
Build a knowledge base from web content with Firecrawl. Use for local reference docs, RAG-ready chunks, fine-tuning datasets, documentation mirrors, topic corpora, or LLM-ready markdown organized from web sources.
officialweb-scrapingresearch
firecrawl-lead-research
firecrawl
Produce pre-meeting lead intelligence briefs with Firecrawl. Use when the user needs company research, person research, recent news, talking points, pain points, or outreach preparation before a sales call, partnership meeting, investor conversation, or customer interview.
officialresearchweb-scraping