PyTorch-FX Shaper
The MCP server provides shape of tensors to convert PyTorch code to einsum and einops
Agent Shaper
Agent Shaper extracts per-module tensor shape metadata from any PyTorch nn.Module and uses it to annotate source files — either with descriptive shape comments or by rewriting operations as torch.einsum / einops.
How it works
- Shape extraction (
fx_utils/get_fx_data.py) — runstorch.export+ShapePropon your module to capture the shape of every intermediate tensor in every workspace-defined module's forward pass, not just the inputs. - Manual annotation (
fx_utils/manual_annotate.py) — inserts the extracted shapes as inline comments at the end of each relevant source line. - LLM annotation (
fx_utils/llm_annotate.py) — feeds each manually-annotated module class to an LLM in parallel. Two modes:- COMMENT — rewrites comments with descriptive dimension names (e.g.
batch_size,seq_len,n_embd) and plain-English explanations of each transformation. - EINSUM — rewrites the entire module replacing matmuls and attention operations with
torch.einsum, collapsing intermediate reshapes where possible.
- COMMENT — rewrites comments with descriptive dimension names (e.g.
- Diff viewer — review LLM changes before accepting them, either in a Streamlit UI or directly in VS Code's native diff editor.
- MCP server (
agent_shaper/mcp_server.py) — exposes the full shape-extraction and rewrite-validation pipeline as MCP tools consumable by Claude Code or any MCP-compatible AI assistant.
All modules in a file are processed in parallel via asyncio. Everything outside the module classes (imports, dataclasses, config objects) is preserved unchanged in the output file.
Project structure
agent_shaper/
fx_utils/
get_fx_data.py # shape extraction via torch.export + ShapeProp
manual_annotate.py # inline shape comment insertion
llm_annotate.py # LLM-powered rewrite (COMMENT or EINSUM mode)
diff_viewer.py # Streamlit diff UI
mcp_server.py # MCP tool server for AI-assisted rewrites
examples/
transformer/ # GPT-2, LLaMA, Qwen3, Swin Transformer
model.py # nanoGPT reference implementation
model_einsum.py # einsum-rewritten version
llama.py / llama_einsum.py
qwen3.py / qwen3_einsum.py
swin_transformer.py / swin_transformer_einsum.py
run_transformer.py # example forward pass
alignment/ # standalone alignment loss references
alignment_losses.py # DPO/IPO/SimPO losses (reference einsum style)
dpo_losses_einsum.py
jsd.py / jsd_einsum.py
opd.py / opd_einsum.py
direct_alignment/ # full training-ready RLHF loss implementations
loss.py # DPO, IPO, SimPO, ORPO, KTO, APO-zero, APO-down
train.py # training loop
data.py # preference dataset loading
config.py # training configuration
Each *_einsum.py file is the einsum/einops-rewritten counterpart of the original, validated to produce numerically identical outputs (atol=1e-5).
Installation
Requires Python 3.12+.
python -m venv .venv
source .venv/bin/activate
pip install -e .
Usage
MCP server (recommended for AI-assisted rewrites)
The MCP server is the primary interface for using Agent Shaper with Claude Code. It exposes shape extraction and rewrite validation as tools the model can call directly.
Add it to your Claude Code MCP config:
{
"mcpServers": {
"agent-shaper": {
"command": "/path/to/.venv/bin/python",
"args": ["-m", "agent_shaper.mcp_server"]
}
}
}
Available tools:
| Tool | Purpose |
|---|---|
get_annotated_sources | Run a setup script, trace the model, return shape-annotated source for the requested files |
get_fx_shapes | Return raw FX shape data for modules and functions |
validate_rewrite | Check a rewritten nn.Module class produces identical outputs (atol=1e-5) |
validate_rewrite_function | Check a rewritten standalone function produces identical outputs |
validate_file_rewrite | Validate all classes in a rewritten file in one call |
save_fixtures | Capture and persist forward-pass fixtures for later testing |
generate_test_files | Generate pytest files that assert output identity against saved fixtures |
Typical rewrite workflow:
- Call
get_annotated_sourceswith a setup script (assignsmodel,example_args, optionaldim_names) and the list of source files to annotate. - Read the shape-annotated snippets returned for each class and function.
- Rewrite the code using
einsum/einops. - Call
validate_rewrite(for classes) orvalidate_rewrite_function(for standalone functions) to confirm numerical identity before writing to disk.
Setup script contract — the script must assign:
model: annn.Moduleinstance to traceexample_args: a tuple of example tensors matchingforward()'s signaturedim_names(optional):dictmapping symbolic dim names to integer values (e.g.{"B": 4, "T": 16}) so shapes show(B, T)instead of(4, 16)
Manual shape annotation
import torch
from agent_shaper.transformer.model import GPT, GPTConfig
from agent_shaper.fx_utils.manual_annotate import annotate_module_source
cfg = GPTConfig(block_size=32, vocab_size=256, n_layer=2, n_head=2, n_embd=64, dropout=0.0, bias=True)
B, T = 2, 16
example_args = (
torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long),
torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long),
)
annotated = annotate_module_source(
GPT(cfg),
example_args,
dim_names={"B": B, "T": T},
output_dir="annotated_output", # writes annotated files here; omit to just get the dict back
)
Or run the built-in smoke test directly:
python -m agent_shaper.fx_utils.manual_annotate
Output files are written to annotated_output/ preserving the original relative path structure.
LLM annotation
Set environment variables first:
export OPENAI_API_KEY=sk-...
export OPENAI_MODEL=gpt-4o
export OPENAI_BASE_URL=https://your-proxy/v1 # optional; omit for default OpenAI
import asyncio, torch
from agent_shaper.transformer.model import GPT, GPTConfig
from agent_shaper.fx_utils.llm_annotate import llm_annotate_module_source, AnnotationMode
cfg = GPTConfig(block_size=32, vocab_size=256, n_layer=2, n_head=2, n_embd=64, dropout=0.0, bias=True)
B, T = 2, 16
example_args = (
torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long),
torch.randint(0, cfg.vocab_size, (B, T), dtype=torch.long),
)
async def main():
annotated = await llm_annotate_module_source(
GPT(cfg),
example_args,
mode=AnnotationMode.COMMENT, # or AnnotationMode.EINSUM
dim_names={"B": B, "T": T},
output_dir="llm_annotated_output",
)
asyncio.run(main())
Or run the built-in smoke test:
python -m agent_shaper.fx_utils.llm_annotate
Reviewing diffs
By default, after generating the LLM-annotated file, VS Code opens automatically showing a side-by-side diff of the original vs. the rewritten file (open_in_vscode=True). Pass open_in_vscode=False to suppress this.
You can also use the Streamlit diff viewer for a browser-based review:
.venv/bin/streamlit run agent_shaper/fx_utils/diff_viewer.py
Enter the path to the original file on the left and the generated file (e.g. llm_annotated_output/agent_shaper/transformer/model.py) on the right. The viewer renders a syntax-highlighted unified diff.
dim_names
The optional dim_names parameter maps symbolic names to their concrete values in the example run. This lets the shape annotations show (B, T, n_embd) instead of (2, 16, 64). When two names share the same value (e.g. B=2 and n_head=2), the annotation shows B/n_head.
dim_names = {"B": 2, "T": 16, "n_embd": 64, "n_head": 2}
get_module_shapes directly
from agent_shaper.fx_utils import get_module_shapes, TensorInfo
module_infos = get_module_shapes(model, example_args, dim_names=dim_names)
for info in module_infos:
print(info.class_name, info.source_file, info.line_start, info.line_end)
for t in info.tensors:
print(" ", t.name, t.shape, t.annotated_shape)
Each ModuleInfo contains:
class_name,module_origin,source_file,line_start,line_endparameters— list ofTensorInfofornn.Parameterentries from__init__tensors— list ofTensorInfofor every intermediate FX node in the forward pass
Repeated module instances with identical shape sequences (e.g. transformer blocks) are deduplicated to one entry.
Examples
The examples/ folder contains worked rewrites across several model families, each paired with a validated einsum/einops version:
examples/transformer/— GPT-2, LLaMA, Qwen3, Swin Transformer. Each model has a*_einsum.pycounterpart rewritten withtorch.einsumandeinops.examples/alignment/— standalone alignment loss functions (DPO, IPO, SimPO, JSD, OPD) in both original and einsum form.examples/direct_alignment/— production-style RLHF loss implementations (DPO, cDPO, IPO, SimPO, ORPO, KTO, APO-zero, APO-down) asnn.Moduleclasses with a full training loop. Theloss.pyfile useseinops.einsumandeinops.reducethroughout, validated against the gather-based originals at atol=1e-5.
All *_einsum.py rewrites were validated using the validate_rewrite / validate_rewrite_function MCP tools.
相關伺服器
Kone.vc
贊助Monetize your AI agent with contextual product recommendations
GoHighLevel
Integrate GoHighLevel with AI assistants like Claude and ChatGPT using a private API key.
obsidian-brain
Standalone Obsidian MCP server with semantic search, knowledge graph analytics (PageRank, Louvain, shortest path), and vault editing — no plugin, no REST API, works when Obsidian is closed.
Obsidian iCloud MCP
Access and manage Obsidian notes stored in iCloud Drive.
DaVinci Resolve MCP
An MCP server integration for the DaVinci Resolve video editing software.
Jira
An MCP server for interacting with Jira's REST API to manage projects, issues, and users.
Nextvisit
Query and edit patient charts, encounters, medications, and clinical documents from Nextvisit, the AI documentation platform for behavioral health.
mcp-walmart-ads
MCP server for Walmart Connect Ads (Sponsored Search + Display) — automatic RSA-SHA256 signing, multi-region config, and bundled API docs.
notebooklm MCP
Chat with Google NotebookLM via MCP or HTTP REST API for zero-hallucination answers from your docs. Perfect for n8n workflows and automation.
vidmagik-mcp
An un-official moviepy mcp server giving your Agents the abillity to edit,master, & re-master Video, Slideshows, and Gif's
MCP Personal Assistant Agent
A versatile AI personal assistant for managing your calendar, tasks, emails, web searches, and smart home.