一、引言
在盘算机视觉范畴,图像分割是一个核心且具有寻衅性的使命,旨在将图像中的差别物体或区域举行划分和识别,广泛应用于主动驾驶、医学影像分析、安防监控等范畴。Segment Anything Model(SAM)由 Meta AI 实验室发布,其引入了基于 Prompt 的交互式分割能力,显着提拔了图像分割的灵活性和泛化能力。
SAM 通过在海量且多样化的数据集上练习,具备处置惩罚未见过对象类别和场景的能力。这一特性使其在学术界引发广泛研究,如模型轻量化、范畴适应、多模态融合等方向;在工业界也迅速应用于医学图像分析中的肿瘤检测、遥感图像处置惩罚中的卫星图像分析以及视频处置惩罚中的目标跟踪等范畴。
对于盘算机视觉爱好者和开发者而言,深入分析 SAM 的源码有助于理解其核心技术原理,为进一步创新和应用提供基础。本文将从原理到源码徐徐解析 SAM 的工作机制,探索实在现高效图像编码、提示编码及掩码解码的过程。
二、SAM 原理速览
(一)核心概念
SAM 的核心在于其 “分割一切” 的能力,这一能力依靠于基于 Prompt 的分割计谋。Prompt 可以是点(points)、框(boxes)、掩码(masks)或文本(text),为模型提供分割目标的关键信息。例如,点击图像中的一个点,SAM 能够识别并分割该点所在物体;给定一个框,SAM 则专注于框内物体的分割。
(二)模型架构
SAM 的架构包含三个核心组件:Image Encoder(图像编码器)、Prompt Encoder(提示编码器) 和 Mask Decoder(掩码解码器),其协作方式如下图所示:
- Image Encoder:将输入图像转换为高维特性表示,通常使用预练习的 Vision Transformer(ViT),如 ViT-H、ViT-L、ViT-B。例如,对于分辨率为 1024 × 1024 1024 \times 1024 1024×1024 的图像,经过 Patch Embedding 操作划分为 16 × 16 16 \times 16 16×16 的 patches,特性图尺寸缩小为原来的 1 16 \frac{1}{16} 161,通道数从 3 映射到 768。
- Prompt Encoder:处置惩罚差别类型的 Prompt,将其编码为与图像嵌入兼容的特性。对点和框使用位置编码(Positional Encoding);对文本使用 CLIP 文本编码器;对掩码则通过轻量级卷积网络编码。
- Mask Decoder:融合图像嵌入和提示嵌入,通过 Transformer 结构解码为分割掩码,默认天生 3 个候选掩码并按置信度排序。
三、源码结构总览
(一)代码目次解析
SAM 的代码目次结构如下:
- segment-anything/
- ├── assets # 示例图片等资源
- ├── demo # 前端部署代码
- ├── notebooks # Jupyter Notebook 示例
- ├── script # 模型导出脚本
- ├── segment_anything # 核心代码目录
- │ ├── build_sam.py # 模型构建脚本
- │ ├── config.py # 配置文件
- │ ├── mask_decoder.py # 掩码解码器实现
- │ ├── model_registry.py # 模型注册模块
- │ ├── predictor.py # 预测接口
- │ ├── sam.py # SAM 整体结构
- │ ├── sam_arch.py # 架构细节
- │ ├── utils.py # 工具函数
- │ └── automatic_mask_generator.py # 自动掩码生成
- └── setup.py # 安装脚本
复制代码 segment_anything 是核心目次,后续分析将聚焦于此。
(二)关键文件与模块
- build_sam.py:界说模型构建函数,支持差别版本(如 vit_h、vit_l、vit_b)。示例代码:
- def build_sam_vit_h(checkpoint=None):
- # 构建 vit_h 版本的 SAM 模型
- return _build_sam(
- encoder_embed_dim=1280, # 编码器嵌入维度
- encoder_depth=32, # 编码器层数
- encoder_num_heads=16, # 注意力头数
- encoder_global_attn_indexes=[7, 15, 23, 31], # 全局注意力层索引
- checkpoint=checkpoint, # 预训练权重文件路径
- )
复制代码
- predictor.py:提供预测接口,set_image 处置惩罚图像预处置惩罚,predict 根据提示天生掩码。核心代码:
- def set_image(self, image: np.ndarray) -> None:
- # 检查输入图像维度和通道数是否符合要求
- if image.ndim != 3 or image.shape[2] not in [3, 4]:
- raise ValueError("Image must be 3D with 3 or 4 channels")
- # 应用图像变换(如缩放、归一化)
- input_image = self.transform.apply_image(image)
- # 转换为 PyTorch 张量并调整维度为 [1, C, H, W]
- input_image_torch = torch.as_tensor(input_image, device=self.device)
- self.set_torch_image(input_image_torch.permute(2, 0, 1)[None, :, :, :], image.shape[:2])
复制代码
- automatic_mask_generator.py:主动天生所有物体掩码,基于点提示网格。核心代码:
- def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
- # 预处理输入图像
- input_image = self.model.preprocess(image)
- # 使用无梯度计算图像嵌入
- with torch.no_grad():
- image_embedding = self.model.image_encoder(input_image)
- # 生成点提示网格
- points = self._generate_points(image.shape[:2])
- all_masks, all_scores = [], []
- # 按批次处理点提示并预测掩码
- for i in range(0, len(points), self.points_per_batch):
- batch_points = points[i:i + self.points_per_batch]
- batch_masks, batch_scores = self._predict_masks(image_embedding, batch_points)
- all_masks.extend(batch_masks)
- all_scores.extend(batch_scores)
- # 返回掩码和对应置信度列表
- return [{"segmentation": m, "score": s} for m, s in zip(all_masks, all_scores)]
复制代码 四、核心代码深度分析
(一)Image Encoder 源码解析
- Patch Embedding:将图像划分为 patches 并映射为特性向量:
- class PatchEmbed(nn.Module):
- def __init__(self, kernel_size=(16, 16), stride=(16, 16), in_chans=3, embed_dim=768):
- # 初始化 Patch Embedding 模块
- super().__init__()
- # 定义卷积层,将图像划分为 patches 并映射到嵌入维度
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size, stride)
-
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- # 对输入图像进行卷积操作,生成特征图
- x = self.proj(x)
- # 调整维度顺序为 [B, H/16, W/16, C],适配 Transformer 输入
- return x.permute(0, 2, 3, 1)
复制代码 输入图像 [ B , 3 , H , W ] [B, 3, H, W] [B,3,H,W] 转换为 [ B , H 16 , W 16 , 768 ] [B, \frac{H}{16}, \frac{W}{16}, 768] [B,16H,16W,768]。
- Transformer Encoder:堆叠多个 Transformer Block 提取特性:
- class Block(nn.Module):
- def __init__(self, dim, num_heads, mlp_ratio):
- # 初始化 Transformer Block
- super().__init__()
- # 第一层归一化
- self.norm1 = nn.LayerNorm(dim)
- # 多头注意力模块
- self.attn = Attention(dim, num_heads)
- # 第二层归一化
- self.norm2 = nn.LayerNorm(dim)
- # 多层感知机,隐藏层维度为 dim * mlp_ratio
- self.mlp = Mlp(dim, int(dim * mlp_ratio))
-
- def forward(self, x):
- # 自注意力计算并残差连接
- x = x + self.attn(self.norm1(x))
- # MLP 计算并残差连接
- x = x + self.mlp(self.norm2(x))
- return x
复制代码 (二)Prompt Encoder 源码解析
编码差别类型提示:
- class PromptEncoder(nn.Module):
- def __init__(self, embed_dim, image_embedding_size, input_image_size):
- # 初始化 Prompt Encoder 模块
- super().__init__()
- # 位置编码层,使用随机高斯矩阵生成
- self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
- # 提示嵌入层,处理点和框提示
- self.prompt_embed_layer = PromptEmbedding(embed_dim, input_image_size, image_embedding_size)
-
- def forward(self, points=None, boxes=None, masks=None):
- # 初始化稀疏嵌入张量,维度为 [batch_size, 0, embed_dim]
- sparse_embed = torch.zeros(bs, 0, self.embed_dim, device=self.device)
- if points:
- # 提取点提示的坐标和标签
- coords, labels = points
- # 生成点嵌入并拼接到稀疏嵌入中
- sparse_embed = torch.cat([sparse_embed, self.prompt_embed_layer.point_embedding(coords, labels)], dim=1)
- # 返回稀疏嵌入和密集嵌入(未完全展示 masks 处理部分)
- return sparse_embed, dense_embed
复制代码 五、实战演练
(一)环境搭建与配置
- 安装 Python:版本 ≥ 3.8。
- 安装 PyTorch:
- pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 -f https://download.pytorch.org/whl/cu117
复制代码 - 安装 SAM:
- pip install -U "git+https://github.com/facebookresearch/segment-anything.git"
复制代码 (二)代码运行与结果分析
示例代码:
- from segment_anything import sam_model_registry, SamPredictor
- import cv2
- import numpy as np
- import matplotlib.pyplot as plt
- # 加载 vit_b 版本的 SAM 模型并移动到 GPU
- sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth").to("cuda")
- # 初始化预测器
- predictor = SamPredictor(sam)
- # 读取图像并转换为 RGB 格式
- image = cv2.cvtColor(cv2.imread("test_image.jpg"), cv2.COLOR_BGR2RGB)
- # 设置输入图像,进行预处理和特征提取
- predictor.set_image(image)
- # 定义点提示坐标和标签(1 表示前景)
- masks, scores, _ = predictor.predict(point_coords=np.array([[500, 375]]), point_labels=np.array([1]), multimask_output=True)
- # 遍历掩码和分数,显示结果
- for i, (mask, score) in enumerate(zip(masks, scores)):
- plt.imshow(image) # 显示原始图像
- plt.imshow(mask, alpha=0.6) # 显示掩码,透明度为 0.6
- plt.title(f"Mask {i+1}, Score: {score:.3f}") # 设置标题,显示掩码编号和置信度
- plt.show() # 显示图像
复制代码 结果分析:SAM 根据点提示天生多个掩码,分数反映置信度。高分掩码通常更准确,低分掩码可能包含错误分割。
六、总结与展望
SAM 通过高效的图像编码、提示编码和掩码解码实现了灵活的图像分割。未来,其在医学影像、主动驾驶和视频处置惩罚中的应用潜力巨大。技术发展方向包括模型轻量化、多模态融合和范畴适应,为盘算机视觉带来更多可能性。
延伸阅读
- AI Agent 系列文章
- 盘算机视觉系列文章
- 呆板学习核心算法系列文章
- 深度学习系列文章
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |