生成对抗网络(Generative Adversarial Networks, 简称 GAN)自 2014 年由 Ian Goodfellow 等人提出以来,迅速成为呆板学习和深度学习领域的紧张工具之一。GAN 以其在图像生成、风格转换、数据增强等领域的精彩体现,吸引了广泛的研究爱好和应用探索。本文将先容 GAN 的基本概念、工作原理以及怎样通过代码实现一个简单的 GAN 模子。
什么是生成对抗网络(GAN)?
GAN 是一种生成模子,旨在通过学习数据的潜在分布,生成与真实数据相似的样本。它由两个核心部门构成:
- 生成器(Generator):输入一个随机噪声向量,通过一系列的变换生成假数据,目标是让生成的假数据尽可能靠近真实数据。
- 判别器(Discriminator):输入真实数据和生成器生成的假数据,输出判定其真实性的概率,目标是尽可能准确地区分真实数据和生成数据。
二者在训练过程中相互对抗,形成一个博弈过程。
GAN 的工作原理
GAN 的训练过程可以看作是生成器和判别器之间的"零和博弈":
- 输入随机噪声向量 z z z(通常服从正态分布)。
- 输出生成的样本 G ( z ) G(z) G(z)。
- 目标是让判别器无法区分 G ( z ) G(z) G(z) 和真实数据。
- 输入真实样本 x x x 和生成器生成的假样本 G ( z ) G(z) G(z)。
- 输出区分真假样本的概率。
- 目标是最大化对真实样本和生成样本的区分能力。
通过对模子进行训练,生成器逐渐生成更靠近真实分布的样本,而判别器也不断提高其判别能力,直到达到平衡。
完整的训练过程如下:
GAN 的代码实现
接下来,我们通过 PyTorch 实现一个简单的 GAN 模子,生成 MNIST 手写数字图片。
- 数据加载与预处置惩罚
MNIST 是一个常用的手写数字数据集,每张图片的巨细为 28x28,灰度范围为 0-1。
- # data_loader
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=(0.5), std=(0.5))
- ])
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('data', train=True, download=True, transform=transform),
- batch_size=batch_size, shuffle=True)
复制代码 使用 torchvision 的 datasets.MNIST 下载MNIST数据集。之后,将图片转换为Tensor格式,并对像素值进行归一化(均值0.5,标准差0.5)。
- 构建生成器与判别器
生成器和判别器都是多层全连接神经网络。
- # G(z)
- class generator(nn.Module):
- # initializers
- def __init__(self, input_size=32, n_class = 10):
- super(generator, self).__init__()
- self.fc1 = nn.Linear(input_size, 256)
- self.fc2 = nn.Linear(self.fc1.out_features, 512)
- self.fc3 = nn.Linear(self.fc2.out_features, 1024)
- self.fc4 = nn.Linear(self.fc3.out_features, n_class)
- # forward method
- def forward(self, input):
- x = F.leaky_relu(self.fc1(input), 0.2)
- x = F.leaky_relu(self.fc2(x), 0.2)
- x = F.leaky_relu(self.fc3(x), 0.2)
- x = F.tanh(self.fc4(x))
- x = x.squeeze(-1)
- return x
- class discriminator(nn.Module):
- # initializers
- def __init__(self, input_size=32, n_class=10):
- super(discriminator, self).__init__()
- self.fc1 = nn.Linear(input_size, 1024)
- self.fc2 = nn.Linear(self.fc1.out_features, 512)
- self.fc3 = nn.Linear(self.fc2.out_features, 256)
- self.fc4 = nn.Linear(self.fc3.out_features, n_class)
- # forward method
- def forward(self, input):
- x = F.leaky_relu(self.fc1(input), 0.2)
- x = F.dropout(x, 0.3)
- x = F.leaky_relu(self.fc2(x), 0.2)
- x = F.dropout(x, 0.3)
- x = F.leaky_relu(self.fc3(x), 0.2)
- x = F.dropout(x, 0.3)
- x = F.sigmoid(self.fc4(x))
- x = x.squeeze(-1)
- return x
-
- # network
- G = generator(input_size=100, n_class=28*28)
- D = discriminator(input_size=28*28, n_class=1)
复制代码
- 生成器 (generator):
- 输入:一个巨细为100的噪声向量。
- 结构:包罗4个全连接层(fc1到fc4),每层背面跟随一个激活函数:
- 前三层使用 LeakyReLU 激活函数,最后一层使用 tanh。
- 输出巨细为 28×28(MNIST图片的尺寸)。
- 功能:将随机噪声映射为类似于手写数字的图片。
- 判别器 (discriminator):
- 输入:展平的MNIST图片(巨细为 28×28)。
- 结构:包罗4个全连接层(fc1到fc4),每层背面跟随:
- LeakyReLU 激活函数和 Dropout(用于防止过拟合)。
- 最后一层使用 sigmoid 激活函数。
- 输出:一个介于0和1之间的值,表现输入是“真实图片”的概率。
- # training parameters
- batch_size = 256
- lr = 0.0002
- train_epoch = 200
- device = torch.cuda.is_available()
- if device:
- print("running on GPU!")
-
- # Binary Cross Entropy loss
- BCE_loss = nn.BCELoss()
- #move to cuda
- if device:
- G.cuda()
- D.cuda()
- BCE_loss = BCE_loss.cuda()
- # Adam optimizer
- G_optimizer = optim.Adam(G.parameters(), lr=lr)
- D_optimizer = optim.Adam(D.parameters(), lr=lr)
- 4. 训练过程
- 在训练过程中,我们交替训练判别器和生成器。
- train_hist = {}
- train_hist['D_losses'] = []
- train_hist['G_losses'] = []
- for epoch in range(train_epoch):
- D_losses = []
- G_losses = []
- #生成任务,不需要标签
- for x_, _ in train_loader:
- #训练图像展平
- x_ = x_.view(-1, 28 * 28)
- mini_batch = x_.size()[0]
- y_real_ = torch.ones(mini_batch)
- y_fake_ = torch.zeros(mini_batch)
- # train discriminator D
- D.zero_grad()
- z_ = torch.randn((mini_batch, 100))
-
- if device:
- x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda()
- z_ = z_.cuda()
- #真数据loss
- D_result = D(x_)
- D_real_loss = BCE_loss(D_result, y_real_)
- D_real_score = D_result
- #假数据loss
- G_result = G(z_)
- D_result = D(G_result)
- D_fake_loss = BCE_loss(D_result, y_fake_)
- D_fake_score = D_result
- D_train_loss = D_real_loss + D_fake_loss
- D_train_loss.backward()
- D_optimizer.step()
- D_losses.append(D_train_loss.item())
- # train generator G
- G.zero_grad()
- # z_ = torch.randn((mini_batch, 100))
- # if device:
- # z_ = z_.cuda()
- G_result = G(z_)
- D_result = D(G_result)
- G_train_loss = BCE_loss(D_result, y_real_)
- G_train_loss.backward()
- G_optimizer.step()
- G_losses.append(G_train_loss.item())
- print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
- (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
- if epoch %10 == 0:
- p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png'
- fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png'
- show_result((epoch+1), save=True, path=p, isFix=False)
- show_result((epoch+1), save=True, path=fixed_p, isFix=True)
- train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
- train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
复制代码 采用交织熵丧失函数(BCE)盘算Loss,即

此中判别器的loss盘算如下:

生成器的loss盘算如下:

- 保存模子及数据
将生成器和判别器的模子参数进行保存,保存训练过程的loss数据。
- print("Training finish!... save training results")
- torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
- torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
- with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
- pickle.dump(train_hist, f)
复制代码- show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')
- images = []
- for e in range(train_epoch):
- img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
- images.append(imageio.imread(img_name))
- imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
复制代码- import osimport matplotlib.pyplot as pltimport itertoolsimport pickleimport imageioimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torchvision import datasets, transforms# from torch.autograd import Variable# G(z)class generator(nn.Module): # initializers def __init__(self, input_size=32, n_class = 10): super(generator, self).__init__() self.fc1 = nn.Linear(input_size, 256) self.fc2 = nn.Linear(self.fc1.out_features, 512) self.fc3 = nn.Linear(self.fc2.out_features, 1024) self.fc4 = nn.Linear(self.fc3.out_features, n_class) # forward method def forward(self, input): x = F.leaky_relu(self.fc1(input), 0.2) x = F.leaky_relu(self.fc2(x), 0.2) x = F.leaky_relu(self.fc3(x), 0.2) x = F.tanh(self.fc4(x)) x = x.squeeze(-1) return xclass discriminator(nn.Module): # initializers def __init__(self, input_size=32, n_class=10): super(discriminator, self).__init__() self.fc1 = nn.Linear(input_size, 1024) self.fc2 = nn.Linear(self.fc1.out_features, 512) self.fc3 = nn.Linear(self.fc2.out_features, 256) self.fc4 = nn.Linear(self.fc3.out_features, n_class) # forward method def forward(self, input): x = F.leaky_relu(self.fc1(input), 0.2) x = F.dropout(x, 0.3) x = F.leaky_relu(self.fc2(x), 0.2) x = F.dropout(x, 0.3) x = F.leaky_relu(self.fc3(x), 0.2) x = F.dropout(x, 0.3) x = F.sigmoid(self.fc4(x)) x = x.squeeze(-1) return xfixed_z_ = torch.randn((5 * 5, 100)) # fixed noisewith torch.no_grad(): fixed_z_ = fixed_z_.cuda()def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False): z_ = torch.randn((5*5, 100)) with torch.no_grad(): z_ = z_.cuda() # z_ = Variable(z_.cuda(), volatile=True) G.eval() if isFix: test_images = G(fixed_z_) else: test_images = G(z_) G.train() size_figure_grid = 5 fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)): ax[i, j].get_xaxis().set_visible(False) ax[i, j].get_yaxis().set_visible(False) for k in range(5*5): i = k // 5 j = k % 5 ax[i, j].cla() ax[i, j].imshow(test_images[k, :].cpu().data.view(28, 28).numpy(), cmap='gray') label = 'Epoch {0}'.format(num_epoch) fig.text(0.5, 0.04, label, ha='center') plt.savefig(path) if show: plt.show() else: plt.close()def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'): x = range(len(hist['D_losses'])) y1 = hist['D_losses'] y2 = hist['G_losses'] plt.plot(x, y1, label='D_loss') plt.plot(x, y2, label='G_loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend(loc=4) plt.grid(True) plt.tight_layout() if save: plt.savefig(path) if show: plt.show() else: plt.close()# training parametersbatch_size = 256lr = 0.0002train_epoch = 200device = torch.cuda.is_available()if device: print("running on GPU!")# data_loader
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize(mean=(0.5), std=(0.5))
- ])
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST('data', train=True, download=True, transform=transform),
- batch_size=batch_size, shuffle=True)
- # networkG = generator(input_size=100, n_class=28*28)D = discriminator(input_size=28*28, n_class=1)# Binary Cross Entropy lossBCE_loss = nn.BCELoss()#move to cudaif device: G.cuda() D.cuda() BCE_loss = BCE_loss.cuda()# Adam optimizerG_optimizer = optim.Adam(G.parameters(), lr=lr)D_optimizer = optim.Adam(D.parameters(), lr=lr)# results save folderif not os.path.isdir('MNIST_GAN_results'): os.mkdir('MNIST_GAN_results')if not os.path.isdir('MNIST_GAN_results/Random_results'): os.mkdir('MNIST_GAN_results/Random_results')if not os.path.isdir('MNIST_GAN_results/Fixed_results'): os.mkdir('MNIST_GAN_results/Fixed_results')train_hist = {}train_hist['D_losses'] = []train_hist['G_losses'] = []for epoch in range(train_epoch): D_losses = [] G_losses = [] #生成任务,不需要标签 for x_, _ in train_loader: #训练图像展平 x_ = x_.view(-1, 28 * 28) mini_batch = x_.size()[0] y_real_ = torch.ones(mini_batch) y_fake_ = torch.zeros(mini_batch) # train discriminator D D.zero_grad() z_ = torch.randn((mini_batch, 100)) if device: x_, y_real_, y_fake_ = x_.cuda(), y_real_.cuda(), y_fake_.cuda() z_ = z_.cuda() #真数据loss D_result = D(x_) D_real_loss = BCE_loss(D_result, y_real_) D_real_score = D_result #假数据loss G_result = G(z_) D_result = D(G_result) D_fake_loss = BCE_loss(D_result, y_fake_) D_fake_score = D_result D_train_loss = D_real_loss + D_fake_loss D_train_loss.backward() D_optimizer.step() D_losses.append(D_train_loss.item()) # train generator G G.zero_grad() # z_ = torch.randn((mini_batch, 100)) # if device: # z_ = z_.cuda() G_result = G(z_) D_result = D(G_result) G_train_loss = BCE_loss(D_result, y_real_) G_train_loss.backward() G_optimizer.step() G_losses.append(G_train_loss.item()) print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % ( (epoch + 1), train_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses)))) if epoch %10 == 0: p = 'MNIST_GAN_results/Random_results/MNIST_GAN_' + str(epoch + 1) + '.png' fixed_p = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(epoch + 1) + '.png' show_result((epoch+1), save=True, path=p, isFix=False) show_result((epoch+1), save=True, path=fixed_p, isFix=True) train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses))) train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))print("Training finish!... save training results")
- torch.save(G.state_dict(), "MNIST_GAN_results/generator_param.pkl")
- torch.save(D.state_dict(), "MNIST_GAN_results/discriminator_param.pkl")
- with open('MNIST_GAN_results/train_hist.pkl', 'wb') as f:
- pickle.dump(train_hist, f)
- show_train_hist(train_hist, save=True, path='MNIST_GAN_results/MNIST_GAN_train_hist.png')
- images = []
- for e in range(train_epoch):
- img_name = 'MNIST_GAN_results/Fixed_results/MNIST_GAN_' + str(e + 1) + '.png'
- images.append(imageio.imread(img_name))
- imageio.mimsave('MNIST_GAN_results/generation_animation.gif', images, fps=5)
复制代码 训练结果
以上是训练190个epoch后得到的结果,可以看到此中某些图片已经有了数字的模样。这里仅仅是使用了全连接层来搭建模子,如果使用卷积神经网络,效果会有更好的提拔,大家可以尝试一下。
遇到的问题
可以适本地提高batch size来提高训练速率,也可以切换更简单的loss函数来提高训练速率。
建议batch size从底到高慢慢调节,若batch size过高,可能导致模子训练出现问题。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |