JAX Development
Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured.
This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims.
Core promise
When this skill is active, the default standard is:
- produce runnable JAX code, not generic advice
- explain why the change works in JAX terms
- call out likely sharp bits even if the user did not ask
- verify claims with the bundled scripts when possible
- separate compile-time, run-time, transfer, and sharding issues instead of mixing them together
When this skill should own the task
Use this skill when the difficult part of the request is any of the following:
- translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX
- fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs
- choosing between , , , , , , , , , , , or export
- removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking
- reasoning about , meshes, , , explicit sharding, migration, multi-host semantics, or collectives
- using , , , lowering, compiler IR, profiler traces, or memory profiling
- using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree
Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning.
Do not over-apply the skill
Do not force JAX when the real problem is one of these instead:
- pure NumPy optimisation where JAX is explicitly out of scope
- generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component
- framework-only design questions whose hard part is not JAX
- irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX
When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill.
First-response workflow
1. Classify the task
Put the request into one or more lanes immediately:
- code design or porting
- debugging or correctness
- performance or compilation
- sharding or distributed execution
- advanced extension points
- JAX repo navigation or source-level questions
Then open the matching reference file:
references/EXPERT-WORKFLOW.md
for the overall workflow
references/MENTAL-MODEL.md
for tracing and staging semantics
references/TRANSFORM-DECISION-MATRIX.md
for choosing primitives
references/PORTING-PATTERNS.md
for NumPy or PyTorch rewrites
references/CODE-REVIEW-RUBRIC.md
for self-review before replying
references/DEBUGGING-TRIAGE.md
for error diagnosis
references/PERFORMANCE-PLAYBOOK.md
for speed, memory, and compile-time work
references/SHARDING-PLAYBOOK.md
for distributed and multi-device design
references/ADVANCED-EXTENSIONS.md
for custom autodiff, export, Pallas, FFI, and internals
- for local source-tree navigation
- for provenance and maintenance notes
2. Inspect before guessing
If the problem could be environment-, backend-, or project-specific, inspect first.
Environment:
bash
python3 scripts/jax_env_report.py --format json
Static project scan:
bash
python3 scripts/jax_project_scan.py PATH --format json
Benchmark a callable honestly:
bash
python3 scripts/jax_benchmark_harness.py --help
Inspect jaxpr, lowering, and IR:
bash
python3 scripts/jax_compile_probe.py --help
Check likely recompile behaviour across cases:
bash
python3 scripts/jax_recompile_explorer.py --help
Search a local JAX checkout:
bash
python3 scripts/jax_repo_locator.py --help
3. Reduce to a minimal reproducer
Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit.
4. Choose the least powerful mechanism that solves the problem
Default ordering:
- pure eager first
- then or
- then or
- then explicit sharding
- then
- then custom derivative, export, custom partitioning, or Pallas
- then FFI or JAX internals
Escalate only with evidence.
5. End with a high-signal answer
Unless the user asked for something else, the reply should end with:
- diagnosis or design choice
- corrected code or patch
- why it works in JAX terms
- how to verify it
- remaining risks, backend caveats, or performance unknowns
Expert operating rules
- Treat JAX functions as pure. Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture.
- Make randomness explicit. Thread keys through the program, split once per consumer, and return updated keys when state continues.
- Keep the hot path in JAX space. Host conversion inside transformed code is almost always a bug or a sync point.
- Separate static and dynamic values. Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation.
- Use structured control flow. If a branch or loop depends on array values, use JAX control-flow primitives instead of Python.
- Benchmark honestly. Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution.
- Optimise after evidence. Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites.
- Prefer current JAX idioms. Typed keys, , and modern sharding APIs are the default unless the codebase is intentionally legacy.
- Think globally for sharding first. Start with global-view code and explicit placement before dropping to per-device manual code.
- Never bluff backend-specific behaviour. CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred.
Default red flags to proactively check
Always scan for these, even if the user did not mention them:
- , , , , or printing arrays in a hot path
- Python , , or inside transformed code
- shape construction or indexing based on traced values
- global or reused PRNG keys
- repeated creation of jitted callables inside loops
- changing shapes, dtypes, or static arguments causing compile storms
- very large Python loops that should be or
- code that may be better expressed with modern sharding APIs
- unexplained precision assumptions or implicit expectations
- replicated-versus-sharded confusion in distributed code
Available scripts
scripts/jax_env_report.py
— report versions, backend, devices, config, env vars, and an optional smoke test.
scripts/jax_project_scan.py
— AST-based scan for common JAX sharp bits and migration targets.
scripts/jax_benchmark_harness.py
— benchmark a callable with warm-up, blocking, optional , and optional donation.
scripts/jax_compile_probe.py
— inspect , jaxpr, lowering, and compiler IR; optionally write artefacts to disk.
scripts/jax_recompile_explorer.py
— run several input cases through a jitted function and flag likely recompiles or signature drift.
scripts/jax_repo_locator.py
— search a local JAX checkout for relevant docs, tests, or source files by topic.
All scripts are non-interactive, support
, and default to structured JSON output.
Available assets
- — minimal reproducible example template
assets/training_step_template.py
— idiomatic compiled training step with explicit key plumbing
- — carry-state loop using
assets/sharding_template.py
— mesh plus starter
assets/shard_map_template.py
— manual SPMD starter using
assets/benchmark_template.py
— honest timing pattern with warm-up and blocking
assets/profile_template.py
— trace and memory-profile starter
assets/checkify_template.py
— runtime checks that survive
assets/custom_vjp_template.py
— custom reverse-mode rule starter
assets/export_template.py
— export and serialisation starter
assets/pallas_kernel_skeleton.py
— kernel-level starting point
assets/issue_report_template.md
— compact bug report / investigation template
Output quality bar
Before sending a final answer, mentally run the code or design through
references/CODE-REVIEW-RUBRIC.md
. The answer should usually satisfy all of the following:
- runnable or patch-ready code
- correct transformation and sharding semantics
- explicit discussion of compile and runtime consequences
- no accidental host round trips in the claimed hot path
- no hidden PRNG or state bugs
- an honest verification method
If the task is exploratory research code
Prefer a staged plan:
- get a correct eager version in
- add tests or invariants
- add transformations one at a time
- benchmark and profile
- only then attempt aggressive sharding or kernel work
This workflow beats premature
/
/Pallas every time.
Skill maintenance
When updating this skill, refresh the JAX facts most likely to drift:
- installation guidance
- sharding APIs and migration status
- randomness recommendations
- profiler and memory-tooling guidance
- export / AOT APIs
- Pallas and custom extension interfaces