《昇思25天学习打卡营第24天》

打印 上一主题 下一主题

主题 679|帖子 679|积分 2039

接续上一天的学习任务,我们要继续进行下一步的操纵
构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模子权重均应从mean为0,sigma为0.02的正态分布中随机初始化。
接下来相识一下其他内容
生成器

生成器G的功能是将隐向量z映射到数据空间。实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会颠末tanh函数,使其返回[-1,1]的数据范围内。
DCGAN论文生成图像如下所示:

通过输入部分中设置的nz、ngf和nc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的巨细有关,nc是输出图像中的通道数。
代码实现
  1. import mindspore as ms
  2. from mindspore import nn, ops
  3. from mindspore.common.initializer import Normal
  4. weight_init = Normal(mean=0, sigma=0.02)
  5. gamma_init = Normal(mean=1, sigma=0.02)
  6. class Generator(nn.Cell):
  7.     """DCGAN网络生成器"""
  8.     def __init__(self):
  9.         super(Generator, self).__init__()
  10.         self.generator = nn.SequentialCell(
  11.             nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
  12.             nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
  13.             nn.ReLU(),
  14.             nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
  15.             nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
  16.             nn.ReLU(),
  17.             nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
  18.             nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
  19.             nn.ReLU(),
  20.             nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
  21.             nn.BatchNorm2d(ngf, gamma_init=gamma_init),
  22.             nn.ReLU(),
  23.             nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
  24.             nn.Tanh()
  25.             )
  26.     def construct(self, x):
  27.         return self.generator(x)
  28. generator = Generator()
复制代码

鉴别器

鉴别器D是一个二分类网络模子,输出判断该图像为真实图的概率。
代码实现
  1. class Discriminator(nn.Cell):
  2.     """DCGAN网络判别器"""
  3.     def __init__(self):
  4.         super(Discriminator, self).__init__()
  5.         self.discriminator = nn.SequentialCell(
  6.             nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
  7.             nn.LeakyReLU(0.2),
  8.             nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
  9.             nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
  10.             nn.LeakyReLU(0.2),
  11.             nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
  12.             nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
  13.             nn.LeakyReLU(0.2),
  14.             nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
  15.             nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
  16.             nn.LeakyReLU(0.2),
  17.             nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
  18.             )
  19.         self.adv_layer = nn.Sigmoid()
  20.     def construct(self, x):
  21.         out = self.discriminator(x)
  22.         out = out.reshape(out.shape[0], -1)
  23.         return self.adv_layer(out)
  24. discriminator = Discriminator()
复制代码

接下来进入模子训练阶段
模子训练

其中分为几个要素:
丧失函数

当定义了D和G后,接下来将使用MindSpore中定义的二进制交织熵丧失函数BCELoss。
优化器

训练模子:训练鉴别器和训练生成器。



实现模子训练正向逻辑:
  1. def generator_forward(real_imgs, valid):
  2.     # 将噪声采样为发生器的输入
  3.     z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
  4.     # 生成一批图像
  5.     gen_imgs = generator(z)
  6.     # 损失衡量发生器绕过判别器的能力
  7.     g_loss = adversarial_loss(discriminator(gen_imgs), valid)
  8.     return g_loss, gen_imgs
  9. def discriminator_forward(real_imgs, gen_imgs, valid, fake):
  10.     # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
  11.     real_loss = adversarial_loss(discriminator(real_imgs), valid)
  12.     fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
  13.     d_loss = (real_loss + fake_loss) / 2
  14.     return d_loss
  15. grad_generator_fn = ms.value_and_grad(generator_forward, None,
  16.                                       optimizer_G.parameters,
  17.                                       has_aux=True)
  18. grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
  19.                                           optimizer_D.parameters)
  20. @ms.jit
  21. def train_step(imgs):
  22.     valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
  23.     fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
  24.     (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
  25.     optimizer_G(g_grads)
  26.     d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
  27.     optimizer_D(d_grads)
  28.     return g_loss, d_loss, gen_imgs
复制代码
代码训练

结果展示就不多说了当作品


文末附上打卡时间


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

刘俊凯

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表