第 8 期:条件天生 DDPM:让模型“听话”地画图!

打印 上一主题 下一主题

主题 1521|帖子 1521|积分 4563

本期关键词:Conditional DDPM、Class Embedding、Label Control、CIFAR-10 条件天生
什么是条件天生(Conditional Generation)?

在标准的 DDPM 中,我们只是“随机天生”图像。
   如果我想让模型天生「小狗」怎么办?
  这就要给模型添加“引导”——标签或笔墨,这种方式就叫 条件天生(Conditional Generation)
条件扩散的原理是什么?

我们要将种别信息 y 参加到模型中,使预测的噪声满足条件:

也就是说,模型要知道当前是“第几类”的图像,从而引导去噪方向。
实现思绪:


  • 将标签 y 进行嵌入(embedding);
  • 将其与时间步编码、图像特性一起送入网络中。
修改 UNet 支持条件标签

我们对 UNet 加一点“料”——标签 embedding。
  1. class ConditionalUNet(nn.Module):
  2.     def __init__(self, num_classes, time_dim=256):
  3.         super().__init__()
  4.         self.time_embed = nn.Sequential(
  5.             SinusoidalPositionEmbeddings(time_dim),
  6.             nn.Linear(time_dim, time_dim),
  7.             nn.ReLU()
  8.         )
  9.         self.label_embed = nn.Embedding(num_classes, time_dim)
  10.         self.conv0 = nn.Conv2d(3, 64, 3, padding=1)
  11.         self.down1 = Block(64, 128, time_dim)
  12.         self.down2 = Block(128, 256, time_dim)
  13.         self.bot = Block(256, 256, time_dim)
  14.         self.up1 = Block(512, 128, time_dim)
  15.         self.up2 = Block(256, 64, time_dim)
  16.         self.final = nn.Conv2d(64, 3, 1)
  17.     def forward(self, x, t, y):
  18.         t_embed = self.time_embed(t)
  19.         y_embed = self.label_embed(y)
  20.         cond = t_embed + y_embed  # 条件融合
  21.         x0 = self.conv0(x)
  22.         x1 = self.down1(x0, cond)
  23.         x2 = self.down2(x1, cond)
  24.         x3 = self.bot(x2, cond)
  25.         x = self.up1(torch.cat([x3, x2], 1), cond)
  26.         x = self.up2(torch.cat([x, x1], 1), cond)
  27.         return self.final(x)
复制代码
 我们为标签添加了一个 nn.Embedding,并与时间编码相加作为“条件向量”注入。
修改训练函数支持 label

  1. def get_conditional_loss(model, x_0, t, y):
  2.     noise = torch.randn_like(x_0)
  3.     x_t = q_sample(x_0, t, noise)
  4.     pred = model(x_t, t, y)
  5.     return F.mse_loss(pred, noise)
复制代码
训练主循环如下:
  1. for epoch in range(epochs):
  2.     for x, y in dataloader:
  3.         x = x.to(device)
  4.         y = y.to(device)
  5.         t = torch.randint(0, T, (x.size(0),), device=device).long()
  6.         loss = get_conditional_loss(model, x, t, y)
  7.         optimizer.zero_grad()
  8.         loss.backward()
  9.         optimizer.step()
复制代码
条件天生代码:指定种别天生图像!

  1. @torch.no_grad()
  2. def sample_with_labels(model, label, num_samples=16, img_size=32, device='cuda'):
  3.     model.eval()
  4.     x = torch.randn(num_samples, 3, img_size, img_size).to(device)
  5.     y = torch.tensor([label] * num_samples).to(device)
  6.    
  7.     for i in reversed(range(T)):
  8.         t = torch.full((num_samples,), i, device=device, dtype=torch.long)
  9.         noise_pred = model(x, t, y)
  10.         alpha = alphas_cumprod[t][:, None, None, None]
  11.         sqrt_alpha = torch.sqrt(alpha)
  12.         sqrt_one_minus_alpha = torch.sqrt(1 - alpha)
  13.         x_0_pred = (x - sqrt_one_minus_alpha * noise_pred) / sqrt_alpha
  14.         x_0_pred = x_0_pred.clamp(-1, 1)
  15.         if i > 0:
  16.             noise = torch.randn_like(x)
  17.             beta_t = betas[t][:, None, None, None]
  18.             x = sqrt_alpha * x_0_pred + torch.sqrt(beta_t) * noise
  19.         else:
  20.             x = x_0_pred
  21.     return x
复制代码
可视化指定种别的天生图像

  1. samples = sample_with_labels(model, label=3, num_samples=16)  # e.g., cat
  2. samples = (samples.clamp(-1, 1) + 1) / 2
  3. grid = torchvision.utils.make_grid(samples, nrow=4)
  4. plt.figure(figsize=(6, 6))
  5. plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
  6. plt.axis('off')
  7. plt.title("Generated Class 3 (Cat)")
  8. plt.show()
复制代码
CIFAR-10 种别索引(参考)


种别编号种别名称0airplane1automobile2bird3cat4deer5dog6frog7horse8ship9truck 总结

在本期中,我们学习了如何:


  • ✅ 在 UNet 中添加类嵌入;
  • ✅ 修改损失函数以支持标签;
  • ✅ 实现条件采样天生指定种别图像;
  • ✅ 可视化天见效果。
第 9 期预告:「CLIP + Diffusion」文本条件扩散!

下一期我们将解锁 笔墨引导天生图像 的能力,用一句话天生图像!
   “一只戴着墨镜冲浪的柴犬”将成为实际!
  
 

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

慢吞云雾缓吐愁

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表