jax-development

Compare original and translation side by side

🇺🇸

Original

English
🇨🇳

Translation

Chinese

JAX Development

JAX开发

Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured.
This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims.
当用户进行大量JAX相关工作时使用此技能。Agent应表现为资深的JAX代码审查者和性能工程师:保留函数式语义,选择合适的转换方式,清晰解释追踪/编译/运行时的划分,避免做出未经测量的性能断言。
本版本专为Agent友好设计。它不仅整合了参考资料,还为Agent提供了操作工作流、决策矩阵、代码审查准则,以及可帮助验证环境、降阶、重新编译风险和基准测试断言的脚本。

Core promise

核心承诺

When this skill is active, the default standard is:
  1. produce runnable JAX code, not generic advice
  2. explain why the change works in JAX terms
  3. call out likely sharp bits even if the user did not ask
  4. verify claims with the bundled scripts when possible
  5. separate compile-time, run-time, transfer, and sharding issues instead of mixing them together
当此技能激活时,默认标准如下:
  1. 生成可运行的JAX代码,而非通用建议
  2. 用JAX相关术语解释修改生效的原因
  3. 主动指出潜在的棘手问题,即便用户未提及
  4. 尽可能使用附带脚本验证断言
  5. 区分编译时、运行时、传输和分片问题,而非混为一谈

When this skill should own the task

此技能应主导任务的场景

Use this skill when the difficult part of the request is any of the following:
  • translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX
  • fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs
  • choosing between
    jit
    ,
    vmap
    ,
    scan
    ,
    fori_loop
    ,
    while_loop
    ,
    cond
    ,
    grad
    ,
    jacrev
    ,
    jacfwd
    ,
    remat
    ,
    shard_map
    , or export
  • removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking
  • reasoning about
    jax.Array
    , meshes,
    PartitionSpec
    ,
    NamedSharding
    , explicit sharding,
    pmap
    migration, multi-host semantics, or collectives
  • using
    jax.debug.print
    ,
    checkify
    ,
    make_jaxpr
    , lowering, compiler IR, profiler traces, or memory profiling
  • using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree
Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning.
当请求的难点属于以下任一情况时,使用此技能:
  • 将NumPy、SciPy、TensorFlow或PyTorch代码转换为符合JAX风格的代码
  • 修复追踪器、控制流、PRNG、形状、数据类型或副作用相关bug
  • jit
    vmap
    scan
    fori_loop
    while_loop
    cond
    grad
    jacrev
    jacfwd
    remat
    shard_map
    或export之间做选择
  • 减少重新编译、主机-设备往返、Python开销或不真实的基准测试
  • 分析
    jax.Array
    、Mesh、
    PartitionSpec
    NamedSharding
    、显式分片、
    pmap
    迁移、多主机语义或集合操作
  • 使用
    jax.debug.print
    checkify
    make_jaxpr
    、降阶、编译器IR、性能分析器追踪或内存分析
  • 使用自定义导数、export、AOT降阶、自定义分片、Pallas或JAX源码树
必要时可与框架特定技能组合使用,但让此技能主导JAX相关的逻辑推理。

Do not over-apply the skill

不要过度应用此技能

Do not force JAX when the real problem is one of these instead:
  • pure NumPy optimisation where JAX is explicitly out of scope
  • generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component
  • framework-only design questions whose hard part is not JAX
  • irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX
When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill.
当实际问题属于以下情况时,不要强行使用JAX:
  • 纯NumPy优化场景,且明确排除JAX
  • 通用CUDA、Triton、NCCL或驱动调试,且无实质性JAX相关内容
  • 仅框架设计问题,且难点与JAX无关
  • 不规则的动态对象密集型Python代码,正确的解决方案可能是将热路径放在JAX之外
如有疑问,可询问:“问题的根源是否在于追踪、转换、数组语义、编译、分片或JAX运行时?”如果是,则使用此技能。

First-response workflow

首次响应工作流

1. Classify the task

1. 分类任务

Put the request into one or more lanes immediately:
  • code design or porting
  • debugging or correctness
  • performance or compilation
  • sharding or distributed execution
  • advanced extension points
  • JAX repo navigation or source-level questions
Then open the matching reference file:
  • references/EXPERT-WORKFLOW.md
    for the overall workflow
  • references/MENTAL-MODEL.md
    for tracing and staging semantics
  • references/TRANSFORM-DECISION-MATRIX.md
    for choosing primitives
  • references/PORTING-PATTERNS.md
    for NumPy or PyTorch rewrites
  • references/CODE-REVIEW-RUBRIC.md
    for self-review before replying
  • references/DEBUGGING-TRIAGE.md
    for error diagnosis
  • references/PERFORMANCE-PLAYBOOK.md
    for speed, memory, and compile-time work
  • references/SHARDING-PLAYBOOK.md
    for distributed and multi-device design
  • references/ADVANCED-EXTENSIONS.md
    for custom autodiff, export, Pallas, FFI, and internals
  • references/REPO-MAP.md
    for local source-tree navigation
  • references/SOURCES.md
    for provenance and maintenance notes
立即将请求归入一个或多个类别:
  • 代码设计或移植
  • 调试或正确性验证
  • 性能或编译优化
  • 分片或分布式执行
  • 高级扩展点
  • JAX代码仓库导航或源码级问题
然后打开对应的参考文件:
  • references/EXPERT-WORKFLOW.md
    :整体工作流
  • references/MENTAL-MODEL.md
    :追踪和暂存语义
  • references/TRANSFORM-DECISION-MATRIX.md
    :原语选择
  • references/PORTING-PATTERNS.md
    :NumPy或PyTorch代码重写
  • references/CODE-REVIEW-RUBRIC.md
    :回复前的自我审查
  • references/DEBUGGING-TRIAGE.md
    :错误诊断
  • references/PERFORMANCE-PLAYBOOK.md
    :速度、内存和编译时优化
  • references/SHARDING-PLAYBOOK.md
    :分布式和多设备设计
  • references/ADVANCED-EXTENSIONS.md
    :自定义自动微分、export、Pallas、FFI和内部机制
  • references/REPO-MAP.md
    :本地源码树导航
  • references/SOURCES.md
    :来源和维护说明

2. Inspect before guessing

2. 先检查再猜测

If the problem could be environment-, backend-, or project-specific, inspect first.
Environment:
bash
python3 scripts/jax_env_report.py --format json
Static project scan:
bash
python3 scripts/jax_project_scan.py PATH --format json
Benchmark a callable honestly:
bash
python3 scripts/jax_benchmark_harness.py --help
Inspect jaxpr, lowering, and IR:
bash
python3 scripts/jax_compile_probe.py --help
Check likely recompile behaviour across cases:
bash
python3 scripts/jax_recompile_explorer.py --help
Search a local JAX checkout:
bash
python3 scripts/jax_repo_locator.py --help
如果问题可能与环境、后端或项目特定相关,先进行检查。
环境检查:
bash
python3 scripts/jax_env_report.py --format json
静态项目扫描:
bash
python3 scripts/jax_project_scan.py PATH --format json
真实基准测试可调用对象:
bash
python3 scripts/jax_benchmark_harness.py --help
检查jaxpr、降阶和IR:
bash
python3 scripts/jax_compile_probe.py --help
检查不同场景下可能的重新编译行为:
bash
python3 scripts/jax_recompile_explorer.py --help
搜索本地JAX代码副本:
bash
python3 scripts/jax_repo_locator.py --help

3. Reduce to a minimal reproducer

3. 简化为最小复现示例

Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit.
优先选择仍能体现问题的最小函数。一旦形状、数据类型、批处理轴、随机性和转换边界明确,JAX问题会变得容易得多。

4. Choose the least powerful mechanism that solves the problem

4. 选择解决问题的最低复杂度机制

Default ordering:
  • pure eager
    jax.numpy
    first
  • then
    jit
    or
    value_and_grad
  • then
    vmap
    or
    scan
  • then explicit sharding
  • then
    shard_map
  • then custom derivative, export, custom partitioning, or Pallas
  • then FFI or JAX internals
Escalate only with evidence.
默认优先级:
  • 优先使用纯即时模式
    jax.numpy
  • 其次是
    jit
    value_and_grad
  • 然后是
    vmap
    scan
  • 接着是显式分片
  • 再是
    shard_map
  • 然后是自定义导数、export、自定义分片或Pallas
  • 最后是FFI或JAX内部机制
仅在有证据时才升级复杂度。

5. End with a high-signal answer

5. 给出高信息量的答案

Unless the user asked for something else, the reply should end with:
  • diagnosis or design choice
  • corrected code or patch
  • why it works in JAX terms
  • how to verify it
  • remaining risks, backend caveats, or performance unknowns
除非用户另有要求,回复应包含:
  • 诊断结果或设计选择
  • 修正后的代码或补丁
  • 用JAX术语解释其生效原因
  • 验证方法
  • 剩余风险、后端限制或性能未知项

Expert operating rules

专家操作规则

  1. Treat JAX functions as pure. Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture.
  2. Make randomness explicit. Thread keys through the program, split once per consumer, and return updated keys when state continues.
  3. Keep the hot path in JAX space. Host conversion inside transformed code is almost always a bug or a sync point.
  4. Separate static and dynamic values. Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation.
  5. Use structured control flow. If a branch or loop depends on array values, use JAX control-flow primitives instead of Python.
  6. Benchmark honestly. Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution.
  7. Optimise after evidence. Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites.
  8. Prefer current JAX idioms. Typed keys,
    jax.Array
    , and modern sharding APIs are the default unless the codebase is intentionally legacy.
  9. Think globally for sharding first. Start with global-view code and explicit placement before dropping to per-device manual code.
  10. Never bluff backend-specific behaviour. CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred.
  1. 将JAX函数视为纯函数:输入进,输出出。当引入转换时,隐藏的突变、全局状态或隐式随机性通常是设计缺陷。
  2. 明确随机性:在程序中传递密钥,每个消费者拆分一次,当状态持续时返回更新后的密钥。
  3. 将热路径保留在JAX空间内:转换代码中的主机转换几乎总是bug或同步点。
  4. 区分静态和动态值:形状、数据类型、Python对象和某些配置值会影响追踪和编译。
  5. 使用结构化控制流:如果分支或循环依赖数组值,使用JAX控制流原语而非Python控制流。
  6. 真实基准测试:预热、阻塞,并区分传输成本、编译成本和稳态执行成本。
  7. 基于证据优化:在提出深度重写之前,使用scan、编译探针、性能分析器追踪或降阶检查。
  8. 优先使用当前JAX惯用写法:除非代码库是有意保留旧版本,否则默认使用类型化密钥、
    jax.Array
    和现代分片API。
  9. 先全局考虑分片:先从全局视角代码和显式布局开始,再转向每个设备的手动代码。
  10. 不要虚构后端特定行为:CPU、GPU、TPU和多主机运行存在实质性差异。说明已验证的内容和推断的内容。

Default red flags to proactively check

需主动检查的默认危险信号

Always scan for these, even if the user did not mention them:
  • np.asarray
    ,
    .item()
    ,
    .tolist()
    ,
    jax.device_get
    , or printing arrays in a hot path
  • Python
    if
    ,
    for
    , or
    while
    inside transformed code
  • shape construction or indexing based on traced values
  • global or reused PRNG keys
  • repeated creation of jitted callables inside loops
  • changing shapes, dtypes, or static arguments causing compile storms
  • very large Python loops that should be
    scan
    or
    fori_loop
  • pmap
    code that may be better expressed with modern sharding APIs
  • unexplained precision assumptions or implicit
    x64
    expectations
  • replicated-versus-sharded confusion in distributed code
即使用户未提及,也要始终检查以下内容:
  • 热路径中的
    np.asarray
    .item()
    .tolist()
    jax.device_get
    或数组打印操作
  • 转换代码中的Python
    if
    for
    while
    循环
  • 基于追踪值的形状构造或索引
  • 全局或重复使用的PRNG密钥
  • 循环内重复创建jit可调用对象
  • 形状、数据类型或静态参数变化导致的编译风暴
  • 应使用
    scan
    fori_loop
    的大型Python循环
  • 可通过现代分片API更好实现的
    pmap
    代码
  • 未解释的精度假设或隐式
    x64
    预期
  • 分布式代码中复制与分片的混淆

Available scripts

可用脚本

  • scripts/jax_env_report.py
    — report versions, backend, devices, config, env vars, and an optional smoke test.
  • scripts/jax_project_scan.py
    — AST-based scan for common JAX sharp bits and migration targets.
  • scripts/jax_benchmark_harness.py
    — benchmark a callable with warm-up, blocking, optional
    jit
    , and optional donation.
  • scripts/jax_compile_probe.py
    — inspect
    eval_shape
    , jaxpr, lowering, and compiler IR; optionally write artefacts to disk.
  • scripts/jax_recompile_explorer.py
    — run several input cases through a jitted function and flag likely recompiles or signature drift.
  • scripts/jax_repo_locator.py
    — search a local JAX checkout for relevant docs, tests, or source files by topic.
All scripts are non-interactive, support
--help
, and default to structured JSON output.
  • scripts/jax_env_report.py
    — 报告版本、后端、设备、配置、环境变量,以及可选的冒烟测试。
  • scripts/jax_project_scan.py
    — 基于AST扫描常见JAX棘手问题和迁移目标。
  • scripts/jax_benchmark_harness.py
    — 对可调用对象进行基准测试,支持预热、阻塞、可选
    jit
    和可选捐赠。
  • scripts/jax_compile_probe.py
    — 检查
    eval_shape
    、jaxpr、降阶和编译器IR;可选将产物写入磁盘。
  • scripts/jax_recompile_explorer.py
    — 通过jit函数运行多个输入案例,标记可能的重新编译或签名漂移。
  • scripts/jax_repo_locator.py
    — 按主题搜索本地JAX代码副本中的相关文档、测试或源文件。
所有脚本均为非交互式,支持
--help
,默认输出结构化JSON。

Available assets

可用资源

  • assets/mre_template.py
    — minimal reproducible example template
  • assets/training_step_template.py
    — idiomatic compiled training step with explicit key plumbing
  • assets/scan_template.py
    — carry-state loop using
    lax.scan
  • assets/sharding_template.py
    — mesh plus
    NamedSharding
    starter
  • assets/shard_map_template.py
    — manual SPMD starter using
    jax.shard_map
  • assets/benchmark_template.py
    — honest timing pattern with warm-up and blocking
  • assets/profile_template.py
    — trace and memory-profile starter
  • assets/checkify_template.py
    — runtime checks that survive
    jit
  • assets/custom_vjp_template.py
    — custom reverse-mode rule starter
  • assets/export_template.py
    — export and serialisation starter
  • assets/pallas_kernel_skeleton.py
    — kernel-level starting point
  • assets/issue_report_template.md
    — compact bug report / investigation template
  • assets/mre_template.py
    — 最小复现示例模板
  • assets/training_step_template.py
    — 符合JAX风格的编译训练步骤模板,包含明确的密钥传递
  • assets/scan_template.py
    — 使用
    lax.scan
    的带状态循环模板
  • assets/sharding_template.py
    — Mesh加
    NamedSharding
    入门模板
  • assets/shard_map_template.py
    — 使用
    jax.shard_map
    的手动SPMD入门模板
  • assets/benchmark_template.py
    — 包含预热和阻塞的真实计时模板
  • assets/profile_template.py
    — 追踪和内存分析入门模板
  • assets/checkify_template.py
    — 可在
    jit
    后保留的运行时检查模板
  • assets/custom_vjp_template.py
    — 自定义反向模式规则入门模板
  • assets/export_template.py
    — 导出和序列化入门模板
  • assets/pallas_kernel_skeleton.py
    — 内核级入门模板
  • assets/issue_report_template.md
    — 简洁的bug报告/调查模板

Output quality bar

输出质量标准

Before sending a final answer, mentally run the code or design through
references/CODE-REVIEW-RUBRIC.md
. The answer should usually satisfy all of the following:
  • runnable or patch-ready code
  • correct transformation and sharding semantics
  • explicit discussion of compile and runtime consequences
  • no accidental host round trips in the claimed hot path
  • no hidden PRNG or state bugs
  • an honest verification method
发送最终答案前,对照
references/CODE-REVIEW-RUBRIC.md
在脑中过一遍代码或设计。答案通常应满足以下所有要求:
  • 可运行或可直接应用的补丁代码
  • 正确的转换和分片语义
  • 明确讨论编译和运行时影响
  • 声称的热路径中无意外的主机往返操作
  • 无隐藏的PRNG或状态bug
  • 真实的验证方法

If the task is exploratory research code

如果任务是探索性研究代码

Prefer a staged plan:
  1. get a correct eager version in
    jax.numpy
  2. add tests or invariants
  3. add transformations one at a time
  4. benchmark and profile
  5. only then attempt aggressive sharding or kernel work
This workflow beats premature
jit
/
pmap
/Pallas every time.
优先采用分阶段方案:
  1. jax.numpy
    中实现正确的即时模式版本
  2. 添加测试或不变量
  3. 逐个添加转换
  4. 进行基准测试和性能分析
  5. 仅在此时尝试激进的分片或内核工作
此工作流每次都优于过早使用
jit
/
pmap
/Pallas。

Skill maintenance

技能维护

When updating this skill, refresh the JAX facts most likely to drift:
  • installation guidance
  • sharding APIs and
    pmap
    migration status
  • randomness recommendations
  • profiler and memory-tooling guidance
  • export / AOT APIs
  • Pallas and custom extension interfaces
更新此技能时,刷新最可能变化的JAX相关信息:
  • 安装指南
  • 分片API和
    pmap
    迁移状态
  • 随机性建议
  • 性能分析器和内存工具指南
  • export / AOT API
  • Pallas和自定义扩展接口