TensorRT-LLM Code Contribution Best Practices
Contribution Process
1. Developer Workflow
- Commit the changes. Never commit using NVIDIA internal email ()!
- Push changes to a branch on the personal fork:
bash
git push -u <user> <local-branch>:<remote-branch>
- Create a PR from the fork branch into upstream (typically targeting ).
2. Coding Guidelines
TRT-LLM coding style is defined in
. Key highlights:
C++: Allman brace style, 4-space indent, 120 char line limit, camelCase for variables/methods, PascalCase for types,
prefix for member variables,
prefix for constants, Doxygen for API docs, smart pointers over raw,
over
, no C-style casts.
Python: snake_case for files/functions/variables, PascalCase for classes, UPPER_SNAKE_CASE for constants, 4-space indent, Google-style docstrings, narrow
clauses, Pydantic
for user-facing config classes (no custom
).
3. Pre-commit Setup
bash
pip install pre-commit
pre-commit install
Pre-commit runs automatically on every
. Hooks include: isort, yapf, autoflake, clang-format, cmake-format, codespell, ruff, ruff-format, mdformat, and others. If hooks modify files, stage and commit them again.
4. DCO Sign-off (Required)
All commits must be signed off to certify the contribution under the
Developer Certificate of Origin:
bash
git commit -s -m "Add cool feature."
This appends
Signed-off-by: Your Name <your@email.com>
to the commit message. PRs containing unsigned commits will not be accepted.
IMPORTANT: Never sign off commits using NVIDIA internal email (
)!
Pre-Implementation Checklist
Before writing any code, complete these steps:
1. Survey Existing Infrastructure
Search before building. TRT-LLM is a large codebase with many reusable components. Before implementing something from scratch, search for existing utilities:
# Before writing a new attention computation
grep -r "TrtllmAttention\|create_attention\|scaled_dot_product" tensorrt_llm/_torch/
# Before writing a new compiled helper
grep -r "maybe_compile\|maybe_compiled_" tensorrt_llm/_torch/utils.py
# Before writing a custom RoPE
grep -r "RotaryEmbedding\|rotary_emb\|rope" tensorrt_llm/_torch/modules/
# Before writing a new cache management pattern
grep -r "mla_rope_append_paged_kv\|append_paged_kv" tensorrt_llm/_torch/
Trace existing forward methods. Before writing a new
method, read all existing forward methods in the class and understand what each one does. Often an existing method already implements the computation you need, and you just need to set up the right state (e.g., create an attribute, adjust a guard) to dispatch to it.
# Find all forward methods in a class
grep -n "def forward" tensorrt_llm/_torch/modules/attention.py
# Then READ each one to understand what it does
Lesson learned: On the short-seq MHA branch (30 commits, ~250 lines written then deleted), the attention computation went through
4 rewrites: per-sequence SDPA loop → batched SDPA with pad_sequence → custom TrtllmAttention backend → deletion in favor of the
already-existing forward_context_default()
. The final approach was +10 lines: a guard check + dispatch to an existing method. Similarly,
was discovered only after a standalone
wrapper was written and then removed.
Anti-pattern: Parallel reimplementation. Before writing a new
method, trace what existing forward methods do. The new method may already be implemented. In the MLA case,
forward_context_short_mha
reimplemented
nearly line-for-line before being deleted.
2. Check Parallelism Dimensions
When adding a new code path, verify correctness under ALL parallelism modes:
| Dimension | Guard | Why |
|---|
| Tensor Parallelism (TP) | | Head counts are sharded |
| Pipeline Parallelism (PP) | | Layers may be on different ranks |
| Context Parallelism (CP) | | Sequence is split across ranks — tokens are not all local |
| Expert Parallelism (EP) | | MoE experts distributed |
Lesson learned: The short-seq MHA path assumed all tokens were local, which breaks under Context Parallelism. The
guard was added as a fix in a later commit instead of being part of the initial design.
3. Think About Threshold/Guard Semantics
When gating a code path with a threshold:
- What does the threshold measure? (per-sequence metric? total batch metric?)
- What does the cost of the path scale with? (per-sequence? total tokens? quadratic in something?)
- Do these match? If cost scales with total tokens, the threshold should check total tokens, not per-sequence max.
Lesson learned: The initial implementation checked
(longest single sequence) against the threshold, but the cost of the short-seq path scales with total packed tokens. A batch of 100 short sequences could incorrectly trigger the path.
4. Check RoPE State
When adding attention code paths:
- Is True (caller handles RoPE) or False (rope_fusion, backend handles RoPE)?
- Does your path apply RoPE? Will that cause double-application?
- Do you need to handle both RoPE states or can you gate to one?
5. Trace Method Limitations
Understand what a method does NOT handle. When reusing an existing method, fully trace the dispatch chain above it. A method may be correct for one scenario but miss edge cases handled by a higher-level dispatcher.
Example: forward_context_default()
handles fresh prefill with no cached KV tokens. But when there are cached KV tokens (chunked context), it silently ignores them — causing a correctness bug. The fix was to call
instead, which dispatches to:
forward_context_with_chunked_prefill
(SM100+, chunked context)
forward_context_with_cached_kv
(SM90 fallback, or cached context)
- (fresh prefill, no cached tokens)
Checklist for reusing a method:
- What does this method handle?
- What does it NOT handle? (cached tokens? chunked prefill? specific hardware?)
- Is there a higher-level dispatcher that routes to this method for the right cases?
- Should I call the dispatcher instead of the method directly?
6. Check Hardware-Specific Behavior
The same algorithm can have different numerical properties across SM versions. FMHA kernels may use different internal implementations (e.g., online softmax merge on SM90 vs single-pass on SM100+) that produce different accuracy characteristics.
Lesson learned: The SM90 (Hopper) FMHA kernel's online softmax merge for chunked prefill diverged from the single-pass reference by ~0.45 max diff — unacceptable for a correctness-critical path. The fix was to gate chunked prefill behind
(Blackwell+) and fall back to
forward_context_with_cached_kv
on SM90.
When to check:
- Any new attention code path that uses fused kernels
- Any path that changes how attention is split/chunked (chunked prefill, context parallelism)
- When accuracy tolerances are tight and the path crosses hardware generations
Implementation Workflow
Use the Right Abstraction Level
Choose backends from this priority list:
- Existing forward method (e.g., ) — may already implement what you need; just set up state and dispatch
- Existing fused backend (e.g., , ) — handles packed sequences, variable lengths, KV cache natively
- PyTorch fused ops (e.g.,
F.scaled_dot_product_attention
) — good for prototyping but requires manual batching/padding
- Manual implementation — last resort, only when no existing backend fits
Use the Right Dispatch Abstraction Level
When dispatching to an existing method, use the highest-level dispatch point that provides the right abstraction. Don't bypass dispatch layers — you'll miss edge cases.
| Abstraction Level | Example | Handles |
|---|
| Top-level dispatcher | | Chunked prefill, cached KV, fresh prefill, SM-version gating |
| Specific handler | forward_context_default()
| Fresh prefill only |
| Backend directly | | Nothing beyond raw attention |
Lesson learned: The initial short-seq MHA implementation called
forward_context_default()
directly. This worked for fresh prefill but silently dropped cached KV tokens during chunked context. Switching to
(which dispatches to
forward_context_with_cached_kv
or
forward_context_with_chunked_prefill
as appropriate) fixed the bug with a 1-line change.
Prefer Reusing Existing Attributes Over Creating New Ones
When adding a new code path, check if an existing attribute can serve double duty:
python
# BAD: parallel attribute alongside existing one
self._short_seq_mha = create_attention(...) # separate from self.mha
# Then need special handling everywhere self.mha is referenced
# GOOD: reuse existing attribute with conditional initialization
if should_use_dense_mha:
self.mha = create_attention(...) # replaces None for DSA models
# Existing code paths that check self.mha just work
Lesson learned: The short-seq MHA initially used
as a separate attribute to "preserve the assertion that
". Later, it was realized the assertion itself should change (
) and
could be reused.
Run Pre-Commit Before Every Commit
Always run pre-commit run --all-files
before committing. The short-seq MHA branch had a 377-line formatting-only commit (commit 15/19) that existed solely because pre-commit wasn't run on earlier commits. This is wasted reviewer attention and pollutes
.
bash
# Before every commit:
pre-commit run --all-files
git add -u # stage any auto-formatted files
git commit -s -m "..."
Apply torch.compile Judiciously
| Pattern | Use ? | Why |
|---|
| Fused math (RoPE rotation, GELU) | Yes | Fuses multiple element-wise ops into one kernel |
| of computed tensors | Use | Already exists as a utility |
| Pure metadata ops (split, view, expand, reshape) | No | These are zero-cost; compile adds overhead |
| Mixed metadata + compute | Extract the compute part | Compile only what benefits from fusion |
Extract Shared Logic Immediately
When a condition appears in more than one place, extract it into a helper method in the same commit. Don't wait for a later refactoring commit.
python
# BAD: same 5-condition check in two places
if (threshold > 0 and not apply_rotary and cp_size == 1 and ...): # site 1
...
if (threshold > 0 and not apply_rotary and cp_size == 1 and ...): # site 2
...
# GOOD: extract immediately
def _should_use_short_mha(self, ...):
return (threshold > 0 and not apply_rotary and cp_size == 1 and ...)
Feature Flags for Complex Optimizations
Complex optimizations with multiple guards, edge cases, and hardware-specific behavior should ship disabled by default. Let users opt-in via environment variable after testing.
python
# Pattern: disabled by default (threshold=0), opt-in via env var
_threshold_str = os.environ.get('TRTLLM_MLA_SHORT_SEQ_MHA_THRESHOLD', '0')
self.short_seq_mha_threshold = int(_threshold_str)
Lesson learned: The short-seq MHA optimization was initially enabled by default (threshold=10240) at commit 8 but had 18 more correctness fixes over the next 22 commits before being disabled by default at commit 26. Complex optimizations accumulate edge cases (chunked context, SM90 accuracy, threshold semantics) that may not be discovered until broad testing.
When to disable by default:
- The optimization has 3+ guard conditions
- It touches attention/correctness-critical paths
- It has hardware-specific behavior (different SM versions)
- It hasn't been tested in full CI across all configurations
Update All References When Changing Semantics
When changing what a variable/threshold means, grep for ALL references:
bash
# After changing threshold from max_seq_len to total_packed_tokens:
grep -rn "max_seq_len\|max_ctx_seq_len\|short.*seq.*threshold" tests/ tensorrt_llm/
Update comments, docstrings, test descriptions, and variable names in the same commit.
Testing Strategy
When to Write Tests
| Phase | What to test | Why |
|---|
| After implementation stabilizes | Full correctness suite | Avoid rewriting tests with each iteration |
| During prototyping | Minimal smoke test only | Validates basic plumbing without coupling to implementation details |
| After optimization changes | Add regression tests for the specific optimization | Catches if the optimization breaks something |
Lesson learned: Tests were written before the attention backend was settled, then required 5 separate fix/update commits as the implementation evolved through 4 rewrites. The 770-line test file needed immediate fixing (device placement, weight layout bugs) because it was never run before committing.
Common Test Gotchas in TRT-LLM
-
Non-Module children aren't moved by : If a module has attributes that aren't
subclasses (e.g.,
DSATrtllmAttention.indexer
),
won't move their parameters. Move them explicitly.
-
Weight layout differs from HuggingFace: Model loading transforms weights. Initialize test weights in the
loaded layout (check
for load functions), not the HuggingFace checkpoint layout.
-
Background threads from cache managers:
and similar create
threads that outlive tests. Add
pytestmark = pytest.mark.threadleak(enabled=False)
at the module level.
-
misses non-Module attributes: When copying weights for A/B comparison tests, explicitly copy parameters from non-Module children (like indexer weights).
-
Attention metadata construction: Use the test fixtures/helpers already in the codebase (check
tests/unittest/_torch/attention/
for patterns) rather than building
from scratch.
Test Consolidation
After implementation stabilizes, aggressively prune tests to a minimal set where each parametrized case exercises a distinct code path.
Pattern:
- During development, write comprehensive tests (many parametrized cases covering all combinations)
- After implementation stabilizes, identify which code paths each test case exercises
- Merge cases that exercise the same code path; remove redundant cases
- Extract shared test helpers (, , ) to reduce duplication
Lesson learned: The short-seq MHA test file peaked at 1394 lines with 21 parametrized cases, then was consolidated to 665 lines with 10 cases covering the same 6 code paths. Three separate cleanup commits were needed because consolidation wasn't done in one pass. Do consolidation as a single deliberate pass.
Test on Multiple Hardware Targets
When testing attention kernels or fused operations, verify on multiple SM versions. The same kernel can have different numerical properties across hardware generations.
- SM90 (Hopper): Online softmax merge in FMHA — can diverge from reference
- SM100+ (Blackwell): Single-pass FMHA — tighter numerical accuracy
- Use guards to skip or adjust tests per hardware
Commit Hygiene
During Development
Commit freely — small, frequent commits help track progress and enable bisection.
Before PR Submission
Squash fix-on-fix chains using interactive rebase:
bash
# Fold fix commits into the commits they fix
git rebase -i $(git merge-base HEAD main)
Target commit structure for a PR:
- Core implementation — the new feature with all guards and edge cases
- Additional optimizations — one commit per distinct optimization
- Tests — comprehensive test suite
- Refactoring (optional) — cleanup that's separate from the feature
Anti-patterns to Avoid
| Anti-pattern | What happens | Prevention |
|---|
| Fix-on-fix chains (A → fix A → fix fix A) | Noisy history, hard to review | Squash before PR |
| Add-then-revert (add X → revert X) | Wasted reviewer attention | Survey existing utilities first |
| Modify shared utility then revert (edit rotary_embedding.py → revert) | Pollutes unrelated files | Check if existing code paths handle it |
| Create compiled helper then inline it (add @maybe_compile → remove) | Churn | Profile first; only compile proven bottlenecks |
| Semantic change + behavior change in one commit | Hard to bisect regressions | Separate bug fixes from feature changes |
| Stale comment fix as separate commit | Shows the comment wasn't updated with the code change | Update comments in the same commit as the code |
PR Title Format (Conventional Commits)
For breaking API changes, use
as the type to alert reviewers.
For NVIDIA developers, prefix with JIRA number or NVBUG ID:
[TRTLLM-5516] perf: description
[nvbug/5334370] fix: description
Examples:
feat: Add support for starcoder-v2 FP8 base + FP16/BF16 LoRA
BREAKING CHANGE: Set default max batch size to 2048
chore: Remove version from plugins .so
None: Stringized enums for better error msgs
fix https://github.com/NVIDIA/TensorRT-LLM/issues/700: a Memory leak issue in C++ runtime
[TRTLLM-5516] perf: Replicate dummy request for cuda graph padding
PR Description
Address these points in the PR description:
- Background/motivation: Why is the change necessary?
- Summary: Summarize the changes in one paragraph.
- Size justification: If the PR is large, explain why it cannot be broken into multiple PRs.
- Impact assessment: Potential performance or functional impacts. Flag risks for reviewers.
- Related PRs: Link to any related PRs.
PR Conciseness
- Avoid committing commented-out code.
- Each PR should address a single concern. If there are several unrelated fixes, open separate PRs and indicate dependencies in the descriptions.
API Stability Tests
Some APIs are protected by the
API stability testsuite. If your PR breaks a protected API, the stability tests will fail with
API stability validation failed
. In this case, request review from the API code owners.
Quantified Impact of Common Mistakes
From the short-seq MHA branch (30 commits → net 2 files changed):
| Mistake | Commits wasted | Lines written & deleted | Root cause |
|---|
| Reimplementing existing forward method | 4 (commits 1,5,6,17) | ~150 lines | Didn't read |
| Custom RoPE handling | 5 (commits 1,13,16,17,18) | ~100 lines | Didn't trace how fused kernel handles RoPE |
| Tests before stable implementation | 5 (commits 3,4,8,11,15) | ~200 lines of churn | Tests coupled to implementation details |
| Compiled helpers created then removed | 4 (commits 10,12,13,18) | ~60 lines | Premature optimization without profiling |
| Style-only commit | 1 (commit 15) | 377 lines reformatted | Pre-commit not run on earlier commits |
| Stale comment fixes | 2 (commits 11,18) | ~15 lines | Comments not updated with code changes |
| Calling method directly instead of dispatcher | 3 (commits 21,23,30) | ~20 lines | Didn't trace dispatch chain |
| Not testing on SM90 | 1 (commit 30) | ~10 lines | Assumed uniform numerical behavior across SM versions |
| Enabled by default too early | 2 (commits 8,26) | ~5 lines | Shipped threshold=10240 before edge cases were found |
| Threshold semantics drift in chunked context | 1 (commit 28) | ~10 lines | doesn't account for cached tokens |
| Redundant test parametrizations | 3 (commits 24,25,27) | ~730 lines pruned | Tests written incrementally without path-coverage analysis |
Total waste: ~24 of 30 commits were fixes/reverts/cleanups of earlier work on the same branch. The final net change is ~200 lines in attention.py and ~665 lines in tests — achievable in ~4-5 clean commits.
Review Readiness Checklist
Before marking a PR ready for review: