通过类似数据蒸馏或自动学习采样的方法,更加高效地学习良品数据分布 ...

打印 上一主题 下一主题

主题 1548|帖子 1548|积分 4644

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
好的,我们先聚焦第一个突破点:
   通过类似数据蒸馏或自动学习采样的方法,更加高效地学习良品数据分布。
  这里我提供一个完备的代码示例:
Masked图像重建 + 残差热力图

这属于自监督蒸馏方法的一个变体:


  • 利用一个 预练习MAE模子(或轻量ViT)对正常样本进行遮挡重建
  • 用重建图与原图的残差来反映“异常程度”

✅ 示例环境依靠

  1. pip install timm einops torchvision matplotlib
复制代码

✅ 完备代码(以MVTec中的图像为例)

  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as T
  4. from torchvision.utils import save_image
  5. from torchvision.datasets.folder import default_loader
  6. from einops import rearrange
  7. import timm
  8. import matplotlib.pyplot as plt
  9. import os
  10. from glob import glob
  11. from PIL import Image
  12. import numpy as np
  13. # ---------------------------
  14. # 模型定义:ViT作为Encoder + 简单Decoder
  15. # ---------------------------
  16. class MAE(nn.Module):
  17.     def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):
  18.         super().__init__()
  19.         self.encoder = timm.create_model(encoder_name, pretrained=True)
  20.         self.mask_ratio = mask_ratio
  21.         self.patch_size = self.encoder.patch_embed.patch_size[0]
  22.         self.num_patches = self.encoder.patch_embed.num_patches
  23.         self.embed_dim = self.encoder.embed_dim
  24.         self.decoder = nn.Sequential(
  25.             nn.Linear(self.embed_dim, self.embed_dim),
  26.             nn.GELU(),
  27.             nn.Linear(self.embed_dim, self.patch_size**2 * 3)
  28.         )
  29.     def forward(self, x):
  30.         B, C, H, W = x.shape
  31.         x_patch = self.encoder.patch_embed(x)  # [B, num_patches, dim]
  32.         B, N, D = x_patch.shape
  33.         # 随机遮挡
  34.         rand_idx = torch.rand(B, N).argsort(dim=1)
  35.         num_keep = int(N * (1 - self.mask_ratio))
  36.         keep_idx = rand_idx[:, :num_keep]
  37.         x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))
  38.         x_encoded = self.encoder.blocks(x_keep)
  39.         x_decoded = self.decoder(x_encoded)
  40.         # 恢复顺序(只对keep部分重建)
  41.         output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)
  42.         output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)
  43.         output = rearrange(output, 'b n (p c) -> b c (h p) (w p)',
  44.                            p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))
  45.         return output
  46. # ---------------------------
  47. # 数据加载 + 预处理
  48. # ---------------------------
  49. transform = T.Compose([
  50.     T.Resize((224, 224)),
  51.     T.ToTensor(),
  52.     T.Normalize([0.5]*3, [0.5]*3)
  53. ])
  54. inv_transform = T.Compose([
  55.     T.Normalize(mean=[-1]*3, std=[2]*3)
  56. ])
  57. def load_images(path):
  58.     files = sorted(glob(os.path.join(path, '*.png')) + glob(os.path.join(path, '*.jpg')))
  59.     images = []
  60.     for f in files:
  61.         img = default_loader(f)
  62.         images.append(transform(img))
  63.     return torch.stack(images)
  64. # ---------------------------
  65. # 测试图像 → 重建图像 → 残差热图
  66. # ---------------------------
  67. def visualize_anomaly(original, recon, save_path='result.png'):
  68.     residual = (original - recon).abs().sum(dim=1, keepdim=True)
  69.     residual = residual / residual.max()
  70.     fig, axs = plt.subplots(1, 3, figsize=(12, 4))
  71.     axs[0].imshow(inv_transform(original[0]).permute(1, 2, 0).cpu().numpy())
  72.     axs[0].set_title('Original')
  73.     axs[1].imshow(inv_transform(recon[0]).permute(1, 2, 0).cpu().numpy())
  74.     axs[1].set_title('Reconstruction')
  75.     axs[2].imshow(residual[0, 0].cpu().numpy(), cmap='hot')
  76.     axs[2].set_title('Anomaly Map')
  77.     for ax in axs: ax.axis('off')
  78.     plt.tight_layout()
  79.     plt.savefig(save_path)
  80.     plt.close()
  81. # ---------------------------
  82. # 主程序执行
  83. # ---------------------------
  84. if __name__ == '__main__':
  85.     device = 'cuda' if torch.cuda.is_available() else 'cpu'
  86.     model = MAE().to(device)
  87.     model.eval()
  88.     # 替换为 MVTec / VisA 任一类别路径
  89.     image_dir = './mvtec/bottle/good/'  # 只加载良品图像
  90.     images = load_images(image_dir).to(device)
  91.     with torch.no_grad():
  92.         for i in range(min(5, len(images))):
  93.             input_img = images[i:i+1]
  94.             recon_img = model(input_img)
  95.             visualize_anomaly(input_img, recon_img, f'output_{i}.png')
复制代码

✅ 示例输出(保存为output_0.png等):



  • 左:原图
  • 中:重建图(模子“明白的良品”)
  • 右:异常热图(残差)
在正常样本上,残差图应接近0;如果输入的是异常图像,则对应区域将出现高相应。

✅ 可扩展方向

模块可扩展优化Encoder更换为轻量ViT(如 vit_tiny_patch16_224)Mask计谋利用结构化遮挡(如Block Mask)提拔重建挑战异常图像输入异常样本(如MVTec测试集中defect图)验证泛化本领练习到场重建loss微调,进步良品建模精度
如果你希望我进一步扩展为:


  • 支持少量异常图像的快速修正版本;
  • 或到场自动样本选择机制;
很好,我们继续在上一套基于自监督重建(MAE)方法的基础上,
为其 封装 Gradio Demo,以实现更直观的异常检测体验。

✅ 新增功能目标


  • 上传恣意图片(良品或异常图)
  • 及时体现:

    • 原图
    • 模子重建图
    • 残差热力图(高相应 = 异常区域)


✅ 完备代码(附Gradio界面)

  1. import torch
  2. import torch.nn as nn
  3. import torchvision.transforms as T
  4. from torchvision.utils import save_image
  5. from torchvision.datasets.folder import default_loader
  6. from einops import rearrange
  7. import timm
  8. import gradio as gr
  9. import numpy as np
  10. from PIL import Image
  11. import matplotlib.pyplot as plt
  12. import io
  13. # ---------------------------
  14. # 模型定义(同上)
  15. # ---------------------------
  16. class MAE(nn.Module):
  17.     def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):
  18.         super().__init__()
  19.         self.encoder = timm.create_model(encoder_name, pretrained=True)
  20.         self.mask_ratio = mask_ratio
  21.         self.patch_size = self.encoder.patch_embed.patch_size[0]
  22.         self.num_patches = self.encoder.patch_embed.num_patches
  23.         self.embed_dim = self.encoder.embed_dim
  24.         self.decoder = nn.Sequential(
  25.             nn.Linear(self.embed_dim, self.embed_dim),
  26.             nn.GELU(),
  27.             nn.Linear(self.embed_dim, self.patch_size**2 * 3)
  28.         )
  29.     def forward(self, x):
  30.         B, C, H, W = x.shape
  31.         x_patch = self.encoder.patch_embed(x)
  32.         B, N, D = x_patch.shape
  33.         rand_idx = torch.rand(B, N).argsort(dim=1)
  34.         num_keep = int(N * (1 - self.mask_ratio))
  35.         keep_idx = rand_idx[:, :num_keep]
  36.         x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))
  37.         x_encoded = self.encoder.blocks(x_keep)
  38.         x_decoded = self.decoder(x_encoded)
  39.         output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)
  40.         output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)
  41.         output = rearrange(output, 'b n (p c) -> b c (h p) (w p)',
  42.                            p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))
  43.         return output
  44. # ---------------------------
  45. # 预处理 & 后处理
  46. # ---------------------------
  47. transform = T.Compose([
  48.     T.Resize((224, 224)),
  49.     T.ToTensor(),
  50.     T.Normalize([0.5]*3, [0.5]*3)
  51. ])
  52. inv_transform = T.Compose([
  53.     T.Normalize(mean=[-1]*3, std=[2]*3)
  54. ])
  55. def tensor_to_pil(t):
  56.     t = inv_transform(t.squeeze(0)).clamp(0, 1)
  57.     return T.ToPILImage()(t)
  58. def residual_map(orig, recon):
  59.     residual = (orig - recon).abs().sum(dim=1, keepdim=True)
  60.     residual = residual / (residual.max() + 1e-8)
  61.     heat = residual.squeeze().cpu().numpy()
  62.     fig, ax = plt.subplots()
  63.     ax.imshow(heat, cmap='hot')
  64.     ax.axis('off')
  65.     buf = io.BytesIO()
  66.     plt.savefig(buf, format='png')
  67.     plt.close(fig)
  68.     buf.seek(0)
  69.     return Image.open(buf)
  70. # ---------------------------
  71. # 推理函数
  72. # ---------------------------
  73. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  74. model = MAE().to(device)
  75. model.eval()
  76. def infer(img_pil):
  77.     img_tensor = transform(img_pil).unsqueeze(0).to(device)
  78.     with torch.no_grad():
  79.         recon = model(img_tensor)
  80.     recon_img = tensor_to_pil(recon)
  81.     input_img = tensor_to_pil(img_tensor)
  82.     heatmap = residual_map(img_tensor, recon)
  83.     return input_img, recon_img, heatmap
  84. # ---------------------------
  85. # Gradio UI
  86. # ---------------------------
  87. demo = gr.Interface(
  88.     fn=infer,
  89.     inputs=gr.Image(type="pil", label="上传图像"),
  90.     outputs=[
  91.         gr.Image(type="pil", label="原图"),
  92.         gr.Image(type="pil", label="重建图"),
  93.         gr.Image(type="pil", label="残差热图")
  94.     ],
  95.     title="基于良品数据的异常检测(MAE重建)",
  96.     description="上传图像,模型将重建正常区域并生成异常残差热力图"
  97. )
  98. if __name__ == '__main__':
  99.     demo.launch()
复制代码

✅ 利用效果

你可以上传如下类型图像进行及时检测:


  • ✔️ 良品图像:残差图团体应较为平滑,相应值低;
  • 异常图像(如划痕/破损):残差图中异常区域显着发亮(高相应);

✅ 后续扩展建议:

模块可增强重建网络替换为 DRAEM / Reverse Distillation异常评分盘算全图平均残差 + Otsu二值化分割多样本比较支持目录上传并批量可视化迁移微调用少量目标数据 fine-tune 提拔领域鲁棒性
需要我下一步为你实现:


  • ✅ 残差异常评分 + 二值掩码输出?
  • ✅ 支持少量异常样本微调功能?
  • ✅ 用 PatchCore / AnomalyCLIP 替换 MAE 结构?
你可以指定下一个要增强的方向,我这边可以直接给出代码。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

九天猎人

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表