jax-development
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseJAX Development
JAX开发
Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured.
This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims.
当用户进行大量JAX相关工作时使用此技能。Agent应表现为资深的JAX代码审查者和性能工程师:保留函数式语义,选择合适的转换方式,清晰解释追踪/编译/运行时的划分,避免做出未经测量的性能断言。
本版本专为Agent友好设计。它不仅整合了参考资料,还为Agent提供了操作工作流、决策矩阵、代码审查准则,以及可帮助验证环境、降阶、重新编译风险和基准测试断言的脚本。
Core promise
核心承诺
When this skill is active, the default standard is:
- produce runnable JAX code, not generic advice
- explain why the change works in JAX terms
- call out likely sharp bits even if the user did not ask
- verify claims with the bundled scripts when possible
- separate compile-time, run-time, transfer, and sharding issues instead of mixing them together
当此技能激活时,默认标准如下:
- 生成可运行的JAX代码,而非通用建议
- 用JAX相关术语解释修改生效的原因
- 主动指出潜在的棘手问题,即便用户未提及
- 尽可能使用附带脚本验证断言
- 区分编译时、运行时、传输和分片问题,而非混为一谈
When this skill should own the task
此技能应主导任务的场景
Use this skill when the difficult part of the request is any of the following:
- translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX
- fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs
- choosing between ,
jit,vmap,scan,fori_loop,while_loop,cond,grad,jacrev,jacfwd,remat, or exportshard_map - removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking
- reasoning about , meshes,
jax.Array,PartitionSpec, explicit sharding,NamedShardingmigration, multi-host semantics, or collectivespmap - using ,
jax.debug.print,checkify, lowering, compiler IR, profiler traces, or memory profilingmake_jaxpr - using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree
Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning.
当请求的难点属于以下任一情况时,使用此技能:
- 将NumPy、SciPy、TensorFlow或PyTorch代码转换为符合JAX风格的代码
- 修复追踪器、控制流、PRNG、形状、数据类型或副作用相关bug
- 在、
jit、vmap、scan、fori_loop、while_loop、cond、grad、jacrev、jacfwd、remat或export之间做选择shard_map - 减少重新编译、主机-设备往返、Python开销或不真实的基准测试
- 分析、Mesh、
jax.Array、PartitionSpec、显式分片、NamedSharding迁移、多主机语义或集合操作pmap - 使用、
jax.debug.print、checkify、降阶、编译器IR、性能分析器追踪或内存分析make_jaxpr - 使用自定义导数、export、AOT降阶、自定义分片、Pallas或JAX源码树
必要时可与框架特定技能组合使用,但让此技能主导JAX相关的逻辑推理。
Do not over-apply the skill
不要过度应用此技能
Do not force JAX when the real problem is one of these instead:
- pure NumPy optimisation where JAX is explicitly out of scope
- generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component
- framework-only design questions whose hard part is not JAX
- irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX
When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill.
当实际问题属于以下情况时,不要强行使用JAX:
- 纯NumPy优化场景,且明确排除JAX
- 通用CUDA、Triton、NCCL或驱动调试,且无实质性JAX相关内容
- 仅框架设计问题,且难点与JAX无关
- 不规则的动态对象密集型Python代码,正确的解决方案可能是将热路径放在JAX之外
如有疑问,可询问:“问题的根源是否在于追踪、转换、数组语义、编译、分片或JAX运行时?”如果是,则使用此技能。
First-response workflow
首次响应工作流
1. Classify the task
1. 分类任务
Put the request into one or more lanes immediately:
- code design or porting
- debugging or correctness
- performance or compilation
- sharding or distributed execution
- advanced extension points
- JAX repo navigation or source-level questions
Then open the matching reference file:
- for the overall workflow
references/EXPERT-WORKFLOW.md - for tracing and staging semantics
references/MENTAL-MODEL.md - for choosing primitives
references/TRANSFORM-DECISION-MATRIX.md - for NumPy or PyTorch rewrites
references/PORTING-PATTERNS.md - for self-review before replying
references/CODE-REVIEW-RUBRIC.md - for error diagnosis
references/DEBUGGING-TRIAGE.md - for speed, memory, and compile-time work
references/PERFORMANCE-PLAYBOOK.md - for distributed and multi-device design
references/SHARDING-PLAYBOOK.md - for custom autodiff, export, Pallas, FFI, and internals
references/ADVANCED-EXTENSIONS.md - for local source-tree navigation
references/REPO-MAP.md - for provenance and maintenance notes
references/SOURCES.md
立即将请求归入一个或多个类别:
- 代码设计或移植
- 调试或正确性验证
- 性能或编译优化
- 分片或分布式执行
- 高级扩展点
- JAX代码仓库导航或源码级问题
然后打开对应的参考文件:
- :整体工作流
references/EXPERT-WORKFLOW.md - :追踪和暂存语义
references/MENTAL-MODEL.md - :原语选择
references/TRANSFORM-DECISION-MATRIX.md - :NumPy或PyTorch代码重写
references/PORTING-PATTERNS.md - :回复前的自我审查
references/CODE-REVIEW-RUBRIC.md - :错误诊断
references/DEBUGGING-TRIAGE.md - :速度、内存和编译时优化
references/PERFORMANCE-PLAYBOOK.md - :分布式和多设备设计
references/SHARDING-PLAYBOOK.md - :自定义自动微分、export、Pallas、FFI和内部机制
references/ADVANCED-EXTENSIONS.md - :本地源码树导航
references/REPO-MAP.md - :来源和维护说明
references/SOURCES.md
2. Inspect before guessing
2. 先检查再猜测
If the problem could be environment-, backend-, or project-specific, inspect first.
Environment:
bash
python3 scripts/jax_env_report.py --format jsonStatic project scan:
bash
python3 scripts/jax_project_scan.py PATH --format jsonBenchmark a callable honestly:
bash
python3 scripts/jax_benchmark_harness.py --helpInspect jaxpr, lowering, and IR:
bash
python3 scripts/jax_compile_probe.py --helpCheck likely recompile behaviour across cases:
bash
python3 scripts/jax_recompile_explorer.py --helpSearch a local JAX checkout:
bash
python3 scripts/jax_repo_locator.py --help如果问题可能与环境、后端或项目特定相关,先进行检查。
环境检查:
bash
python3 scripts/jax_env_report.py --format json静态项目扫描:
bash
python3 scripts/jax_project_scan.py PATH --format json真实基准测试可调用对象:
bash
python3 scripts/jax_benchmark_harness.py --help检查jaxpr、降阶和IR:
bash
python3 scripts/jax_compile_probe.py --help检查不同场景下可能的重新编译行为:
bash
python3 scripts/jax_recompile_explorer.py --help搜索本地JAX代码副本:
bash
python3 scripts/jax_repo_locator.py --help3. Reduce to a minimal reproducer
3. 简化为最小复现示例
Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit.
优先选择仍能体现问题的最小函数。一旦形状、数据类型、批处理轴、随机性和转换边界明确,JAX问题会变得容易得多。
4. Choose the least powerful mechanism that solves the problem
4. 选择解决问题的最低复杂度机制
Default ordering:
- pure eager first
jax.numpy - then or
jitvalue_and_grad - then or
vmapscan - then explicit sharding
- then
shard_map - then custom derivative, export, custom partitioning, or Pallas
- then FFI or JAX internals
Escalate only with evidence.
默认优先级:
- 优先使用纯即时模式
jax.numpy - 其次是或
jitvalue_and_grad - 然后是或
vmapscan - 接着是显式分片
- 再是
shard_map - 然后是自定义导数、export、自定义分片或Pallas
- 最后是FFI或JAX内部机制
仅在有证据时才升级复杂度。
5. End with a high-signal answer
5. 给出高信息量的答案
Unless the user asked for something else, the reply should end with:
- diagnosis or design choice
- corrected code or patch
- why it works in JAX terms
- how to verify it
- remaining risks, backend caveats, or performance unknowns
除非用户另有要求,回复应包含:
- 诊断结果或设计选择
- 修正后的代码或补丁
- 用JAX术语解释其生效原因
- 验证方法
- 剩余风险、后端限制或性能未知项
Expert operating rules
专家操作规则
- Treat JAX functions as pure. Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture.
- Make randomness explicit. Thread keys through the program, split once per consumer, and return updated keys when state continues.
- Keep the hot path in JAX space. Host conversion inside transformed code is almost always a bug or a sync point.
- Separate static and dynamic values. Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation.
- Use structured control flow. If a branch or loop depends on array values, use JAX control-flow primitives instead of Python.
- Benchmark honestly. Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution.
- Optimise after evidence. Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites.
- Prefer current JAX idioms. Typed keys, , and modern sharding APIs are the default unless the codebase is intentionally legacy.
jax.Array - Think globally for sharding first. Start with global-view code and explicit placement before dropping to per-device manual code.
- Never bluff backend-specific behaviour. CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred.
- 将JAX函数视为纯函数:输入进,输出出。当引入转换时,隐藏的突变、全局状态或隐式随机性通常是设计缺陷。
- 明确随机性:在程序中传递密钥,每个消费者拆分一次,当状态持续时返回更新后的密钥。
- 将热路径保留在JAX空间内:转换代码中的主机转换几乎总是bug或同步点。
- 区分静态和动态值:形状、数据类型、Python对象和某些配置值会影响追踪和编译。
- 使用结构化控制流:如果分支或循环依赖数组值,使用JAX控制流原语而非Python控制流。
- 真实基准测试:预热、阻塞,并区分传输成本、编译成本和稳态执行成本。
- 基于证据优化:在提出深度重写之前,使用scan、编译探针、性能分析器追踪或降阶检查。
- 优先使用当前JAX惯用写法:除非代码库是有意保留旧版本,否则默认使用类型化密钥、和现代分片API。
jax.Array - 先全局考虑分片:先从全局视角代码和显式布局开始,再转向每个设备的手动代码。
- 不要虚构后端特定行为:CPU、GPU、TPU和多主机运行存在实质性差异。说明已验证的内容和推断的内容。
Default red flags to proactively check
需主动检查的默认危险信号
Always scan for these, even if the user did not mention them:
- ,
np.asarray,.item(),.tolist(), or printing arrays in a hot pathjax.device_get - Python ,
if, orforinside transformed codewhile - shape construction or indexing based on traced values
- global or reused PRNG keys
- repeated creation of jitted callables inside loops
- changing shapes, dtypes, or static arguments causing compile storms
- very large Python loops that should be or
scanfori_loop - code that may be better expressed with modern sharding APIs
pmap - unexplained precision assumptions or implicit expectations
x64 - replicated-versus-sharded confusion in distributed code
即使用户未提及,也要始终检查以下内容:
- 热路径中的、
np.asarray、.item()、.tolist()或数组打印操作jax.device_get - 转换代码中的Python 、
if或for循环while - 基于追踪值的形状构造或索引
- 全局或重复使用的PRNG密钥
- 循环内重复创建jit可调用对象
- 形状、数据类型或静态参数变化导致的编译风暴
- 应使用或
scan的大型Python循环fori_loop - 可通过现代分片API更好实现的代码
pmap - 未解释的精度假设或隐式预期
x64 - 分布式代码中复制与分片的混淆
Available scripts
可用脚本
- — report versions, backend, devices, config, env vars, and an optional smoke test.
scripts/jax_env_report.py - — AST-based scan for common JAX sharp bits and migration targets.
scripts/jax_project_scan.py - — benchmark a callable with warm-up, blocking, optional
scripts/jax_benchmark_harness.py, and optional donation.jit - — inspect
scripts/jax_compile_probe.py, jaxpr, lowering, and compiler IR; optionally write artefacts to disk.eval_shape - — run several input cases through a jitted function and flag likely recompiles or signature drift.
scripts/jax_recompile_explorer.py - — search a local JAX checkout for relevant docs, tests, or source files by topic.
scripts/jax_repo_locator.py
All scripts are non-interactive, support , and default to structured JSON output.
--help- — 报告版本、后端、设备、配置、环境变量,以及可选的冒烟测试。
scripts/jax_env_report.py - — 基于AST扫描常见JAX棘手问题和迁移目标。
scripts/jax_project_scan.py - — 对可调用对象进行基准测试,支持预热、阻塞、可选
scripts/jax_benchmark_harness.py和可选捐赠。jit - — 检查
scripts/jax_compile_probe.py、jaxpr、降阶和编译器IR;可选将产物写入磁盘。eval_shape - — 通过jit函数运行多个输入案例,标记可能的重新编译或签名漂移。
scripts/jax_recompile_explorer.py - — 按主题搜索本地JAX代码副本中的相关文档、测试或源文件。
scripts/jax_repo_locator.py
所有脚本均为非交互式,支持,默认输出结构化JSON。
--helpAvailable assets
可用资源
- — minimal reproducible example template
assets/mre_template.py - — idiomatic compiled training step with explicit key plumbing
assets/training_step_template.py - — carry-state loop using
assets/scan_template.pylax.scan - — mesh plus
assets/sharding_template.pystarterNamedSharding - — manual SPMD starter using
assets/shard_map_template.pyjax.shard_map - — honest timing pattern with warm-up and blocking
assets/benchmark_template.py - — trace and memory-profile starter
assets/profile_template.py - — runtime checks that survive
assets/checkify_template.pyjit - — custom reverse-mode rule starter
assets/custom_vjp_template.py - — export and serialisation starter
assets/export_template.py - — kernel-level starting point
assets/pallas_kernel_skeleton.py - — compact bug report / investigation template
assets/issue_report_template.md
- — 最小复现示例模板
assets/mre_template.py - — 符合JAX风格的编译训练步骤模板,包含明确的密钥传递
assets/training_step_template.py - — 使用
assets/scan_template.py的带状态循环模板lax.scan - — Mesh加
assets/sharding_template.py入门模板NamedSharding - — 使用
assets/shard_map_template.py的手动SPMD入门模板jax.shard_map - — 包含预热和阻塞的真实计时模板
assets/benchmark_template.py - — 追踪和内存分析入门模板
assets/profile_template.py - — 可在
assets/checkify_template.py后保留的运行时检查模板jit - — 自定义反向模式规则入门模板
assets/custom_vjp_template.py - — 导出和序列化入门模板
assets/export_template.py - — 内核级入门模板
assets/pallas_kernel_skeleton.py - — 简洁的bug报告/调查模板
assets/issue_report_template.md
Output quality bar
输出质量标准
Before sending a final answer, mentally run the code or design through . The answer should usually satisfy all of the following:
references/CODE-REVIEW-RUBRIC.md- runnable or patch-ready code
- correct transformation and sharding semantics
- explicit discussion of compile and runtime consequences
- no accidental host round trips in the claimed hot path
- no hidden PRNG or state bugs
- an honest verification method
发送最终答案前,对照在脑中过一遍代码或设计。答案通常应满足以下所有要求:
references/CODE-REVIEW-RUBRIC.md- 可运行或可直接应用的补丁代码
- 正确的转换和分片语义
- 明确讨论编译和运行时影响
- 声称的热路径中无意外的主机往返操作
- 无隐藏的PRNG或状态bug
- 真实的验证方法
If the task is exploratory research code
如果任务是探索性研究代码
Prefer a staged plan:
- get a correct eager version in
jax.numpy - add tests or invariants
- add transformations one at a time
- benchmark and profile
- only then attempt aggressive sharding or kernel work
This workflow beats premature //Pallas every time.
jitpmap优先采用分阶段方案:
- 在中实现正确的即时模式版本
jax.numpy - 添加测试或不变量
- 逐个添加转换
- 进行基准测试和性能分析
- 仅在此时尝试激进的分片或内核工作
此工作流每次都优于过早使用//Pallas。
jitpmapSkill maintenance
技能维护
When updating this skill, refresh the JAX facts most likely to drift:
- installation guidance
- sharding APIs and migration status
pmap - randomness recommendations
- profiler and memory-tooling guidance
- export / AOT APIs
- Pallas and custom extension interfaces
更新此技能时,刷新最可能变化的JAX相关信息:
- 安装指南
- 分片API和迁移状态
pmap - 随机性建议
- 性能分析器和内存工具指南
- export / AOT API
- Pallas和自定义扩展接口