生成对抗网络(GAN)入门与编程实现

打印 上一主题 下一主题

主题 1002|帖子 1002|积分 3006

生成对抗网络(Generative Adversarial Networks, 简称 GAN)自 2014 年由 Ian Goodfellow 等人提出以来,迅速成为呆板学习和深度学习领域的紧张工具之一。GAN 以其在图像生成、风格转换、数据增强等领域的精彩体现,吸引了广泛的研究爱好和应用探索。本文将先容 GAN 的基本概念、工作原理以及怎样通过代码实现一个简单的 GAN 模子。
什么是生成对抗网络(GAN)?

GAN 是一种生成模子,旨在通过学习数据的潜在分布,生成与真实数据相似的样本。它由两个核心部门构成:


  • 生成器(Generator):输入一个随机噪声向量,通过一系列的变换生成假数据,目标是让生成的假数据尽可能靠近真实数据。
  • 判别器(Discriminator):输入真实数据和生成器生成的假数据,输出判定其真实性的概率,目标是尽可能准确地区分真实数据和生成数据。
    二者在训练过程中相互对抗,形成一个博弈过程。

GAN 的工作原理

GAN 的训练过程可以看作是生成器和判别器之间的"零和博弈":

  • 生成器:


  • 输入随机噪声向量                                         z                                  z                     z(通常服从正态分布)。
  • 输出生成的样本                                         G                            (                            z                            )                                  G(z)                     G(z)。
  • 目标是让判别器无法区分                                         G                            (                            z                            )                                  G(z)                     G(z) 和真实数据。

  • 判别器:


  • 输入真实样本                                         x                                  x                     x 和生成器生成的假样本                                         G                            (                            z                            )                                  G(z)                     G(z)。
  • 输出区分真假样本的概率。
  • 目标是最大化对真实样本和生成样本的区分能力。
通过对模子进行训练,生成器逐渐生成更靠近真实分布的样本,而判别器也不断提高其判别能力,直到达到平衡。

完整的训练过程如下:

GAN 的代码实现

接下来,我们通过 PyTorch 实现一个简单的 GAN 模子,生成 MNIST 手写数字图片。

  • 数据加载与预处置惩罚
    MNIST 是一个常用的手写数字数据集,每张图片的巨细为 28x28,灰度范围为 0-1。
  1. # data_loader
  2. transform = transforms.Compose([
  3.         transforms.ToTensor(),
  4.         transforms.Normalize(mean=(0.5), std=(0.5))
  5. ])
  6. train_loader = torch.utils.data.DataLoader(
  7.     datasets.MNIST('data', train=True, download=True, transform=transform),
  8.     batch_size=batch_size, shuffle=True)
复制代码
使用 torchvision 的 datasets.MNIST 下载MNIST数据集。之后,将图片转换为Tensor格式,并对像素值进行归一化(均值0.5,标准差0.5)。

  • 构建生成器与判别器
    生成器和判别器都是多层全连接神经网络。
  1. # G(z)
  2. class generator(nn.Module):
  3.     # initializers
  4.     def __init__(self, input_size=32, n_class = 10):
  5.         super(generator, self).__init__()
  6.         self.fc1 = nn.Linear(input_size, 256)
  7.         self.fc2 = nn.Linear(self.fc1.out_features, 512)
  8.         self.fc3 = nn.Linear(self.fc2.out_features, 1024)
  9.         self.fc4 = nn.Linear(self.fc3.out_features, n_class)
  10.     # forward method
  11.     def forward(self, input):
  12.         x = F.leaky_relu(self.fc1(input), 0.2)
  13.         x = F.leaky_relu(self.fc2(x), 0.2)
  14.         x = F.leaky_relu(self.fc3(x), 0.2)
  15.         x = F.tanh(self.fc4(x))
  16.         x = x.squeeze(-1)
  17.         return x
  18. class discriminator(nn.Module):
  19.     # initializers
  20.     def __init__(self, input_size=32, n_class=10):
  21.         super(discriminator, self).__init__()
  22.         self.fc1 = nn.Linear(input_size, 1024)
  23.         self.fc2 = nn.Linear(self.fc1.out_features, 512)
  24.         self.fc3 = nn.Linear(self.fc2.out_features, 256)
  25.         self.fc4 = nn.Linear(self.fc3.out_features, n_class)
  26.     # forward method
  27.     def forward(self, input):
  28.         x = F.leaky_relu(self.fc1(input), 0.2)
  29.         x = F.dropout(x, 0.3)
  30.         x = F.leaky_relu(self.fc2(x), 0.2)
  31.         x = F.dropout(x, 0.3)
  32.         x = F.leaky_relu(self.fc3(x), 0.2)
  33.         x = F.dropout(x, 0.3)
  34.         x = F.sigmoid(self.fc4(x))
  35.         x = x.squeeze(-1)
  36.         return x
  37.         
  38. # network
  39. G = generator(input_size=100, n_class=28*28)
  40. D = discriminator(input_size=28*28, n_class=1)
复制代码


  • 生成器 (generator):

    • 输入:一个巨细为100的噪声向量。
    • 结构:包罗4个全连接层(fc1到fc4),每层背面跟随一个激活函数:

      • 前三层使用 LeakyReLU 激活函数,最后一层使用 tanh。
      • 输出巨细为 28×28(MNIST图片的尺寸)。

    • 功能:将随机噪声映射为类似于手写数字的图片。

  • 判别器 (discriminator):

    • 输入:展平的MNIST图片(巨细为 28×28)。
    • 结构:包罗4个全连接层(fc1到fc4),每层背面跟随:

      • LeakyReLU 激活函数和 Dropout(用于防止过拟合)。
      • 最后一层使用 sigmoid 激活函数。

    • 输出:一个介于0和1之间的值,表现输入是“真实图片”的概率。


  • 定义训练参数以及丧失函数和优化器
  1. # training parameters
  2. batch_size = 256
  3. lr = 0.0002
  4. train_epoch = 200
  5. device = torch.cuda.is_available()
  6. if device:
  7.     print("running on GPU!")
  8.    
  9. # Binary Cross Entropy loss
  10. BCE_loss = nn.BCELoss()
  11. #move to cuda
  12. if device:
  13.     G.cuda()
  14.     D.cuda()
  15.     BCE_loss = BCE_loss.cuda()
  16. # Adam optimizer
  17. G_optimizer = optim.Adam(G.parameters(), lr=lr)
  18. D_optimizer = optim.Adam(D.parameters(), lr=lr)
  19. 4. 训练过程
  20. 在训练过程中,我们交替训练判别器和生成器。
  21. train_hist = {}
  22. train_hist['D_losses'] = []
  23. train_hist['G_losses'] = []
  24. for epoch in range(train_epoch):
  25.     D_losses = []
  26.     G_losses = []
  27.     #生成任务,不需要标签
  28.     for x_, _ in train_loader:
  29.         #训练图像展平
  30.         x_ = x_.view(-1, 28 * 28)
  31.         mini_batch = x_.size()[0]
  32.         y_real_ = torch.ones(mini_batch)
  33.         y_fake_ = torch.zeros(mini_batch)
  34.         # train discriminator D
  35.         D.zero_grad()
  36.         z_ = torch.randn((mini_batch, 100))
  37.         
  38.         if device:
  39.             x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()
  40.             z_ = z_.cuda()
  41.         #真数据loss
  42.         D_result = D(x_)
  43.         D_real_loss = BCE_loss(D_result, y_real_)
  44.         D_real_score = D_result
  45.         #假数据loss
  46.         G_result = G(z_)
  47.         D_result = D(G_result)
  48.         D_fake_loss = BCE_loss(D_result, y_fake_)
  49.         D_fake_score = D_result
  50.         D_train_loss = D_real_loss + D_fake_loss
  51.         D_train_loss.backward()
  52.         D_optimizer.step()
  53.         D_losses.append(D_train_loss.item())
  54.         # train generator G
  55.         G.zero_grad()
  56.         # z_ = torch.randn((mini_batch, 100))
  57.         # if device:
  58.         #     z_ = z_.cuda()
  59.         G_result = G(z_)
  60.         D_result = D(G_result)
  61.         G_train_loss = BCE_loss(D_result, y_real_)
  62.         G_train_loss.backward()
  63.         G_optimizer.step()
  64.         G_losses.append(G_train_loss.item())
  65.     print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
  66.         (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
  67.     if epoch %10 == 0:
  68.         p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
  69.         fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
  70.         show_result((epoch+1), save=True, path=p, isFix=False)
  71.         show_result((epoch+1), save=True, path=fixed_p, isFix=True)
  72.         train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
  73.         train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
复制代码
采用交织熵丧失函数(BCE)盘算Loss,即

此中判别器的loss盘算如下:

生成器的loss盘算如下:


  • 保存模子及数据
    将生成器和判别器的模子参数进行保存,保存训练过程的loss数据。
  1. print("Training finish!... save training results")
  2. torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
  3. torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
  4. with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
  5.     pickle.dump(train_hist, f)
复制代码

  • 数据可视化
  1. show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')
  2. images = []
  3. for e in range(train_epoch):
  4.     img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
  5.     images.append(imageio.imread(img_name))
  6. imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
复制代码

  • 完整代码
  1. import osimport matplotlib.pyplot as pltimport itertoolsimport pickleimport imageioimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transforms# from torch.autograd import Variable# G(z)class generator(nn.Module):    # initializers    def __init__(self, input_size=32, n_class = 10):        super(generator, self).__init__()        self.fc1 = nn.Linear(input_size, 256)        self.fc2 = nn.Linear(self.fc1.out_features, 512)        self.fc3 = nn.Linear(self.fc2.out_features, 1024)        self.fc4 = nn.Linear(self.fc3.out_features, n_class)    # forward method    def forward(self, input):        x = F.leaky_relu(self.fc1(input), 0.2)        x = F.leaky_relu(self.fc2(x), 0.2)        x = F.leaky_relu(self.fc3(x), 0.2)        x = F.tanh(self.fc4(x))        x = x.squeeze(-1)        return xclass discriminator(nn.Module):    # initializers    def __init__(self, input_size=32, n_class=10):        super(discriminator, self).__init__()        self.fc1 = nn.Linear(input_size, 1024)        self.fc2 = nn.Linear(self.fc1.out_features, 512)        self.fc3 = nn.Linear(self.fc2.out_features, 256)        self.fc4 = nn.Linear(self.fc3.out_features, n_class)    # forward method    def forward(self, input):        x = F.leaky_relu(self.fc1(input), 0.2)        x = F.dropout(x, 0.3)        x = F.leaky_relu(self.fc2(x), 0.2)        x = F.dropout(x, 0.3)        x = F.leaky_relu(self.fc3(x), 0.2)        x = F.dropout(x, 0.3)        x = F.sigmoid(self.fc4(x))        x = x.squeeze(-1)        return xfixed_z_ = torch.randn((5 * 5, 100))    # fixed noisewith torch.no_grad():    fixed_z_ = fixed_z_.cuda()def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):    z_ = torch.randn((5*5, 100))    with torch.no_grad():        z_ = z_.cuda()    # z_ = Variable(z_.cuda(), volatile=True)    G.eval()    if isFix:        test_images = G(fixed_z_)    else:        test_images = G(z_)    G.train()    size_figure_grid = 5    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):        ax[i, j].get_xaxis().set_visible(False)        ax[i, j].get_yaxis().set_visible(False)    for k in range(5*5):        i = k // 5        j = k % 5        ax[i, j].cla()        ax[i, j].imshow(test_images[k, :].cpu().data.view(28, 28).numpy(), cmap='gray')    label = 'Epoch {0}'.format(num_epoch)    fig.text(0.5, 0.04, label, ha='center')    plt.savefig(path)    if show:        plt.show()    else:        plt.close()def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):    x = range(len(hist['D_losses']))    y1 = hist['D_losses']    y2 = hist['G_losses']    plt.plot(x, y1, label='D_loss')    plt.plot(x, y2, label='G_loss')    plt.xlabel('Epoch')    plt.ylabel('Loss')    plt.legend(loc=4)    plt.grid(True)    plt.tight_layout()    if save:        plt.savefig(path)    if show:        plt.show()    else:        plt.close()# training parametersbatch_size = 256lr = 0.0002train_epoch = 200device = torch.cuda.is_available()if device:    print("running on GPU!")# data_loader
  2. transform = transforms.Compose([
  3.         transforms.ToTensor(),
  4.         transforms.Normalize(mean=(0.5), std=(0.5))
  5. ])
  6. train_loader = torch.utils.data.DataLoader(
  7.     datasets.MNIST('data', train=True, download=True, transform=transform),
  8.     batch_size=batch_size, shuffle=True)
  9. # networkG = generator(input_size=100, n_class=28*28)D = discriminator(input_size=28*28, n_class=1)# Binary Cross Entropy lossBCE_loss = nn.BCELoss()#move to cudaif device:    G.cuda()    D.cuda()    BCE_loss = BCE_loss.cuda()# Adam optimizerG_optimizer = optim.Adam(G.parameters(), lr=lr)D_optimizer = optim.Adam(D.parameters(), lr=lr)# results save folderif not os.path.isdir('MNIST_GAN_results'):    os.mkdir('MNIST_GAN_results')if not os.path.isdir('MNIST_GAN_results/Random_results'):    os.mkdir('MNIST_GAN_results/Random_results')if not os.path.isdir('MNIST_GAN_results/Fixed_results'):    os.mkdir('MNIST_GAN_results/Fixed_results')train_hist = {}train_hist['D_losses'] = []train_hist['G_losses'] = []for epoch in range(train_epoch):    D_losses = []    G_losses = []    #生成任务,不需要标签    for x_, _ in train_loader:        #训练图像展平        x_ = x_.view(-1, 28 * 28)        mini_batch = x_.size()[0]        y_real_ = torch.ones(mini_batch)        y_fake_ = torch.zeros(mini_batch)        # train discriminator D        D.zero_grad()        z_ = torch.randn((mini_batch, 100))                if device:            x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()            z_ = z_.cuda()        #真数据loss        D_result = D(x_)        D_real_loss = BCE_loss(D_result, y_real_)        D_real_score = D_result        #假数据loss        G_result = G(z_)        D_result = D(G_result)        D_fake_loss = BCE_loss(D_result, y_fake_)        D_fake_score = D_result        D_train_loss = D_real_loss + D_fake_loss        D_train_loss.backward()        D_optimizer.step()        D_losses.append(D_train_loss.item())        # train generator G        G.zero_grad()        # z_ = torch.randn((mini_batch, 100))        # if device:        #     z_ = z_.cuda()        G_result = G(z_)        D_result = D(G_result)        G_train_loss = BCE_loss(D_result, y_real_)        G_train_loss.backward()        G_optimizer.step()        G_losses.append(G_train_loss.item())    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (        (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))    if epoch %10 == 0:        p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'        fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'        show_result((epoch+1), save=True, path=p, isFix=False)        show_result((epoch+1), save=True, path=fixed_p, isFix=True)        train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))print("Training finish!... save training results")
  10. torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
  11. torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
  12. with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
  13.     pickle.dump(train_hist, f)
  14. show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')
  15. images = []
  16. for e in range(train_epoch):
  17.     img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
  18.     images.append(imageio.imread(img_name))
  19. imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
复制代码
训练结果


以上是训练190个epoch后得到的结果,可以看到此中某些图片已经有了数字的模样。这里仅仅是使用了全连接层来搭建模子,如果使用卷积神经网络,效果会有更好的提拔,大家可以尝试一下。
遇到的问题

可以适本地提高batch size来提高训练速率,也可以切换更简单的loss函数来提高训练速率。
建议batch size从底到高慢慢调节,若batch size过高,可能导致模子训练出现问题。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

惊落一身雪

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表