tao-train-metric-learning-recognition
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseML Recog
ML Recog
Metric learning recognition for fine-grained visual recognition. Learns embeddings for retrieval-based matching (e.g., retail product recognition). Uses triplet/contrastive losses.
Set model.pretrained_model_path for pretrained backbone.
For TAO Deploy TensorRT actions (, TensorRT , and TensorRT ), read first. Deploy spec templates live in this skill's folder with the prefix.
gen_trt_engineevaluateinferencereferences/tao-deploy-metric-learning-recognition.mdreferences/spec_template_deploy_*.yaml基于度量学习的识别,用于细粒度视觉识别。学习用于基于检索的匹配(例如零售商品识别)的嵌入向量,采用三元组损失/对比损失。
设置model.pretrained_model_path以指定预训练骨干网络。
对于TAO Deploy TensorRT动作(、TensorRT 和TensorRT ),请先阅读。部署配置模板位于该技能的文件夹中,前缀为。
gen_trt_engineevaluateinferencereferences/tao-deploy-metric-learning-recognition.mdreferences/spec_template_deploy_*.yamlDataclass Schemas
数据类模式
Generated TAO Core schemas are packaged in , with listing available actions. Each generated schema also emits from the schema top-level field. AutoML enablement is declared at the model layer in via . Runnable AutoML still requires and to exist and parse. Use the packaged train schema for , , defaults, min/max bounds, enums, option weights, math conditions, dependencies, and popular parameters. Do not expect at runtime; maintainers regenerate schemas/templates before packaging the skill bank.
schemas/<action>.schema.jsonschemas/manifest.jsonreferences/spec_template_<action>.yamldefaultreferences/skill_info.yamlautoml_enabledschemas/train.schema.jsonreferences/spec_template_train.yamlautoml_default_parametersautoml_disabled_parameters~/tao-core生成的TAO Core模式打包在中,列出了可用动作。每个生成的模式还会从模式顶层字段生成。AutoML支持在的模型层通过声明。可运行的AutoML仍要求和存在且可解析。使用打包的训练模式来配置、、默认值、最小/最大边界、枚举、选项权重、数学条件、依赖关系和常用参数。运行时不要依赖;维护人员会在打包技能库前重新生成模式/模板。
schemas/<action>.schema.jsonschemas/manifest.jsondefaultreferences/spec_template_<action>.yamlreferences/skill_info.yamlautoml_enabledschemas/train.schema.jsonreferences/spec_template_train.yamlautoml_default_parametersautoml_disabled_parameters~/tao-coreTrain Action Policy
训练动作策略
This model is AutoML-enabled at the model layer. Before handling any train-stage request, read and resolve the run override from either an explicit value or the user's workflow request. Treat phrases like "turn off AutoML", "disable AutoML", "no HPO", or "plain training" as for this run only; otherwise default to . When , , and both and are packaged, route the train action through by default with this model's . Preserve workflow/application overrides for datasets, specs, output directories, GPU/platform settings, parent checkpoints, and . Use direct model training only when or the packaged train schema/template is missing; in the missing-schema case, report that AutoML is enabled but not runnable for this model until schemas are generated.
references/skill_info.yamlautoml_policyautoml_policy: offautoautoml_policy: autoautoml_enabled: trueschemas/train.schema.jsonreferences/spec_template_train.yamltao-skill-bank:tao-run-automlskill_dirautoml_policyautoml_policy: offNon-train actions such as , , , and deploy flows stay in this model skill. The per-run override does not change model metadata.
evaluateinferenceexportautoml_policy该模型在模型层支持AutoML。处理任何训练阶段请求前,请读取,并通过显式的值或用户的工作流请求确定运行覆盖配置。将“turn off AutoML”、“disable AutoML”、“no HPO”或“plain training”等短语视为本次运行的;否则默认使用。当、且和已打包时,默认将训练动作通过路由,并传入该模型的。保留数据集、配置、输出目录、GPU/平台设置、父检查点和的工作流/应用覆盖配置。仅当或打包的训练模式/模板缺失时,才使用直接模型训练;若模式缺失,需报告该模型已启用AutoML但无法运行,直至生成模式。
references/skill_info.yamlautoml_policyautoml_policy: offautoautoml_policy: autoautoml_enabled: trueschemas/train.schema.jsonreferences/spec_template_train.yamltao-skill-bank:tao-run-automlskill_dirautoml_policyautoml_policy: off非训练动作(如、、和部署流程)仍在该模型技能中执行。每次运行的覆盖配置不会更改模型元数据。
evaluateinferenceexportautoml_policyTraining Requirements
训练要求
- Dataset type: ml_recog
- Formats: default
- Monitoring metric: val Precision at Rank 1
- 数据集类型: ml_recog
- 格式: default
- 监控指标: 验证集Rank 1精度
Per-Action Dataset Requirements
各动作的数据集要求
| Action | Spec Key | Source | Files | List? |
|---|---|---|---|---|
| evaluate | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz, query: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz | No |
| gen_trt_engine | gen_trt_engine.tensorrt.calibration.cal_image_dir | calibration_dataset | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/test.tar.gz | Yes |
| inference | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz, query: | No |
| inference | inference.input_path | train_datasets | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz | No |
| train | dataset.train_dataset | train_datasets | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/train.tar.gz | No |
| train | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/reference.tar.gz, query: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/val.tar.gz | No |
| 动作 | 配置键 | 来源 | 文件 | 是否为列表? |
|---|---|---|---|---|
| evaluate | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz, query: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz | 否 |
| gen_trt_engine | gen_trt_engine.tensorrt.calibration.cal_image_dir | calibration_dataset | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/test.tar.gz | 是 |
| inference | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz, query: | 否 |
| inference | inference.input_path | train_datasets | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz | 否 |
| train | dataset.train_dataset | train_datasets | metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/train.tar.gz | 否 |
| train | dataset.val_dataset | train_datasets | reference: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/reference.tar.gz, query: metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/val.tar.gz | 否 |
Typical Spec Overrides
典型配置覆盖
Data source overrides are mandatory for every action — the agent MUST construct data source paths from the Per-Action Dataset Requirements table above and include them in .
spec_overridespython
S3_TRAIN = "s3://bucket/data/train"train (mandatory data sources):
python
{
"train.num_epochs": 30,
"train.checkpoint_interval": 10,
"train.validation_interval": 10,
"train.num_gpus": 1,
"dataset.train_dataset": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/train.tar.gz",
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/reference.tar.gz", "query": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/val.tar.gz"},
}gen_trt_engine (mandatory data sources):
python
{
"gen_trt_engine.tensorrt.data_type": "INT8",
"gen_trt_engine.tensorrt.calibration.cal_image_dir": [f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/test.tar.gz"],
}evaluate (mandatory data sources):
python
{
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz", "query": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz"},
}inference (mandatory data sources):
python
{
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz"},
"inference.input_path": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz",
}数据源覆盖对每个动作都是必填项——Agent必须根据上述各动作数据集要求表构建数据源路径,并将其包含在中。
spec_overridespython
S3_TRAIN = "s3://bucket/data/train"train(必填数据源):
python
{
"train.num_epochs": 30,
"train.checkpoint_interval": 10,
"train.validation_interval": 10,
"train.num_gpus": 1,
"dataset.train_dataset": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/train.tar.gz",
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/reference.tar.gz", "query": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/val.tar.gz"},
}gen_trt_engine(必填数据源):
python
{
"gen_trt_engine.tensorrt.data_type": "INT8",
"gen_trt_engine.tensorrt.calibration.cal_image_dir": [f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/known_classes/test.tar.gz"],
}evaluate(必填数据源):
python
{
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz", "query": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz"},
}inference(必填数据源):
python
{
"dataset.val_dataset": {"reference": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/reference.tar.gz"},
"inference.input_path": f"{S3_TRAIN}/metric_learning_recognition/retail-product-checkout-dataset_classification_demo/unknown_classes/test.tar.gz",
}Eval Dataset
评估数据集
Required. Evaluation requires reference and query datasets for retrieval metrics.
必填项。评估需要参考数据集和查询数据集来计算检索指标。
Important Parameters
重要参数
- model.backbone: Default resnet_50. Options: resnet_50, resnet_101, fan_small, fan_base, fan_large, fan_tiny, nvdinov2_vit_large_legacy.
- model.feat_dim: Embedding dimension. Default 256. Output feature vector size for similarity matching.
- train.batch_size: Per-GPU batch size. Default 4. val_batch_size also 4.
- dataset.num_instance: Instances per identity in a batch (P/K sampling). Default 4. Controls how many images of the same class appear together.
- train.optim.trunk.base_lr: Learning rate for the trunk (backbone). Default 3.5e-4 (Adam).
- train.optim.embedder.base_lr: Learning rate for the embedding head. Default 3.5e-4.
- train.optim.triplet_loss_margin: Margin for triplet loss. Default 0.3. smooth_loss=True by default.
- train.optim.miner_function_margin: Hard mining margin. Default 0.1. Controls pair mining difficulty.
- train.optim.steps: LR decay steps. Default [40, 70] with gamma=0.1.
- dataset.train_dataset: Path to training images organized in class folders.
- dataset.val_dataset: Dict with 'reference' and 'query' keys pointing to ImageNet-format directories for retrieval evaluation.
- model.backbone: 默认值为resnet_50。可选值:resnet_50, resnet_101, fan_small, fan_base, fan_large, fan_tiny, nvdinov2_vit_large_legacy。
- model.feat_dim: 嵌入向量维度。默认值256。用于相似度匹配的输出特征向量大小。
- train.batch_size: 单GPU批次大小。默认值4。val_batch_size同样为4。
- dataset.num_instance: 一个批次中每个类别的样本数量(P/K采样)。默认值4。控制同一类别的图像在批次中出现的数量。
- train.optim.trunk.base_lr: 骨干网络的学习率。默认值3.5e-4(Adam优化器)。
- train.optim.embedder.base_lr: 嵌入头的学习率。默认值3.5e-4。
- train.optim.triplet_loss_margin: 三元组损失的边界值。默认值0.3。默认开启smooth_loss=True。
- train.optim.miner_function_margin: 难样本挖掘的边界值。默认值0.1。控制样本对挖掘的难度。
- train.optim.steps: 学习率衰减步骤。默认值[40, 70],gamma=0.1。
- dataset.train_dataset: 按类别文件夹组织的训练图像路径。
- dataset.val_dataset: 包含'reference'和'query'键的字典,指向用于检索评估的ImageNet格式目录。
Multi-GPU / Multi-Node
多GPU / 多节点
Launch method: Lightning-managed (single process, Lightning spawns workers).
python| Spec Key | Description | Default |
|---|---|---|
| Number of GPUs | 1 |
| GPU device indices | [0] |
- Strategy: (Lightning picks best strategy automatically)
auto - No explicit or
num_nodesconfig — single-node orienteddistributed_strategy
启动方式: Lightning管理(单个进程,Lightning生成工作进程)。
python| 配置键 | 描述 | 默认值 |
|---|---|---|
| GPU数量 | 1 |
| GPU设备索引 | [0] |
- 策略: (Lightning自动选择最佳策略)
auto - 无显式或
num_nodes配置——面向单节点场景distributed_strategy
Hardware
硬件要求
Minimum 1 GPU(s), recommended 2 GPU(s). 16GB+ VRAM per GPU. Metric learning benefits from larger batch sizes for better triplet sampling but is otherwise moderate on memory.
最低要求1块GPU,推荐2块GPU。每块GPU需16GB以上显存。度量学习受益于更大的批次大小以优化三元组采样,但对内存的其他要求适中。
Error Patterns
错误模式
Reference/query mismatch: Ensure reference and query datasets share compatible class namespaces for evaluation.
参考/查询数据集不匹配: 确保评估时参考数据集和查询数据集的类别命名空间兼容。
Spec Param / Parent Model Inference
配置参数 / 父模型推理
Model-specific inference mappings belong in this MD file, not in . Generated runners should read this section and apply the mappings with SDK helpers before . This mirrors the old microservices flow.
config.jsoncreate_job()infer_params.pyInference mappings from TAO Core :
ml_recog.config.json| Action | Spec Field | Inference Function | Meaning |
|---|---|---|---|
| evaluate | | | model file inferred from the parent job results folder |
| evaluate | | | model file inferred from the parent job results folder |
| evaluate | | | current job results directory |
| export | | | model file inferred from the parent job results folder |
| export | | | output ONNX path |
| export | | | current job results directory |
| gen_trt_engine | | | model file inferred from the parent job results folder |
| gen_trt_engine | | | calibration cache path |
| gen_trt_engine | | | output TensorRT engine path |
| gen_trt_engine | | | current job results directory |
| inference | | | model file inferred from the parent job results folder |
| inference | | | model file inferred from the parent job results folder |
| inference | | | current job results directory |
| train | | | PTM when no resume checkpoint exists |
| train | | | current job results directory |
| train | | | model file inferred from the current job results folder |
For or , pass the upstream train/export/AutoML child job id as . The SDK lists the parent result folder, filters checkpoint artifacts, and returns the selected model file or folder. Do not add these mappings back to and do not patch generated runner scripts to guess checkpoint paths.
parent_modelparent_model_folderparent_job_idconfig.json模型特定的推理映射应记录在此MD文件中,而非。生成的运行器应读取本节内容,并在前使用SDK助手应用映射。这与旧微服务的流程一致。
config.jsoncreate_job()infer_params.py来自TAO Core 的推理映射:
ml_recog.config.json| 动作 | 配置字段 | 推理函数 | 含义 |
|---|---|---|---|
| evaluate | | | 从父作业结果文件夹推断出的模型文件 |
| evaluate | | | 从父作业结果文件夹推断出的模型文件 |
| evaluate | | | 当前作业的结果目录 |
| export | | | 从父作业结果文件夹推断出的模型文件 |
| export | | | 输出ONNX路径 |
| export | | | 当前作业的结果目录 |
| gen_trt_engine | | | 从父作业结果文件夹推断出的模型文件 |
| gen_trt_engine | | | 校准缓存路径 |
| gen_trt_engine | | | 输出TensorRT引擎路径 |
| gen_trt_engine | | | 当前作业的结果目录 |
| inference | | | 从父作业结果文件夹推断出的模型文件 |
| inference | | | 从父作业结果文件夹推断出的模型文件 |
| inference | | | 当前作业的结果目录 |
| train | | | 无恢复检查点时使用的预训练模型 |
| train | | | 当前作业的结果目录 |
| train | | | 从当前作业结果文件夹推断出的模型文件 |
对于或,将上游训练/导出/AutoML子作业ID作为传入。SDK会列出父结果文件夹,过滤检查点工件,并返回选定的模型文件或文件夹。请勿将这些映射添加回,也不要修改生成的运行器脚本以猜测检查点路径。
parent_modelparent_model_folderparent_job_idconfig.jsonDeployment
部署
- tao-deploy-metric-learning-recognition — MLRecog deploy workflow for TensorRT engine generation, TensorRT evaluation, and TensorRT inference using TAO Deploy.
- tao-deploy-metric-learning-recognition —— 用于TensorRT引擎生成、TensorRT评估和TensorRT推理的MLRecog部署工作流,基于TAO Deploy实现。