rocm-kernels作成者: huggingface

Provides guidance for writing and benchmarking optimized Triton kernels for AMD GPUs (MI355X, R9700) on ROCm, targeting HuggingFace diffusers (LTX-Video, SD3,…

npx skills add https://github.com/huggingface/kernels --skill rocm-kernels

ROCm Triton Kernels for Diffusers & Transformers

This skill provides patterns and guidance for developing optimized Triton kernels targeting AMD GPUs (MI355X, R9700) on ROCm, for use with HuggingFace diffusers (LTX-Video, SD3, FLUX) and transformers libraries.

Quick Start

Diffusers (LTX-Video)

Inject optimized kernels into LTX-Video pipeline:

import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'

from diffusers import LTXPipeline
pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")  # ROCm uses same API via HIP
inject_optimized_kernels(pipe)  # BEFORE CPU offloading
pipe.enable_model_cpu_offload()

For a minimal integration example (~150 lines):

python scripts/ltx_kernel_injection_example.py

Isolated Kernel Micro-benchmarks

# All 4 kernels: correctness + performance + bandwidth
python scripts/benchmark_kernels.py

# Single kernel
python scripts/benchmark_kernels.py --kernel rmsnorm
python scripts/benchmark_kernels.py --kernel rope
python scripts/benchmark_kernels.py --kernel geglu
python scripts/benchmark_kernels.py --kernel adaln

End-to-End Pipeline Benchmark

# Compare baseline vs Triton vs torch.compile
python scripts/benchmark_e2e.py --mode all

# Quick test
python scripts/benchmark_e2e.py --mode triton --num-frames 9 --steps 5

# Save results for comparison
python scripts/benchmark_e2e.py --mode all --output-json results.json

Target Model: LTX-Video

Architecture Overview

ComponentClassHas WeightCountKernel
transformer_blocks.*.norm1RMSNormNo (elementwise_affine=False)56RMSNorm
transformer_blocks.*.norm2RMSNormNo56RMSNorm
transformer_blocks.*.attn1.norm_qtorch.nn.RMSNormYes28RMSNorm
transformer_blocks.*.attn1.norm_ktorch.nn.RMSNormYes28RMSNorm
transformer_blocks.*.ffFeedForward-28GELU (not GEGLU!)
Rotary position encodingLTXVideoRotaryPosEmbed-1RoPE 3D

Total RMSNorm modules: 168 (56 with weights, 112 without)

Target Kernels

KernelUse CaseInput LayoutKey Challenge
RMSNormNormalization[..., hidden_size]Weight may be None; 168 instances
RoPE 3DVideo position encoding[batch, t*h*w, heads, head_dim]3D → temporal + spatial decomposition
GEGLUGated activation (SD3/FLUX)[batch, seq, 2*hidden][batch, seq, hidden]Gate/value split
AdaLNConditioned normalization (DiT)norm(x) * weight * (1+scale) + shiftFused norm + condition

Supported Hardware

GPUArchitectureWave SizeLDS/CUMem BWKey FeatureVerified
MI355XCDNA3+ (gfx950)Wave64160 KB8 TB/s32 XCDs, XCD Swizzle for GEMMYes
R9700RDNA4 (gfx1201)Wave3264 KB~608 GB/s256B cacheline, inference-focusedYes

See MI355X guide | R9700 guide

When This Skill Applies

Use this skill when:

  • Writing Triton kernels for RMSNorm, RoPE, GEGLU, AdaLN on AMD GPUs
  • Integrating custom kernels with diffusers pipelines (LTX-Video, SD3, FLUX)
  • Benchmarking kernel performance against PyTorch baseline on ROCm
  • Optimizing existing kernels for MI355X or R9700 architecture
  • Debugging ROCm/HIP-specific kernel issues

Critical ROCm Constraints

Things That DON'T Work on AMD

# FORBIDDEN - CUDA only, NOT available on ROCm
tl.libdevice.tanh(x)          # Use manual formula below
tl.libdevice.log1p(x)         # Use: tl.log(1.0 + x)
tl.math.tanh(x)               # Also NOT available on ROCm Triton

# Manual tanh (ONLY reliable method on ROCm):
e2x = tl.exp(2.0 * x)
tanh_x = (e2x - 1.0) / (e2x + 1.0)

# FORBIDDEN - Triton limitations on ROCm
break / continue               # Use: tl.where()
min(a, b) / max(a, b)          # Use: tl.minimum(a, b) / tl.maximum(a, b)

Mandatory Environment Variables

import os
os.environ['TRITON_HIP_USE_BLOCK_PINGPONG'] = '1'
os.environ['TRITON_HIP_USE_ASYNC_COPY'] = '1'

Core Kernel Implementations

1. RMSNorm (Core Optimization Target)

Row-wise reduction pattern. 168 instances in LTX-Video, ~5% of total compute.

CRITICAL: Do NOT autotune BLOCK_D. Autotune may pick BLOCK_D < D, causing partial row processing and wrong results. Always compute BLOCK_D = triton.next_power_of_2(D) in the Python wrapper.

@triton.jit
def rmsnorm_kernel(
    x_ptr, weight_ptr, out_ptr,
    stride_x, D,
    eps: tl.constexpr,
    HAS_WEIGHT: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_D)
    mask = offs < D
    x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)

    variance = tl.sum(x * x, axis=0) / D
    rms_inv = tl.rsqrt(variance + eps)

    if HAS_WEIGHT:
        w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
        out = x * rms_inv * w
    else:
        out = x * rms_inv

    tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)


def triton_rmsnorm(x, weight=None, eps=1e-6):
    x_2d = x.contiguous().view(-1, x.shape[-1])
    out = torch.empty_like(x_2d)
    M, D = x_2d.shape
    has_weight = weight is not None
    if not has_weight:
        weight = torch.empty(0, device=x.device)

    BLOCK_D = triton.next_power_of_2(D)
    num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
    rmsnorm_kernel[(M,)](
        x_2d, weight, out, x_2d.stride(0), D, eps, has_weight,
        BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
    )
    return out.view_as(x)

LTX-Video pitfall: Weight may be None!

has_weight = hasattr(module, 'weight') and module.weight is not None

2. RoPE 3D (Video Position Encoding)

Element-wise pattern. LTX-Video splits head_dim into temporal + spatial components.

CRITICAL: cos/sin have shape [seq_len, head_dim]. When grid flattens batch dimension (batch * seq_len), use pid_s % seq_len to index cos/sin, otherwise batch > 1 causes OOB GPU crash.

@triton.jit
def rope_3d_kernel(
    qk_ptr, cos_ptr, sin_ptr, out_ptr,
    seq_len, num_heads, head_dim,
    stride_s, stride_h, stride_d,
    BLOCK_HD: tl.constexpr,
):
    pid_s = tl.program_id(0)  # batch * seq_len
    pid_h = tl.program_id(1)  # head index
    half_dim = head_dim // 2
    offs = tl.arange(0, BLOCK_HD)
    mask = offs < half_dim

    base = pid_s * stride_s + pid_h * stride_h
    x0 = tl.load(qk_ptr + base + offs, mask=mask, other=0.0).to(tl.float32)
    x1 = tl.load(qk_ptr + base + half_dim + offs, mask=mask, other=0.0).to(tl.float32)

    seq_idx = pid_s % seq_len  # wrap for batch > 1
    cos_val = tl.load(cos_ptr + seq_idx * head_dim + offs, mask=mask, other=1.0).to(tl.float32)
    sin_val = tl.load(sin_ptr + seq_idx * head_dim + offs, mask=mask, other=0.0).to(tl.float32)

    out0 = x0 * cos_val - x1 * sin_val
    out1 = x0 * sin_val + x1 * cos_val

    tl.store(out_ptr + base + offs, out0.to(x0.dtype), mask=mask)
    tl.store(out_ptr + base + half_dim + offs, out1.to(x0.dtype), mask=mask)


def triton_rope_3d(qk, cos, sin):
    qk = qk.contiguous()
    out = torch.empty_like(qk)
    batch, seq_len, num_heads, head_dim = qk.shape
    qk_flat = qk.view(batch * seq_len, num_heads, head_dim)
    out_flat = out.view(batch * seq_len, num_heads, head_dim)
    BLOCK_HD = triton.next_power_of_2(head_dim // 2)
    num_warps = 4 if BLOCK_HD <= 64 else 8
    rope_3d_kernel[(batch * seq_len, num_heads)](
        qk_flat, cos, sin, out_flat,
        seq_len, num_heads, head_dim,
        qk_flat.stride(0), qk_flat.stride(1), qk_flat.stride(2),
        BLOCK_HD=BLOCK_HD, num_warps=num_warps, num_stages=2,
    )
    return out

3. GEGLU (For SD3/FLUX, NOT LTX-Video)

Element-wise gated activation. Input [batch, seq, 2*hidden] → Output [batch, seq, hidden].

Same BLOCK_SIZE rule: compute dynamically, do NOT autotune.

@triton.jit
def geglu_kernel(
    input_ptr, output_ptr,
    stride_in, stride_out, hidden_size,
    BLOCK_H: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_H)
    mask = offs < hidden_size

    gate = tl.load(input_ptr + row * stride_in + offs, mask=mask, other=0.0).to(tl.float32)
    value = tl.load(input_ptr + row * stride_in + hidden_size + offs, mask=mask, other=0.0).to(tl.float32)

    # GELU approx — manual tanh (tl.math.tanh NOT available on ROCm)
    k = 0.7978845608028654  # sqrt(2/pi)
    tanh_arg = k * (gate + 0.044715 * gate * gate * gate)
    e2x = tl.exp(2.0 * tanh_arg)
    tanh_val = (e2x - 1.0) / (e2x + 1.0)
    gate_gelu = 0.5 * gate * (1.0 + tanh_val)
    result = gate_gelu * value

    tl.store(output_ptr + row * stride_out + offs, result.to(gate.dtype), mask=mask)


def triton_geglu(x):
    x = x.contiguous()
    *batch_dims, double_h = x.shape
    hidden_size = double_h // 2
    x_2d = x.view(-1, double_h)
    M = x_2d.shape[0]
    out = torch.empty(M, hidden_size, device=x.device, dtype=x.dtype)
    BLOCK_H = triton.next_power_of_2(hidden_size)
    num_warps = 4 if BLOCK_H <= 1024 else (8 if BLOCK_H <= 4096 else 16)
    geglu_kernel[(M,)](
        x_2d, out, x_2d.stride(0), out.stride(0), hidden_size,
        BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=2,
    )
    return out.view(*batch_dims, hidden_size)

Warning: LTX-Video uses GELU, NOT GEGLU. GEGLU is for SD3/FLUX.

4. AdaLN (Adaptive Layer Normalization for DiT)

Fused normalization + conditioning: norm(x) * weight * (1 + scale) + shift

Same BLOCK_D rule: compute dynamically.

@triton.jit
def adaln_kernel(
    x_ptr, weight_ptr, scale_ptr, shift_ptr, out_ptr,
    stride_x, stride_cond, D,
    eps: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    row = tl.program_id(0)
    offs = tl.arange(0, BLOCK_D)
    mask = offs < D
    x = tl.load(x_ptr + row * stride_x + offs, mask=mask, other=0.0).to(tl.float32)

    variance = tl.sum(x * x, axis=0) / D
    rms_inv = tl.rsqrt(variance + eps)
    x_norm = x * rms_inv

    w = tl.load(weight_ptr + offs, mask=mask, other=1.0).to(tl.float32)
    scale = tl.load(scale_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)
    shift = tl.load(shift_ptr + row * stride_cond + offs, mask=mask, other=0.0).to(tl.float32)

    out = x_norm * w * (1.0 + scale) + shift
    tl.store(out_ptr + row * stride_x + offs, out.to(x.dtype), mask=mask)


def triton_adaln(x, weight, scale, shift, eps=1e-6):
    x_flat = x.contiguous().view(-1, x.shape[-1])
    scale_flat = scale.contiguous().view(-1, x.shape[-1])
    shift_flat = shift.contiguous().view(-1, x.shape[-1])
    out = torch.empty_like(x_flat)
    M, D = x_flat.shape
    BLOCK_D = triton.next_power_of_2(D)
    num_warps = 4 if BLOCK_D <= 1024 else (8 if BLOCK_D <= 4096 else 16)
    adaln_kernel[(M,)](
        x_flat, weight, scale_flat, shift_flat, out,
        x_flat.stride(0), scale_flat.stride(0), D, eps,
        BLOCK_D=BLOCK_D, num_warps=num_warps, num_stages=2,
    )
    return out.view_as(x)

Diffusers Integration

See diffusers-integration.md for the complete guide.

Minimal Integration Pattern

def patch_rmsnorm_modules(model):
    """Patch all RMSNorm modules to use custom Triton kernel."""
    for name, module in model.named_modules():
        if type(module).__name__ == 'RMSNorm':
            eps = getattr(module, 'eps', 1e-6)
            has_weight = hasattr(module, 'weight') and module.weight is not None
            if has_weight:
                def make_forward(mod, epsilon):
                    def forward(x):
                        return triton_rmsnorm(x, mod.weight, eps=epsilon)
                    return forward
                module.forward = make_forward(module, eps)
            else:
                def make_forward(epsilon):
                    def forward(x):
                        w = torch.ones(x.shape[-1], device=x.device, dtype=x.dtype)
                        return triton_rmsnorm(x, w, eps=epsilon)
                    return forward
                module.forward = make_forward(eps)

pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
pipe.to("cuda")
patch_rmsnorm_modules(pipe.transformer)
pipe.enable_model_cpu_offload()

Diffusers Critical Pitfalls

  1. RMSNorm weight may be None — LTX-Video uses elementwise_affine=False
  2. Diffusers RMSNorm != torch.nn.RMSNorm — Use type(module).__name__ not isinstance()
  3. LTX-Video uses GELU, not GEGLU — Don't patch GEGLU for LTX-Video
  4. Inject BEFORE CPU offloadinginject_kernels() then enable_model_cpu_offload()

Performance Expectations

Micro-benchmark Results (MI355X, BF16)

KernelAvg SpeedupBest Config SpeedupStatus
RMSNorm1.71x2.44x ([4×4096×3072])PASS
RoPE 3D1.21x1.52x ([2×4096×16×128])PASS
GEGLU1.43x2.13x ([4×4096×8192])PASS
AdaLN2.22x2.77x ([4×4096×3072])PASS

RMSNorm bandwidth utilization: 3554 GB/s (MI355X theoretical: 8 TB/s, ~44%).

End-to-End LTX-Video (MI355X, 25 frames, 30 steps)

ModeTime (s)Per Step (s)Peak Mem (GB)Speedup
baseline1.200.04018.581.00x
triton0.980.03318.581.22x
torch.compile0.780.02618.581.54x

Key finding: MI355X Triton E2E speedup (22%) is significantly higher than H100 CUDA reference (6%), because MI355X's default PyTorch RMSNorm path has more room for optimization.

Micro-benchmark Results (R9700, BF16)

KernelAvg SpeedupBest Config SpeedupStatus
RMSNorm2.90x3.97x ([1×8192×2048])PASS
RoPE 3D2.09x2.38x ([1×1024×16×64])PASS
GEGLU1.69x1.93x ([2×1024×8192])PASS
AdaLN3.00x3.67x ([4×4096×3072])PASS

RMSNorm bandwidth utilization: 483 GB/s (R9700 theoretical: ~608 GB/s, ~79%).

R9700 speedups are higher than MI355X because PyTorch's default RDNA4 backend is less mature, leaving more room for Triton optimization. The bandwidth utilization (79%) is also significantly better than MI355X (44%).

End-to-End LTX-Video (R9700, 25 frames, 30 steps)

ModeTime (s)Per Step (s)Peak Mem (GB)Speedup
baseline (mean of 3)6.910.23118.581.00x
triton (mean of 3)6.100.20318.581.13x
torch.compile (single run)5.050.16818.581.37x

Reviewer-facing benchmark files for this comparison live in examples/ltx-video-benchmark/, including:

  • Summary table with gen_time_s, time_per_step_s, peak_memory_gb, and speedup
  • Consolidated JSON results in examples/ltx-video-benchmark/benchmark_results.json
  • OpenCode run outputs in examples/ltx-video-benchmark/trace/opencode_live/results.json
  • OpenCode parsed trace in examples/ltx-video-benchmark/trace/opencode_live/opencode_trace_result.json

R9700 Additional Validation

TestResult
Transformers injection (TinyLlama 1.1B)PASS — 45 RMSNorm patched, 99.9 tokens/s
HuggingFace Kernels Hub integrationPASS — Hub kernel loads and runs on ROCm
Local Triton vs Hub kernel (small shape)Local 5.92x vs Hub 1.27x (lower launch overhead)
Local Triton vs Hub kernel (large shape)Local 3.59x vs Hub 3.57x (comparable)
num_warps sweep (2/4/8/16/32)Default heuristic (4/8/16) is near-optimal; nw=32 always worst
rocprof kernel fusion analysisTriton fuses 4 PyTorch kernels (pow+mean+rsqrt+mul) into 1

CUDA Reference (H100, for comparison)

ShapeCustom (ms)PyTorch (ms)Speedup
[1×1024×2048]0.0190.0653.37x
[2×4096×3072]0.0870.2082.41x

H100 E2E: ~6% (RMSNorm is ~5% of total compute).

Optimization Targets

KernelMI355XR9700TargetPriority
RMSNorm1.71x2.90x>3x (R9700)P0 — MI355X bandwidth util (44%→60%+)
AdaLN2.22x3.00x>3.5x (R9700)P1 — already strong on both
GEGLU1.43x1.69x>2xP1 — tanh overhead
RoPE 3D1.21x2.09x>2.5x (R9700)P2 — small head_dim launch overhead

Common Issues on ROCm

IssueSymptomFix
Autotune BLOCK_DWrong results (max_abs 4-8+)Never autotune BLOCK_D. Use triton.next_power_of_2(D)
RoPE batch OOBGPU crash (Memory access fault)Use pid_s % seq_len for cos/sin indexing
tl.libdeviceNot found on AMDUse manual math formulas
tl.tanh / tl.math.tanhNot on ROCmManual: e2x=exp(2x); (e2x-1)/(e2x+1)
Python min/maxRuntime errortl.minimum()/tl.maximum()
LDS overflowHIP OOMReduce num_stages to 2
Weight is NoneAttributeErrorCheck elementwise_affine
isinstance() missRMSNorm not patchedUse type(module).__name__

See troubleshooting.md for all common issues.

Performance Profiling

rocprof --stats python your_kernel.py
rocprofv3 -i metrics.txt python your_kernel.py
rocm-bandwidth-test
rocminfo | grep -E "Name|Compute Unit|Wavefront"

See Also

Benchmark & Test Scripts

Integration Guides

GPU Optimization Guides

Reference

External Resources

huggingfaceのその他のスキル

Hugging Face Cli
by huggingface
Execute Hugging Face Hub operations using the `hf` CLI. Use when the user needs to download models/datasets/spaces, upload files to Hub repositories, create repos, manage local cache, or run compute jobs on HF infrastructure. Covers authentication, file transfers, repository creation, cache operations, and cloud compute.
Hugging Face Datasets
by huggingface
Create and manage datasets on Hugging Face Hub. Supports initializing repos, defining configs/system prompts, streaming row updates, and SQL-based dataset querying/transformation. Designed to work alongside HF MCP server for comprehensive dataset workflows.
Hugging Face Evaluation
by huggingface
Add and manage evaluation results in Hugging Face model cards. Supports extracting eval tables from README content, importing scores from Artificial Analysis API, and running custom model evaluations with vLLM/lighteval. Works with the model-index metadata format.
Hugging Face Jobs
by huggingface
Run any workload on Hugging Face Jobs infrastructure. Covers UV scripts, Docker-based jobs, hardware selection, cost estimation, authentication with tokens, secrets management, timeout configuration, and result persistence. Designed for general-purpose compute workloads including data processing, inference, experiments, batch jobs, and any Python-based tasks.
Hugging Face Model Trainer
by huggingface
Train or fine-tune language models using TRL (Transformer Reinforcement Learning) on Hugging Face Jobs infrastructure. Covers SFT, DPO, GRPO and reward modeling training methods, plus GGUF conversion for local deployment. Includes guidance on dataset preparation, hardware selection, cost estimation, and model persistence.
Hugging Face Paper Publisher
by huggingface
Publish and manage research papers on Hugging Face Hub. Supports creating paper pages, linking papers to models/datasets, claiming authorship, and generating professional markdown-based research articles.
Hugging Face Tool Builder
by huggingface
Build reusable scripts and tools using the Hugging Face API. Useful when chaining or combining API calls, or when tasks will be repeated/automated. Creates reusable command line scripts to fetch, enrich, or process data from Hugging Face Hub.
Hugging Face Trackio
by huggingface
Track and visualize ML training experiments with Trackio. Use when logging metrics during training (Python API) or retrieving/analyzing logged metrics (CLI). Supports real-time dashboard visualization, HF Space syncing, and JSON output for automation.