tao-train-image-classification
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseClassification PyT
Classification PyT
PyTorch image classification. Supports a wide range of backbones (FAN, EfficientNet, ResNet, etc.) with distillation and quantization for deployment.
Set model.backbone.pretrained_backbone_path for backbone weights or train.pretrained_model_path for full model.
For TAO Deploy TensorRT actions (, TensorRT , and TensorRT ), read first. Deploy spec templates live in this skill's folder with the prefix.
gen_trt_engineevaluateinferencereferences/tao-deploy-image-classification.mdreferences/spec_template_deploy_*.yaml基于PyTorch的图像分类,支持多种骨干网络(FAN、EfficientNet、ResNet等),并带有用于部署的蒸馏和量化功能。
设置model.backbone.pretrained_backbone_path以指定骨干网络权重,或设置train.pretrained_model_path以指定完整模型权重。
对于TAO Deploy TensorRT操作(、TensorRT 和TensorRT ),请先阅读。部署规格模板位于本技能的文件夹中,前缀为。
gen_trt_engineevaluateinferencereferences/tao-deploy-image-classification.mdreferences/spec_template_deploy_*.yamlDataclass Schemas
数据类模式(Dataclass Schemas)
Generated TAO Core schemas are packaged in , with listing available actions. Each generated schema also emits from the schema top-level field. AutoML enablement is declared at the model layer in via . Runnable AutoML still requires and to exist and parse. Use the packaged train schema for , , defaults, min/max bounds, enums, option weights, math conditions, dependencies, and popular parameters. Do not expect at runtime; maintainers regenerate schemas/templates before packaging the skill bank.
schemas/<action>.schema.jsonschemas/manifest.jsonreferences/spec_template_<action>.yamldefaultreferences/skill_info.yamlautoml_enabledschemas/train.schema.jsonreferences/spec_template_train.yamlautoml_default_parametersautoml_disabled_parameters~/tao-core生成的TAO Core模式打包在中,列出了可用操作。每个生成的模式还会从模式顶层的字段生成。AutoML支持在的模型层通过声明。可运行的AutoML仍要求和存在且可解析。使用打包的训练模式来配置、、默认值、最小/最大范围、枚举、选项权重、数学条件、依赖关系以及常用参数。运行时不要依赖;维护人员会在打包技能库前重新生成模式/模板。
schemas/<action>.schema.jsonschemas/manifest.jsondefaultreferences/spec_template_<action>.yamlreferences/skill_info.yamlautoml_enabledschemas/train.schema.jsonreferences/spec_template_train.yamlautoml_default_parametersautoml_disabled_parameters~/tao-coreTrain Action Policy
训练操作策略(Train Action Policy)
This model is AutoML-enabled at the model layer. Before handling any train-stage request, read and resolve the run override from either an explicit value or the user's workflow request. Treat phrases like "turn off AutoML", "disable AutoML", "no HPO", or "plain training" as for this run only; otherwise default to . When , , and both and are packaged, route the train action through by default with this model's . Preserve workflow/application overrides for datasets, specs, output directories, GPU/platform settings, parent checkpoints, and . Use direct model training only when or the packaged train schema/template is missing; in the missing-schema case, report that AutoML is enabled but not runnable for this model until schemas are generated.
references/skill_info.yamlautoml_policyautoml_policy: offautoautoml_policy: autoautoml_enabled: trueschemas/train.schema.jsonreferences/spec_template_train.yamltao-skill-bank:tao-run-automlskill_dirautoml_policyautoml_policy: offNon-train actions such as , , , and deploy flows stay in this model skill. The per-run override does not change model metadata.
evaluateinferenceexportautoml_policy该模型在模型层支持AutoML。处理任何训练阶段请求前,请阅读,并通过显式的值或用户的工作流请求解析运行覆盖配置。将“turn off AutoML”、“disable AutoML”、“no HPO”或“plain training”等短语视为本次运行的;否则默认设置为。当、,且和均已打包时,默认将训练操作通过路由,并传入该模型的。保留数据集、规格、输出目录、GPU/平台设置、父检查点和的工作流/应用覆盖配置。仅当或打包的训练模式/模板缺失时,才使用直接模型训练;在模式缺失的情况下,需报告该模型已启用AutoML但无法运行,直到生成模式为止。
references/skill_info.yamlautoml_policyautoml_policy: offautoautoml_policy: autoautoml_enabled: trueschemas/train.schema.jsonreferences/spec_template_train.yamltao-skill-bank:tao-run-automlskill_dirautoml_policyautoml_policy: off非训练操作(如、、和部署流程)仍在本模型技能中处理。每次运行的覆盖配置不会改变模型元数据。
evaluateinferenceexportautoml_policyTraining Requirements
训练要求
- Dataset type: image_classification
- Formats: classification_pyt
- Monitoring metric: val_acc_1
- 数据集类型: image_classification
- 格式: classification_pyt
- 监控指标: val_acc_1
Per-Action Dataset Requirements
各操作的数据集要求
| Action | Spec Key | Source | Files | List? |
|---|---|---|---|---|
| distill | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | No |
| distill | dataset.classes_file | train_datasets | classes.txt | No |
| distill | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | No |
| evaluate | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | No |
| evaluate | dataset.classes_file | eval_dataset | classes.txt | No |
| evaluate | dataset.test_dataset.images_dir | inference_dataset | images_test.tar.gz | No |
| export | dataset.root_dir | train_datasets | No | |
| inference | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | No |
| inference | dataset.classes_file | eval_dataset | classes.txt | No |
| inference | dataset.test_dataset.images_dir | inference_dataset | images_test.tar.gz | No |
| quantize | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | No |
| quantize | dataset.classes_file | train_datasets | classes.txt | No |
| quantize | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | No |
| quantize | dataset.quant_calibration_dataset.images_dir | calibration_dataset | images_train.tar.gz | No |
| train | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | No |
| train | dataset.classes_file | train_datasets | classes.txt | No |
| train | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | No |
| 操作 | 规格键 | 来源 | 文件 | 是否为列表? |
|---|---|---|---|---|
| distill | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | 否 |
| distill | dataset.classes_file | train_datasets | classes.txt | 否 |
| distill | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | 否 |
| evaluate | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | 否 |
| evaluate | dataset.classes_file | eval_dataset | classes.txt | 否 |
| evaluate | dataset.test_dataset.images_dir | inference_dataset | images_test.tar.gz | 否 |
| export | dataset.root_dir | train_datasets | 否 | |
| inference | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | 否 |
| inference | dataset.classes_file | eval_dataset | classes.txt | 否 |
| inference | dataset.test_dataset.images_dir | inference_dataset | images_test.tar.gz | 否 |
| quantize | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | 否 |
| quantize | dataset.classes_file | train_datasets | classes.txt | 否 |
| quantize | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | 否 |
| quantize | dataset.quant_calibration_dataset.images_dir | calibration_dataset | images_train.tar.gz | 否 |
| train | dataset.train_dataset.images_dir | train_datasets | images_train.tar.gz | 否 |
| train | dataset.classes_file | train_datasets | classes.txt | 否 |
| train | dataset.val_dataset.images_dir | eval_dataset | images_val.tar.gz | 否 |
Typical Spec Overrides
典型规格覆盖配置
Data source overrides are mandatory for every action — the agent MUST construct data source paths from the Per-Action Dataset Requirements table above and include them in .
spec_overridespython
S3_TRAIN = "s3://bucket/data/train"
S3_EVAL = "s3://bucket/data/eval"train (mandatory data sources):
python
{
"train.num_epochs": 2,
"train.validation_interval": 2,
"train.checkpoint_interval": 2,
"train.num_gpus": 1,
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
}export (mandatory data sources):
python
{
"export.input_height": 224,
"export.input_width": 224,
"dataset.root_dir": f"{S3_TRAIN}",
}gen_trt_engine:
python
{
"gen_trt_engine.tensorrt.data_type": "fp16",
}inference (mandatory data sources):
python
{
"dataset.batch_size": 1,
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.classes_file": f"{S3_EVAL}/classes.txt",
"dataset.test_dataset.images_dir": f"{S3_EVAL}/images_test.tar.gz",
}distill (mandatory data sources):
python
{
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
}evaluate (mandatory data sources):
python
{
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.classes_file": f"{S3_EVAL}/classes.txt",
"dataset.test_dataset.images_dir": f"{S3_EVAL}/images_test.tar.gz",
}quantize (mandatory data sources):
python
{
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.quant_calibration_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
}数据源覆盖配置对每个操作都是必填项——Agent必须根据上述“各操作的数据集要求”表格构建数据源路径,并将其包含在中。
spec_overridespython
S3_TRAIN = "s3://bucket/data/train"
S3_EVAL = "s3://bucket/data/eval"train(必填数据源):
python
{
"train.num_epochs": 2,
"train.validation_interval": 2,
"train.checkpoint_interval": 2,
"train.num_gpus": 1,
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
}export(必填数据源):
python
{
"export.input_height": 224,
"export.input_width": 224,
"dataset.root_dir": f"{S3_TRAIN}",
}gen_trt_engine:
python
{
"gen_trt_engine.tensorrt.data_type": "fp16",
}inference(必填数据源):
python
{
"dataset.batch_size": 1,
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.classes_file": f"{S3_EVAL}/classes.txt",
"dataset.test_dataset.images_dir": f"{S3_EVAL}/images_test.tar.gz",
}distill(必填数据源):
python
{
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
}evaluate(必填数据源):
python
{
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.classes_file": f"{S3_EVAL}/classes.txt",
"dataset.test_dataset.images_dir": f"{S3_EVAL}/images_test.tar.gz",
}quantize(必填数据源):
python
{
"dataset.train_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
"dataset.classes_file": f"{S3_TRAIN}/classes.txt",
"dataset.val_dataset.images_dir": f"{S3_EVAL}/images_val.tar.gz",
"dataset.quant_calibration_dataset.images_dir": f"{S3_TRAIN}/images_train.tar.gz",
}Eval Dataset
评估数据集(Eval Dataset)
Optional. Validation images are provided as a separate tar alongside training images.
可选。验证图像作为单独的压缩包与训练图像一起提供。
Important Parameters
重要参数
- dataset.num_classes: Number of classes. Default 20. Must match the number of subdirectories in your image tarballs.
- model.backbone.type: Default fan_small_12_p4_hybrid. Supported backbones and their head in_channels (from model_params_mapping.py): FAN: fan_tiny, fan_small_12_p4_hybrid, fan_base_16_p4_hybrid, fan_large_16_p4_hybrid. GCViT: gcvit_tiny through gcvit_large. FasterViT: fastervit_0 through fastervit_6. ViT/EVA/DINO: vit_large_patch14_dinov2, eva02_large_patch14, etc. SigLIP-CLIPA: ViT-H-14-SigLIP-CLIPA-224, etc. Some backbones require non-default input resolution (384, 512, 768).
- dataset.classes_file: Path to classes.txt listing class names.
- train.optim.lr: Learning rate. Default 6e-5.
- dataset.img_size: Input image size. Default 224.
- dataset.batch_size: Per-GPU batch size. Default 8.
- dataset.num_classes:类别数量,默认值为20。必须与图像压缩包中的子目录数量匹配。
- model.backbone.type:默认值为fan_small_12_p4_hybrid。支持的骨干网络及其头部输入通道(来自model_params_mapping.py):FAN: fan_tiny、fan_small_12_p4_hybrid、fan_base_16_p4_hybrid、fan_large_16_p4_hybrid;GCViT: gcvit_tiny至gcvit_large;FasterViT: fastervit_0至fastervit_6;ViT/EVA/DINO: vit_large_patch14_dinov2、eva02_large_patch14等;SigLIP-CLIPA: ViT-H-14-SigLIP-CLIPA-224等。部分骨干网络需要非默认的输入分辨率(384、512、768)。
- dataset.classes_file:列出类别的classes.txt文件路径。
- train.optim.lr:学习率,默认值为6e-5。
- dataset.img_size:输入图像尺寸,默认值为224。
- dataset.batch_size:单GPU批次大小,默认值为8。
Multi-GPU / Multi-Node
多GPU / 多节点
Launch method: Lightning-managed (single process, Lightning spawns workers).
python| Spec Key | Description | Default |
|---|---|---|
| Number of GPUs | 1 |
| GPU device indices | [0] |
| Number of nodes | 1 |
- Multi-GPU strategy:
ddp_find_unused_parameters_true - No fsdp support
Multi-node env vars (set by orchestrator): , , , , .
WORLD_SIZENODE_RANKMASTER_ADDRMASTER_PORTNUM_GPU_PER_NODE启动方式: Lightning托管(单个进程,Lightning生成工作进程)。
python| 规格键 | 描述 | 默认值 |
|---|---|---|
| GPU数量 | 1 |
| GPU设备索引 | [0] |
| 节点数量 | 1 |
- 多GPU策略:
ddp_find_unused_parameters_true - 不支持fsdp
多节点环境变量(由编排器设置):、、、、。
WORLD_SIZENODE_RANKMASTER_ADDRMASTER_PORTNUM_GPU_PER_NODEHardware
硬件要求
Minimum 1 GPU(s), recommended 2 GPU(s). 16GB+ (V100 or A100) VRAM per GPU. Classification is generally lightweight. Most backbones at 224x224 fit well on 16GB GPUs with batch_size=8.
最少需要1块GPU,推荐2块GPU。每块GPU需配备16GB及以上显存(V100或A100)。分类任务通常轻量化,大多数骨干网络在224x224分辨率下,批次大小设为8时可适配16GB显存的GPU。
Error Patterns
错误模式
CUDA out of memory: Reduce batch_size or use a smaller backbone.
num_classes mismatch: Ensure dataset.num_classes matches the actual class directories in your image tarballs and classes.txt.
Empty class directory: Every class in classes.txt must have at least one image in the corresponding subdirectory.
CUDA内存不足:减小批次大小或使用更小的骨干网络。
num_classes不匹配:确保dataset.num_classes与图像压缩包和classes.txt中的实际类别目录数量一致。
类别目录为空:classes.txt中的每个类别在对应的子目录中必须至少包含一张图像。
Spec Param / Parent Model Inference
规格参数 / 父模型推理
Model-specific inference mappings belong in this MD file, not in . Generated runners should read this section and apply the mappings with SDK helpers before . This mirrors the old microservices flow.
config.jsoncreate_job()infer_params.pyInference mappings from TAO Core :
classification_pyt.config.json| Action | Spec Field | Inference Function | Meaning |
|---|---|---|---|
| distill | | | model file inferred from the parent job results folder |
| distill | | | current job results directory |
| evaluate | | | model file inferred from the parent job results folder |
| evaluate | | | current job results directory |
| export | | | model file inferred from the parent job results folder |
| export | | | output ONNX path |
| export | | | current job results directory |
| gen_trt_engine | | | model file inferred from the parent job results folder |
| gen_trt_engine | | | output TensorRT engine path |
| gen_trt_engine | | | current job results directory |
| inference | | | model file inferred from the parent job results folder |
| inference | | | model file inferred from the parent job results folder |
| inference | | | current job results directory |
| quantize | | | model file inferred from the parent job results folder |
| quantize | | | current job results directory |
| train | | | PTM when no resume checkpoint exists |
| train | | | current job results directory |
| train | | | PTM when no resume checkpoint exists |
| train | | | model file inferred from the current job results folder |
For or , pass the upstream train/export/AutoML child job id as . The SDK lists the parent result folder, filters checkpoint artifacts, and returns the selected model file or folder. Do not add these mappings back to and do not patch generated runner scripts to guess checkpoint paths.
parent_modelparent_model_folderparent_job_idconfig.json模型特定的推理映射应放在此MD文件中,而非。生成的运行器应读取本节内容,并在调用前使用SDK助手应用映射。这与旧微服务的流程一致。
config.jsoncreate_job()infer_params.py来自TAO Core 的推理映射:
classification_pyt.config.json| 操作 | 规格字段 | 推理函数 | 含义 |
|---|---|---|---|
| distill | | | 从父任务结果文件夹推断出的模型文件 |
| distill | | | 当前任务结果目录 |
| evaluate | | | 从父任务结果文件夹推断出的模型文件 |
| evaluate | | | 当前任务结果目录 |
| export | | | 从父任务结果文件夹推断出的模型文件 |
| export | | | 输出ONNX路径 |
| export | | | 当前任务结果目录 |
| gen_trt_engine | | | 从父任务结果文件夹推断出的模型文件 |
| gen_trt_engine | | | 输出TensorRT引擎路径 |
| gen_trt_engine | | | 当前任务结果目录 |
| inference | | | 从父任务结果文件夹推断出的模型文件 |
| inference | | | 从父任务结果文件夹推断出的模型文件 |
| inference | | | 当前任务结果目录 |
| quantize | | | 从父任务结果文件夹推断出的模型文件 |
| quantize | | | 当前任务结果目录 |
| train | | | 无恢复检查点时的预训练模型(PTM) |
| train | | | 当前任务结果目录 |
| train | | | 无恢复检查点时的预训练模型(PTM) |
| train | | | 从当前任务结果文件夹推断出的模型文件 |
对于或,将上游训练/导出/AutoML子任务ID作为传入。SDK会列出父结果文件夹,过滤检查点工件,并返回选定的模型文件或文件夹。不要将这些映射添加回,也不要修改生成的运行器脚本以猜测检查点路径。
parent_modelparent_model_folderparent_job_idconfig.jsonDeployment
部署
- tao-deploy-image-classification — Classification PyT deploy workflow for TensorRT engine generation, TensorRT evaluation, and TensorRT inference using TAO Deploy.
- tao-deploy-image-classification — 使用TAO Deploy进行TensorRT引擎生成、TensorRT评估和TensorRT推理的Classification PyT部署工作流。