segment-anything-model
Compare original and translation side by side
🇺🇸
Original
English🇨🇳
Translation
ChineseSegment Anything Model (SAM)
Segment Anything Model(SAM)
Comprehensive guide to using Meta AI's Segment Anything Model for zero-shot image segmentation.
Meta AI推出的Segment Anything Model零样本图像分割使用全指南。
When to use SAM
何时使用SAM
Use SAM when:
- Need to segment any object in images without task-specific training
- Building interactive annotation tools with point/box prompts
- Generating training data for other vision models
- Need zero-shot transfer to new image domains
- Building object detection/segmentation pipelines
- Processing medical, satellite, or domain-specific images
Key features:
- Zero-shot segmentation: Works on any image domain without fine-tuning
- Flexible prompts: Points, bounding boxes, or previous masks
- Automatic segmentation: Generate all object masks automatically
- High quality: Trained on 1.1 billion masks from 11 million images
- Multiple model sizes: ViT-B (fastest), ViT-L, ViT-H (most accurate)
- ONNX export: Deploy in browsers and edge devices
Use alternatives instead:
- YOLO/Detectron2: For real-time object detection with classes
- Mask2Former: For semantic/panoptic segmentation with categories
- GroundingDINO + SAM: For text-prompted segmentation
- SAM 2: For video segmentation tasks
在以下场景使用SAM:
- 需要在不进行任务特定训练的情况下分割图像中的任意对象
- 构建带有点/框提示的交互式标注工具
- 为其他视觉模型生成训练数据
- 需要向新图像领域进行零样本迁移
- 构建目标检测/分割流水线
- 处理医学、卫星或特定领域的图像
核心特性:
- 零样本分割:无需微调即可在任意图像领域工作
- 灵活的提示方式:支持点、边界框或之前的掩码
- 自动分割:自动生成所有对象的掩码
- 高质量:基于1100万张图像中的11亿个掩码训练而成
- 多种模型尺寸:ViT-B(速度最快)、ViT-L、ViT-H(精度最高)
- ONNX导出:可在浏览器和边缘设备部署
以下场景可使用替代方案:
- YOLO/Detectron2:用于带类别信息的实时目标检测
- Mask2Former:用于带类别的语义/全景分割
- GroundingDINO + SAM:用于文本提示的分割
- SAM 2:用于视频分割任务
Quick start
快速开始
Installation
安装
bash
undefinedbash
undefinedFrom GitHub
从GitHub安装
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install git+https://github.com/facebookresearch/segment-anything.git
Optional dependencies
可选依赖
pip install opencv-python pycocotools matplotlib
pip install opencv-python pycocotools matplotlib
Or use HuggingFace transformers
或使用HuggingFace Transformers
pip install transformers
undefinedpip install transformers
undefinedDownload checkpoints
下载权重文件
bash
undefinedbash
undefinedViT-H (largest, most accurate) - 2.4GB
ViT-H(最大,精度最高)- 2.4GB
ViT-L (medium) - 1.2GB
ViT-L(中等)- 1.2GB
ViT-B (smallest, fastest) - 375MB
ViT-B(最小,速度最快)- 375MB
Basic usage with SamPredictor
使用SamPredictor的基础用法
python
import numpy as np
from segment_anything import sam_model_registry, SamPredictorpython
import numpy as np
from segment_anything import sam_model_registry, SamPredictorLoad model
加载模型
sam = sam_model_registry"vit_h"
sam.to(device="cuda")
sam = sam_model_registry"vit_h"
sam.to(device="cuda")
Create predictor
创建预测器
predictor = SamPredictor(sam)
predictor = SamPredictor(sam)
Set image (computes embeddings once)
设置图像(仅计算一次嵌入)
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
image = cv2.imread("image.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
Predict with point prompts
使用点提示进行预测
input_point = np.array([[500, 375]]) # (x, y) coordinates
input_label = np.array([1]) # 1 = foreground, 0 = background
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True # Returns 3 mask options
)
input_point = np.array([[500, 375]]) # (x, y)坐标
input_label = np.array([1]) # 1 = 前景,0 = 背景
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True # 返回3种掩码选项
)
Select best mask
选择最佳掩码
best_mask = masks[np.argmax(scores)]
undefinedbest_mask = masks[np.argmax(scores)]
undefinedHuggingFace Transformers
HuggingFace Transformers用法
python
import torch
from PIL import Image
from transformers import SamModel, SamProcessorpython
import torch
from PIL import Image
from transformers import SamModel, SamProcessorLoad model and processor
加载模型和处理器
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
model.to("cuda")
Process image with point prompt
处理带点提示的图像
image = Image.open("image.jpg")
input_points = [[[450, 600]]] # Batch of points
inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
image = Image.open("image.jpg")
input_points = [[[450, 600]]] # 点的批次
inputs = processor(image, input_points=input_points, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
Generate masks
生成掩码
with torch.no_grad():
outputs = model(**inputs)
with torch.no_grad():
outputs = model(**inputs)
Post-process masks to original size
将掩码后处理为原始尺寸
masks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
undefinedmasks = processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
undefinedCore concepts
核心概念
Model architecture
模型架构
SAM Architecture:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(computed once) (per prompt) predictionsSAM架构:
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Image Encoder │────▶│ Prompt Encoder │────▶│ Mask Decoder │
│ (ViT) │ │ (Points/Boxes) │ │ (Transformer) │
└─────────────────┘ └─────────────────┘ └─────────────────┘
│ │ │
Image Embeddings Prompt Embeddings Masks + IoU
(仅计算一次) (每个提示对应一次) 预测结果Model variants
模型变体
| Model | Checkpoint | Size | Speed | Accuracy |
|---|---|---|---|---|
| ViT-H | | 2.4 GB | Slowest | Best |
| ViT-L | | 1.2 GB | Medium | Good |
| ViT-B | | 375 MB | Fastest | Good |
| 模型 | 权重标识 | 大小 | 速度 | 精度 |
|---|---|---|---|---|
| ViT-H | | 2.4 GB | 最慢 | 最佳 |
| ViT-L | | 1.2 GB | 中等 | 良好 |
| ViT-B | | 375 MB | 最快 | 良好 |
Prompt types
提示类型
| Prompt | Description | Use Case |
|---|---|---|
| Point (foreground) | Click on object | Single object selection |
| Point (background) | Click outside object | Exclude regions |
| Bounding box | Rectangle around object | Larger objects |
| Previous mask | Low-res mask input | Iterative refinement |
| 提示方式 | 描述 | 使用场景 |
|---|---|---|
| 点(前景) | 点击目标对象 | 单个对象选择 |
| 点(背景) | 点击对象外部区域 | 排除指定区域 |
| 边界框 | 围绕对象绘制矩形 | 较大的对象 |
| 之前的掩码 | 低分辨率掩码输入 | 迭代优化 |
Interactive segmentation
交互式分割
Point prompts
点提示
python
undefinedpython
undefinedSingle foreground point
单个前景点
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
input_point = np.array([[500, 375]])
input_label = np.array([1])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
multimask_output=True
)
Multiple points (foreground + background)
多个点(前景+背景)
input_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2 foreground, 1 background
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # Single mask when prompts are clear
)
undefinedinput_points = np.array([[500, 375], [600, 400], [450, 300]])
input_labels = np.array([1, 1, 0]) # 2个前景点,1个背景点
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False # 当提示明确时返回单个掩码
)
undefinedBox prompts
框提示
python
undefinedpython
undefinedBounding box [x1, y1, x2, y2]
边界框 [x1, y1, x2, y2]
input_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
undefinedinput_box = np.array([425, 600, 700, 875])
masks, scores, logits = predictor.predict(
box=input_box,
multimask_output=False
)
undefinedCombined prompts
组合提示
python
undefinedpython
undefinedBox + points for precise control
框+点实现精确控制
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
undefinedmasks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
box=np.array([400, 300, 700, 600]),
multimask_output=False
)
undefinedIterative refinement
迭代优化
python
undefinedpython
undefinedInitial prediction
初始预测
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
Refine with additional point using previous mask
使用之前的掩码结合额外点进行优化
masks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # Add background point
mask_input=logits[np.argmax(scores)][None, :, :], # Use best mask
multimask_output=False
)
undefinedmasks, scores, logits = predictor.predict(
point_coords=np.array([[500, 375], [550, 400]]),
point_labels=np.array([1, 0]), # 添加背景点
mask_input=logits[np.argmax(scores)][None, :, :], # 使用最佳掩码
multimask_output=False
)
undefinedAutomatic mask generation
自动掩码生成
Basic automatic segmentation
基础自动分割
python
from segment_anything import SamAutomaticMaskGeneratorpython
from segment_anything import SamAutomaticMaskGeneratorCreate generator
创建生成器
mask_generator = SamAutomaticMaskGenerator(sam)
mask_generator = SamAutomaticMaskGenerator(sam)
Generate all masks
生成所有掩码
masks = mask_generator.generate(image)
masks = mask_generator.generate(image)
Each mask contains:
每个掩码包含以下信息:
- segmentation: binary mask
- segmentation: 二值掩码
- bbox: [x, y, w, h]
- bbox: [x, y, w, h]
- area: pixel count
- area: 像素数量
- predicted_iou: quality score
- predicted_iou: 质量得分
- stability_score: robustness score
- stability_score: 鲁棒性得分
- point_coords: generating point
- point_coords: 生成时使用的点
undefinedundefinedCustomized generation
自定义生成
python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # Grid density (more = more masks)
pred_iou_thresh=0.88, # Quality threshold
stability_score_thresh=0.95, # Stability threshold
crop_n_layers=1, # Multi-scale crops
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # Remove tiny masks
)
masks = mask_generator.generate(image)python
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=32, # 网格密度(值越大,掩码数量越多)
pred_iou_thresh=0.88, # 质量阈值
stability_score_thresh=0.95, # 鲁棒性阈值
crop_n_layers=1, # 多尺度裁剪
crop_n_points_downscale_factor=2,
min_mask_region_area=100, # 移除小掩码
)
masks = mask_generator.generate(image)Filtering masks
掩码过滤
python
undefinedpython
undefinedSort by area (largest first)
按面积排序(从大到小)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
Filter by predicted IoU
按预测IoU过滤
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
high_quality = [m for m in masks if m['predicted_iou'] > 0.9]
Filter by stability score
按稳定性得分过滤
stable_masks = [m for m in masks if m['stability_score'] > 0.95]
undefinedstable_masks = [m for m in masks if m['stability_score'] > 0.95]
undefinedBatched inference
批量推理
Multiple images
多图像处理
python
undefinedpython
undefinedProcess multiple images efficiently
高效处理多张图像
images = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
undefinedimages = [cv2.imread(f"image_{i}.jpg") for i in range(10)]
all_masks = []
for image in images:
predictor.set_image(image)
masks, _, _ = predictor.predict(
point_coords=np.array([[500, 375]]),
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks)
undefinedMultiple prompts per image
单图像多提示处理
python
undefinedpython
undefinedProcess multiple prompts efficiently (one image encoding)
高效处理多个提示(仅需一次图像编码)
predictor.set_image(image)
predictor.set_image(image)
Batch of point prompts
点提示批次
points = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
undefinedpoints = [
np.array([[100, 100]]),
np.array([[200, 200]]),
np.array([[300, 300]])
]
all_masks = []
for point in points:
masks, scores, _ = predictor.predict(
point_coords=point,
point_labels=np.array([1]),
multimask_output=True
)
all_masks.append(masks[np.argmax(scores)])
undefinedONNX deployment
ONNX部署
Export model
导出模型
bash
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-maskbash
python scripts/export_onnx_model.py \
--checkpoint sam_vit_h_4b8939.pth \
--model-type vit_h \
--output sam_onnx.onnx \
--return-single-maskUse ONNX model
使用ONNX模型
python
import onnxruntimepython
import onnxruntimeLoad ONNX model
加载ONNX模型
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
ort_session = onnxruntime.InferenceSession("sam_onnx.onnx")
Run inference (image embeddings computed separately)
运行推理(图像嵌入需单独计算)
masks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
undefinedmasks = ort_session.run(
None,
{
"image_embeddings": image_embeddings,
"point_coords": point_coords,
"point_labels": point_labels,
"mask_input": np.zeros((1, 1, 256, 256), dtype=np.float32),
"has_mask_input": np.array([0], dtype=np.float32),
"orig_im_size": np.array([h, w], dtype=np.float32)
}
)
undefinedCommon workflows
常见工作流
Workflow 1: Annotation tool
工作流1:标注工具
python
import cv2python
import cv2Load model
加载模型
predictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# Foreground point
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# Display best mask
display_mask(masks[np.argmax(scores)])
undefinedpredictor = SamPredictor(sam)
predictor.set_image(image)
def on_click(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# 前景点
masks, scores, _ = predictor.predict(
point_coords=np.array([[x, y]]),
point_labels=np.array([1]),
multimask_output=True
)
# 显示最佳掩码
display_mask(masks[np.argmax(scores)])
undefinedWorkflow 2: Object extraction
工作流2:对象提取
python
def extract_object(image, point):
"""Extract object at point with transparent background."""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# Create RGBA output
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgbapython
def extract_object(image, point):
"""提取指定点处的对象,背景透明。"""
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=np.array([point]),
point_labels=np.array([1]),
multimask_output=True
)
best_mask = masks[np.argmax(scores)]
# 创建RGBA输出
rgba = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
rgba[:, :, :3] = image
rgba[:, :, 3] = best_mask * 255
return rgbaWorkflow 3: Medical image segmentation
工作流3:医学图像分割
python
undefinedpython
undefinedProcess medical images (grayscale to RGB)
处理医学图像(灰度图转RGB)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
medical_image = cv2.imread("scan.png", cv2.IMREAD_GRAYSCALE)
rgb_image = cv2.cvtColor(medical_image, cv2.COLOR_GRAY2RGB)
predictor.set_image(rgb_image)
Segment region of interest
分割感兴趣区域
masks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # ROI bounding box
multimask_output=True
)
undefinedmasks, scores, _ = predictor.predict(
box=np.array([x1, y1, x2, y2]), # 感兴趣区域的边界框
multimask_output=True
)
undefinedOutput format
输出格式
Mask data structure
掩码数据结构
python
undefinedpython
undefinedSamAutomaticMaskGenerator output
SamAutomaticMaskGenerator输出
{
"segmentation": np.ndarray, # H×W binary mask
"bbox": [x, y, w, h], # Bounding box
"area": int, # Pixel count
"predicted_iou": float, # 0-1 quality score
"stability_score": float, # 0-1 robustness score
"crop_box": [x, y, w, h], # Generation crop region
"point_coords": [[x, y]], # Input point
}
undefined{
"segmentation": np.ndarray, # H×W二值掩码
"bbox": [x, y, w, h], # 边界框
"area": int, # 像素数量
"predicted_iou": float, # 0-1质量得分
"stability_score": float, # 0-1鲁棒性得分
"crop_box": [x, y, w, h], # 生成时使用的裁剪区域
"point_coords": [[x, y]], # 输入点
}
undefinedCOCO RLE format
COCO RLE格式
python
from pycocotools import mask as mask_utilspython
from pycocotools import mask as mask_utilsEncode mask to RLE
将掩码编码为RLE
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
rle["counts"] = rle["counts"].decode("utf-8")
Decode RLE to mask
将RLE解码为掩码
decoded_mask = mask_utils.decode(rle)
undefineddecoded_mask = mask_utils.decode(rle)
undefinedPerformance optimization
性能优化
GPU memory
GPU内存优化
python
undefinedpython
undefinedUse smaller model for limited VRAM
在显存有限时使用更小的模型
sam = sam_model_registry"vit_b"
sam = sam_model_registry"vit_b"
Process images in batches
批量处理图像
Clear CUDA cache between large batches
在大批次处理之间清理CUDA缓存
torch.cuda.empty_cache()
undefinedtorch.cuda.empty_cache()
undefinedSpeed optimization
速度优化
python
undefinedpython
undefinedUse half precision
使用半精度
sam = sam.half()
sam = sam.half()
Reduce points for automatic generation
减少自动生成时的点数量
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # Default is 32
)
mask_generator = SamAutomaticMaskGenerator(
model=sam,
points_per_side=16, # 默认值为32
)
Use ONNX for deployment
使用ONNX进行部署
Export with --return-single-mask for faster inference
导出时添加--return-single-mask参数以加快推理速度
undefinedundefinedCommon issues
常见问题
| Issue | Solution |
|---|---|
| Out of memory | Use ViT-B model, reduce image size |
| Slow inference | Use ViT-B, reduce points_per_side |
| Poor mask quality | Try different prompts, use box + points |
| Edge artifacts | Use stability_score filtering |
| Small objects missed | Increase points_per_side |
| 问题 | 解决方案 |
|---|---|
| 内存不足 | 使用ViT-B模型,缩小图像尺寸 |
| 推理速度慢 | 使用ViT-B,减少points_per_side的值 |
| 掩码质量差 | 尝试不同的提示方式,使用框+点组合提示 |
| 边缘伪影 | 使用stability_score进行过滤 |
| 小对象被遗漏 | 增大points_per_side的值 |
References
参考资料
- Advanced Usage - Batching, fine-tuning, integration
- Troubleshooting - Common issues and solutions
- 高级用法 - 批量处理、微调、集成
- 故障排除 - 常见问题及解决方案
Resources
资源
- GitHub: https://github.com/facebookresearch/segment-anything
- Paper: https://arxiv.org/abs/2304.02643
- Demo: https://segment-anything.com
- SAM 2 (Video): https://github.com/facebookresearch/segment-anything-2
- HuggingFace: https://huggingface.co/facebook/sam-vit-huge