Mindspore框架利用扩散模型DDPM天生高分辨率图像|(三)模型训练与推理实践 ...

打印 上一主题 下一主题

主题 560|帖子 560|积分 1680

利用扩散模型DDPM天生高分辨率图像(天生高保真图像项目实践)

Mindspore框架利用扩散模型DDPM天生高分辨率图像|(一)关于denoising diffusion probabilistic model (DDPM)模型
Mindspore框架利用扩散模型DDPM天生高分辨率图像|(二)数据集准备与处理
Mindspore框架利用扩散模型DDPM天生高分辨率图像|(三)模型训练与推理实践

一、扩散模型DDPM模型训练

扩散模型天生新图像是通过反转扩散过程来实现。理想情况下,我们最终会得到一个看起来像是来自真实数据分布的图像。
  1. def p_sample(model, x, t, t_index):
  2.     betas_t = extract(betas, t, x.shape)
  3.     sqrt_one_minus_alphas_cumprod_t = extract(
  4.         sqrt_one_minus_alphas_cumprod, t, x.shape
  5.     )
  6.     sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
  7.     model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
  8.     if t_index == 0:
  9.         return model_mean
  10.     posterior_variance_t = extract(posterior_variance, t, x.shape)
  11.     noise = randn_like(x)
  12.     return model_mean + ops.sqrt(posterior_variance_t) * noise
  13. def p_sample_loop(model, shape):
  14.     b = shape[0]
  15.     # 从纯噪声开始
  16.     img = randn(shape, dtype=None)
  17.     imgs = []
  18.     for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
  19.         img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
  20.         imgs.append(img.asnumpy())
  21.     return imgs
  22. def sample(model, image_size, batch_size=16, channels=3):
  23.     return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
复制代码
1.1 模型初始化

  1. # 定义动态学习率
  2. lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
  3. # 定义 Unet模型
  4. unet_model = Unet(
  5.     dim=image_size,
  6.     channels=channels,
  7.     dim_mults=(1, 2, 4,)
  8. )
  9. name_list = []
  10. for (name, par) in list(unet_model.parameters_and_names()):
  11.     name_list.append(name)
  12. i = 0
  13. for item in list(unet_model.trainable_params()):
  14.     item.name = name_list[i]
  15.     i += 1
  16. # 定义优化器
  17. optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
  18. loss_scaler = DynamicLossScaler(65536, 2, 1000)
  19. # 定义前向过程
  20. def forward_fn(data, t, noise=None):
  21.     loss = p_losses(unet_model, data, t, noise)
  22.     return loss
  23. # 计算梯度
  24. grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
  25. # 梯度更新
  26. def train_step(data, t, noise):
  27.     loss, grads = grad_fn(data, t, noise)
  28.     optimizer(grads)
  29.     return loss
复制代码
1.2 模型训练

  1. import time
  2. epochs = 10
  3. iterator = dataset.create_tuple_iterator(num_epochs=epochs)
  4. for epoch in range(epochs):
  5.     begin_time = time.time()
  6.     for step, batch in enumerate(iterator):
  7.         unet_model.set_train()
  8.         batch_size = batch[0].shape[0]
  9.         t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
  10.         noise = randn_like(batch[0])
  11.         loss = train_step(batch[0], t, noise)
  12.         if step % 500 == 0:
  13.             print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
  14.     end_time = time.time()
  15.     times = end_time - begin_time
  16.     print("training time:", times, "s")
  17.     # 展示随机采样效果
  18.     unet_model.set_train(False)
  19.     samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
  20.     plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
  21. print("Training Success!")
复制代码

二、模型推理预测

2.1 推理过程

从模型中采样
  1. # 采样64个图片
  2. unet_model.set_train(False)
  3. samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
  4. # 展示一个随机效果
  5. # random_index = 5
  6. # plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")
复制代码
创建去噪过程
  1. import matplotlib.animation as animation
  2. random_index = 53
  3. fig = plt.figure()
  4. ims = []
  5. for i in range(timesteps):
  6.     im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
  7.     ims.append([im])
  8. animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
  9. animate.save('diffusion.gif')
  10. plt.show()
复制代码

三、参考文献


  • The Annotated Diffusion Model
  • 由浅入深了解Diffusion Model
  • Diffusion扩散模型

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

反转基因福娃

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表