rocm-kernels

Fornece orientação para escrever e avaliar kernels Triton otimizados para GPUs AMD (MI355X, R9700) no ROCm, voltados para difusores HuggingFace (LTX-Video, SD3,…)

npx skills add https://github.com/huggingface/kernels --skill rocm-kernels

ROCm Triton Kernels for Diffusers & Transformers

This skill provides patterns and guidance for developing optimized Triton kernels targeting AMD GPUs (MI355X, R9700) on ROCm, for use with HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers libraries.

Quick Start

Diffusers (LTX-Video)

Inject optimized kernels into LTX-Video pipeline:

import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'

from diffusers import LTXPipeline
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")  # ROCm uses same API via HIP
inject_optimized_kernels(pipe)  # BEFORE CPU offloading
pipe.enable_model_cpu_offload()

For a minimal integration example (~150 lines):

python scripts/ltx_kernel_injection_example.py

Isolated Kernel Micro-benchmarks

# All 4 kernels: correctness + performance + bandwidth
python scripts/benchmark_kernels.py

# Single kernel
python scripts/benchmark_kernels.py --kernel rmsnorm
python scripts/benchmark_kernels.py --kernel rope
python scripts/benchmark_kernels.py --kernel geglu
python scripts/benchmark_kernels.py --kernel adaln

End-to-End Pipeline Benchmark

# Compare baseline vs Triton vs torch.compile
python scripts/benchmark_e2e.py --mode all

# Quick test
python scripts/benchmark_e2e.py --mode triton --num-frames 9 --steps 5

# Save results for comparison
python scripts/benchmark_e2e.py --mode all --output-json results.json

Target Model: LTX-Video

Architecture Overview

ComponentClassHas WeightCountKernel
transformer_blocks.*.norm1RMSNormNo (elementwise_affine=False)56RMSNorm
transformer_blocks.*.norm2RMSNormNo56RMSNorm
transformer_blocks.*.attn1.norm_qtorch.nn.RMSNormYes28RMSNorm
transformer_blocks.*.attn1.norm_ktorch.nn.RMSNormYes28RMSNorm
transformer_blocks.*.ffFeedForward-28GELU (not GEGLU!)
Rotary position encodingLTXVideoRotaryPosEmbed-1RoPE 3D

Total RMSNorm modules: 168 (56 with weights, 112 without)

Target Kernels

KernelUse CaseInput LayoutKey Challenge
RMSNormNormalization[..., hidden_size]Weight may be None; 168 instances
RoPE 3DVideo position encoding[batch, t*h*w, heads, head_dim]3D → temporal + spatial decomposition
GEGLUGated activation (SD3/FLUX)[batch, seq, 2*hidden][batch, seq, hidden]Gate/value split
AdaLNConditioned normalization (DiT)norm(x) * weight * (1+scale) + shiftFused norm + condition

Supported Hardware

GPUArchitectureWave SizeLDS/CUMem BWKey FeatureVerified
MI355XCDNA3+ (gfx950)Wave64160 KB8 TB/s32 XCDs, XCD Swizzle for GEMMYes
R9700RDNA4 (gfx1201)Wave3264 KB~608 GB/s256B cacheline, inference-focusedYes

See MI355X guide | R9700 guide

When This Skill Applies

Use this skill when:

  • Writing Triton kernels for RMSNorm, RoPE, GEGLU, AdaLN on AMD GPUs
  • Integrating custom kernels with diffusers pipelines (LTX-Video, SD3, FLUX)
  • Benchmarking kernel performance against PyTorch baseline on ROCm
  • Optimizing existing kernels for MI355X or R9700 architecture
  • Debugging ROCm/HIP-specific kernel issues

Critical ROCm Constraints

Things That DON'T Work on AMD

# FORBIDDEN - CUDA only, NOT available on ROCm
tl.libdevice.tanh(x)          # Use manual formula below
tl.libdevice.log1p(x)         # Use: tl.log(1.0 + x)
tl.math.tanh(x)               # Also NOT available on ROCm Triton

# Manual tanh (ONLY reliable method on ROCm):
e2x = tl.exp(2.0 * x)
tanh_x = (e2x - 1.0) / (e2x + 1.0)

# FORBIDDEN - Triton limitations on ROCm
break / continue               # Use: tl.where()
min(a, b) / max(a, b)          # Use: tl.minimum(a, b) / tl.maximum(a, b)

Mandatory Environment Variables

import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'

Core Kernel Implementations

1. RMSNorm (Core Optimization Target)

Row-wise reduction pattern. 168 instances in LTX-Video, ~5% of total compute.

CRITICAL: Do NOT autotune BLOCK_D. Autotune may pick BLOCK_D < D, causing partial row processing and wrong results. Always compute BLOCK_D = triton.next_power_of_2(D) in the Python wrapper.

@triton.jit
def rmsnorm_kernel(
    x_ptr, weight_ptr, out_ptr,
    stride_x, D,
    eps: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_D)
    mask = offs < D
    x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)

    variance = tl.sum(x * x, axis=0) / D
    rms_inv = tl.rsqrt(variance + eps)

    if HAS_WEIGHT:
        w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
        out = x * rms_inv * w
    else:
        out = x * rms_inv

    tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)


def triton_rmsnorm(x, weight=None, eps=1e-6):
    x_2d = x.contiguous().view(-1, x.shape[-1])
    out = torch.empty_like(x_2d)
    M, D = x_2d.shape
    has_weight = weight is not None
    if not has_weight:
        weight = torch.empty(0, device=x.device)

    BLOCK_D = triton.next_power_of_2(D)
    num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
    rmsnorm_kernel[(M,)](
        x_2d, weight, out, x_2d.stride(0), D, eps, has_weight,
        BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
    )
    return out.view_as(x)

LTX-Video pitfall: Weight may be None!

has_weight = hasattr(module, 'weight') and module.weight is not None

2. RoPE 3D (Video Position Encoding)

Element-wise pattern. LTX-Video splits head_dim into temporal + spatial components.

CRITICAL: cos/sin have shape [seq_len, head_dim]. When grid flattens batch dimension (batch * seq_len), use pid_s % seq_len to index cos/sin, otherwise batch > 1 causes OOB GPU crash.

@triton.jit
def rope_3d_kernel(
    qk_ptr, cos_ptr, sin_ptr, out_ptr,
    seq_len, num_heads, head_dim,
    stride_s, stride_h, stride_d,
    BLOCK_HD: tl.constexpr,
):
    pid_s = tl.program_id(0)  # batch * seq_len
    pid_h = tl.program_id(1)  # head index
    half_dim = head_dim // 2
    offs = tl.arange(0, BLOCK_HD)
    mask = offs < half_dim

    base = pid_s * stride_s + pid_h * stride_h
    x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32)
    x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32)

    seq_idx = pid_s % seq_len  # wrap for batch > 1
    cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32)
    sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32)

    out0 = x0 * cos_val - x1 * sin_val
    out1 = x0 * sin_val + x1 * cos_val

    tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask)
    tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask)


def triton_rope_3d(qk, cos, sin):
    qk = qk.contiguous()
    out = torch.empty_like(qk)
    batch, seq_len, num_heads, head_dim = qk.shape
    qk_flat = qk.view(batch * seq_len, num_heads, head_dim)
    out_flat = out.view(batch * seq_len, num_heads, head_dim)
    BLOCK_HD = triton.next_power_of_2(head_dim // 2)
    num_warps = 4 if BLOCK_HD <= 64 else 8
    rope_3d_kernel[(batch * seq_len, num_heads)](
        qk_flat, cos, sin, out_flat,
        seq_len, num_heads, head_dim,
        qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2),
        BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2,
    )
    return out

3. GEGLU (For SD3/FLUX, NOT LTX-Video)

Element-wise gated activation. Input [batch, seq, 2*hidden] → Output [batch, seq, hidden].

Same BLOCK_SIZE rule: compute dynamically, do NOT autotune.

@triton.jit
def geglu_kernel(
    input_ptr, output_ptr,
    stride_in, stride_out, hidden_size,
    BLOCK_H: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_H)
    mask = offs < hidden_size

    gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32)
    value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32)

    # GELU approx — manual tanh (tl.math.tanh NOT available on ROCm)
    k = 0.7978845608028654  # sqrt(2/pi)
    tanh_arg = k * (gate + 0.044715 * gate * gate * gate)
    e2x = tl.exp(2.0 * tanh_arg)
    tanh_val = (e2x - 1.0) / (e2x + 1.0)
    gate_gelu = 0.5 * gate * (1.0 + tanh_val)
    result = gate_gelu * value

    tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask)


def triton_geglu(x):
    x = x.contiguous()
    *batch_dims, double_h = x.shape
    hidden_size = double_h // 2
    x_2d = x.view(-1, double_h)
    M = x_2d.shape[0]
    out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype)
    BLOCK_H = triton.next_power_of_2(hidden_size)
    num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16)
    geglu_kernel[(M,)](
        x_2d, out, x_2d.stride(0), out.stride(0), hidden_size,
        BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2,
    )
    return out.view(*batch_dims, hidden_size)

Warning: LTX-Video uses GELU, NOT GEGLU. GEGLU is for SD3/FLUX.

4. AdaLN (Adaptive Layer Normalization for DiT)

Fused normalization + conditioning: norm(x) * weight * (1 + scale) + shift

Same BLOCK_D rule: compute dynamically.

@triton.jit
def adaln_kernel(
    x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr,
    stride_x, stride_cond, D,
    eps: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_D)
    mask = offs < D
    x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)

    variance = tl.sum(x * x, axis=0) / D
    rms_inv = tl.rsqrt(variance + eps)
    x_norm = x * rms_inv

    w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
    scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)
    shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)

    out = x_norm * w * (1.0 + scale) + shift
    tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)


def triton_adaln(x, weight, scale, shift, eps=1e-6):
    x_flat = x.contiguous().view(-1, x.shape[-1])
    scale_flat = scale.contiguous().view(-1, x.shape[-1])
    shift_flat = shift.contiguous().view(-1, x.shape[-1])
    out = torch.empty_like(x_flat)
    M, D = x_flat.shape
    BLOCK_D = triton.next_power_of_2(D)
    num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
    adaln_kernel[(M,)](
        x_flat, weight, scale_flat, shift_flat, out,
        x_flat.stride(0), scale_flat.stride(0), D, eps,
        BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
    )
    return out.view_as(x)

Diffusers Integration

See diffusers-integration.md for the complete guide.

Minimal Integration Pattern

def patch_rmsnorm_modules(model):
    """Patch all RMSNorm modules to use custom Triton kernel."""
    for name, module in model.named_modules():
        if type(module).__name__ == 'RMSNorm':
            eps = getattr(module, 'eps', 1e-6)
            has_weight = hasattr(module, 'weight') and module.weight is not None
            if has_weight:
                def make_forward(mod, epsilon):
                    def forward(x):
                        return triton_rmsnorm(x, mod.weight, eps=epsilon)
                    return forward
                module.forward = make_forward(module, eps)
            else:
                def make_forward(epsilon):
                    def forward(x):
                        w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
                        return triton_rmsnorm(x, w, eps=epsilon)
                    return forward
                module.forward = make_forward(eps)

pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")
patch_rmsnorm_modules(pipe.transformer)
pipe.enable_model_cpu_offload()

Diffusers Critical Pitfalls

  1. RMSNorm weight may be None — LTX-Video uses elementwise_affine=False
  2. Diffusers RMSNorm != torch.nn.RMSNorm — Use type(module).__name__ not isinstance()
  3. LTX-Video uses GELU, not GEGLU — Don't patch GEGLU for LTX-Video
  4. Inject BEFORE CPU offloadinginject_kernels() then enable_model_cpu_offload()

Performance Expectations

Micro-benchmark Results (MI355X, BF16)

KernelAvg SpeedupBest Config SpeedupStatus
RMSNorm1.71x2.44x ([4×4096×3072])PASS
RoPE 3D1.21x1.52x ([2×4096×16×128])PASS
GEGLU1.43x2.13x ([4×4096×8192])PASS
AdaLN2.22x2.77x ([4×4096×3072])PASS

RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%).

End-to-End LTX-Video (MI355X, 25 frames, 30 steps)

ModeTime (s)Per Step (s)Peak Mem (GB)Speedup
baseline1.200.04018.581.00x
triton0.980.03318.581.22x
torch.compile0.780.02618.581.54x

Key finding: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization.

Micro-benchmark Results (R9700, BF16)

KernelAvg SpeedupBest Config SpeedupStatus
RMSNorm2.90x3.97x ([1×8192×2048])PASS
RoPE 3D2.09x2.38x ([1×1024×16×64])PASS
GEGLU1.69x1.93x ([2×1024×8192])PASS
AdaLN3.00x3.67x ([4×4096×3072])PASS

RMSNorm bandwidth utilization: 483 GB/s (R9700 theoretical: ~608 GB/s, ~79%).

R9700 speedups are higher than MI355X because PyTorch's default RDNA4 backend is less mature, leaving more room for Triton optimization. The bandwidth utilization (79%) is also significantly better than MI355X (44%).

End-to-End LTX-Video (R9700, 25 frames, 30 steps)

ModeTime (s)Per Step (s)Peak Mem (GB)Speedup
baseline (mean of 3)6.910.23118.581.00x
triton (mean of 3)6.100.20318.581.13x
torch.compile (single run)5.050.16818.581.37x

Reviewer-facing benchmark files for this comparison live in examples/ltx-video-benchmark/, including:

  • Summary table with gen_time_s, time_per_step_s, peak_memory_gb, and speedup
  • Consolidated JSON results in examples/ltx-video-benchmark/benchmark_results.json
  • OpenCode run outputs in examples/ltx-video-benchmark/trace/opencode_live/results.json
  • OpenCode parsed trace in examples/ltx-video-benchmark/trace/opencode_live/opencode_trace_result.json

R9700 Additional Validation

TestResult
Transformers injection (TinyLlama 1.1B)PASS — 45 RMSNorm patched, 99.9 tokens/s
HuggingFace Kernels Hub integrationPASS — Hub kernel loads and runs on ROCm
Local Triton vs Hub kernel (small shape)Local 5.92x vs Hub 1.27x (lower launch overhead)
Local Triton vs Hub kernel (large shape)Local 3.59x vs Hub 3.57x (comparable)
num_warps sweep (2/4/8/16/32)Default heuristic (4/8/16) is near-optimal; nw=32 always worst
rocprof kernel fusion analysisTriton fuses 4 PyTorch kernels (pow+mean+rsqrt+mul) into 1

CUDA Reference (H100, for comparison)

ShapeCustom (ms)PyTorch (ms)Speedup
[1×1024×2048]0.0190.0653.37x
[2×4096×3072]0.0870.2082.41x

H100 E2E: ~6% (RMSNorm is ~5% of total compute).

Optimization Targets

KernelMI355XR9700TargetPriority
RMSNorm1.71x2.90x>3x (R9700)P0 — MI355X bandwidth util (44%→60%+)
AdaLN2.22x3.00x>3.5x (R9700)P1 — already strong on both
GEGLU1.43x1.69x>2xP1 — tanh overhead
RoPE 3D1.21x2.09x>2.5x (R9700)P2 — small head_dim launch overhead

Common Issues on ROCm

IssueSymptomFix
Autotune BLOCK_DWrong results (max_abs 4-8+)Never autotune BLOCK_D. Use triton.next_power_of_2(D)
RoPE batch OOBGPU crash (Memory access fault)Use pid_s % seq_len for cos/sin indexing
tl.libdeviceNot found on AMDUse manual math formulas
tl.tanh / tl.math.tanhNot on ROCmManual: e2x=exp(2x); (e2x-1)/(e2x+1)
Python min/maxRuntime errortl.minimum()/tl.maximum()
LDS overflowHIP OOMReduce num_stages to 2
Weight is NoneAttributeErrorCheck elementwise_affine
isinstance() missRMSNorm not patchedUse type(module).__name__

See troubleshooting.md for all common issues.

Performance Profiling

rocprof --stats python your_kernel.py
rocprofv3 -i metrics.txt python your_kernel.py
rocm-bandwidth-test
rocminfo | grep -E "Name|Compute Unit|Wavefront"

See Also

Benchmark & Test Scripts

Integration Guides

GPU Optimization Guides

Reference

External Resources

Mais skills de huggingface

Hugging Face Cli
huggingface
Execute Hugging Face Hub operations using the `hf` CLI. Use when the user needs to download models/datasets/spaces, upload files to Hub repositories, create repos, manage local cache, or run compute jobs on HF infrastructure. Covers authentication, file transfers, repository creation, cache operations, and cloud compute.
official
Hugging Face Datasets
huggingface
Criar e gerenciar datasets no Hugging Face Hub. Suporta inicialização de repositórios, definição de configurações/prompts de sistema, atualização de linhas em streaming e consulta/transformação de datasets baseada em SQL. Projetado para funcionar junto ao servidor MCP do HF para fluxos de trabalho abrangentes com datasets.
official
Hugging Face Evaluation
huggingface
Adicionar e gerenciar resultados de avaliação em model cards do Hugging Face. Suporta extração de tabelas de avaliação do conteúdo do README, importação de pontuações da API Artificial Analysis e execução de avaliações personalizadas de modelos com vLLM/lighteval. Funciona com o formato de metadados model-index.
official
Hugging Face Jobs
huggingface
Execute qualquer workload na infraestrutura de Hugging Face Jobs. Abrange scripts UV, jobs baseados em Docker, seleção de hardware, estimativa de custos, autenticação com tokens, gerenciamento de segredos, configuração de timeout e persistência de resultados. Projetado para workloads de computação de uso geral, incluindo processamento de dados, inferência, experimentos, jobs em lote e qualquer tarefa baseada em Python.
official
Hugging Face Model Trainer
huggingface
Treine ou ajuste modelos de linguagem usando TRL (Transformer Reinforcement Learning) na infraestrutura de Jobs do Hugging Face. Abrange os métodos de treinamento SFT, DPO, GRPO e modelagem de recompensa, além da conversão para GGUF para implantação local. Inclui orientações sobre preparação de conjuntos de dados, seleção de hardware, estimativa de custos e persistência de modelos.
official
Hugging Face Paper Publisher
huggingface
Publique e gerencie artigos de pesquisa no Hugging Face Hub. Suporta a criação de páginas de artigos, vinculação de artigos a modelos/conjuntos de dados, reivindicação de autoria e geração de artigos de pesquisa profissionais baseados em markdown.
official
Hugging Face Tool Builder
huggingface
Construa scripts e ferramentas reutilizáveis usando a API do Hugging Face. Útil ao encadear ou combinar chamadas de API, ou quando tarefas forem repetidas/automatizadas. Cria scripts de linha de comando reutilizáveis para buscar, enriquecer ou processar dados do Hugging Face Hub.
official
Hugging Face Trackio
huggingface
Acompanhe e visualize experimentos de treinamento de ML com o Trackio. Use ao registrar métricas durante o treinamento (API Python) ou ao recuperar/analisar métricas registradas (CLI). Suporta visualização em dashboard em tempo real, sincronização com HF Space e saída JSON para automação.
official