【计算机视觉技术 - 人脸天生】2.GAN网络的构建和练习 ...

打印 上一主题 下一主题

主题 1079|帖子 1079|积分 3237

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
         GAN 是一种常用的优秀的图像天生模子。我们使用了支持条件天生的 cGAN。下面先容简单 cGAN 模子的构建以及练习过程。
2.1 在 model 文件夹中新建 nets.py 文件

  1. import torch
  2. import torch.nn as nn
  3. # 生成器类
  4. class Generator(nn.Module):
  5.     def __init__(self, nz=100, nc=3, ngf=128, num_classes=4):
  6.         super(Generator, self).__init__()
  7.         self.label_emb = nn.Embedding(num_classes, nz)
  8.         self.main = nn.Sequential(
  9.             nn.ConvTranspose2d(nz + nz, ngf * 8, 4, 1, 0, bias=False),
  10.             nn.BatchNorm2d(ngf * 8),
  11.             nn.ReLU(True),
  12.             nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
  13.             nn.BatchNorm2d(ngf * 4),
  14.             nn.ReLU(True),
  15.             nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
  16.             nn.BatchNorm2d(ngf * 2),
  17.             nn.ReLU(True),
  18.             nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
  19.             nn.BatchNorm2d(ngf),
  20.             nn.ReLU(True),
  21.             nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
  22.             nn.Tanh()
  23.         )
  24.     def forward(self, z, labels):
  25.         c = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
  26.         x = torch.cat([z, c], 1)
  27.         return self.main(x)
  28. # 判别器类
  29. class Discriminator(nn.Module):
  30.     def __init__(self, nc=3, ndf=64, num_classes=4):
  31.         super(Discriminator, self).__init__()
  32.         self.label_emb = nn.Embedding(num_classes, nc * 64 * 64)
  33.         self.main = nn.Sequential(
  34.             nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
  35.             nn.LeakyReLU(0.2, inplace=True),
  36.             nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
  37.             nn.BatchNorm2d(ndf * 2),
  38.             nn.LeakyReLU(0.2, inplace=True),
  39.             nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
  40.             nn.BatchNorm2d(ndf * 4),
  41.             nn.LeakyReLU(0.2, inplace=True),
  42.             nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
  43.             nn.Sigmoid()
  44.         )
  45.     def forward(self, img, labels):
  46.         c = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
  47.         x = torch.cat([img, c], 1)
  48.         return self.main(x)
复制代码
2.2新建cGAN_net.py

  1. import torch
  2. import torch.nn as nn
  3. from torch.optim import Adam
  4. from torchvision import datasets, transforms
  5. from torch.utils.data import DataLoader
  6. from torch.optim.lr_scheduler import StepLR
  7. # ===========================
  8. # Conditional DCGAN 实现
  9. # ===========================
  10. class cDCGAN:
  11.     def __init__(self, data_root, batch_size, device, latent_dim=100, num_classes=4):
  12.         self.device = device
  13.         self.batch_size = batch_size
  14.         self.latent_dim = latent_dim
  15.         self.num_classes = num_classes
  16.         # 数据加载器
  17.         self.train_loader = self.get_dataloader(data_root)
  18.         # 初始化生成器和判别器
  19.         self.generator = self.build_generator().to(device)
  20.         self.discriminator = self.build_discriminator().to(device)
  21.         # 初始化权重
  22.         self.generator.apply(self.weights_init)
  23.         self.discriminator.apply(self.weights_init)
  24.         # 损失函数和优化器
  25.         self.criterion = nn.BCELoss()
  26.         self.optimizer_G = Adam(self.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
  27.         self.optimizer_D = Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
  28.         # 学习率调度器
  29.         self.scheduler_G = StepLR(self.optimizer_G, step_size=10, gamma=0.5)  # 每10个epoch学习率减半
  30.         self.scheduler_D = StepLR(self.optimizer_D, step_size=10, gamma=0.5)
  31.     def get_dataloader(self, data_root):
  32.         transform = transforms.Compose([
  33.             transforms.Resize(128),
  34.             transforms.ToTensor(),
  35.             transforms.Normalize((0.5,), (0.5,))
  36.         ])
  37.         dataset = datasets.ImageFolder(root=data_root, transform=transform)
  38.         return DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
  39.                           num_workers=8, pin_memory=True, persistent_workers=True)
  40.     @staticmethod
  41.     def weights_init(model):
  42.         """权重初始化"""
  43.         if isinstance(model, (nn.Conv2d, nn.Linear)):
  44.             nn.init.normal_(model.weight.data, 0.0, 0.02)
  45.             if model.bias is not None:
  46.                 nn.init.constant_(model.bias.data, 0)
  47.     def train_step(self, epoch, step, num_epochs):
  48.         """单次训练步骤"""
  49.         self.generator.train()
  50.         self.discriminator.train()
  51.         G_losses, D_losses = [], []
  52.         for i, (real_img, labels) in enumerate(self.train_loader):
  53.             # 确保 real_img 和 labels 在同一设备
  54.             real_img = real_img.to(self.device)
  55.             labels = labels.to(self.device)
  56.             batch_size = real_img.size(0)
  57.             # # 标签   11.19 15:11:12修改
  58.             # valid = torch.ones((batch_size, 1), device=self.device)
  59.             # fake = torch.zeros((batch_size, 1), device=self.device)
  60.             # 标签平滑
  61.             # smooth_valid = torch.full((batch_size, 1), 1, device=self.device)  # 平滑真实标签
  62.             # smooth_fake = torch.full((batch_size, 1), 0, device=self.device)  # 平滑伪造标签
  63.             # smooth_valid = torch.full((batch_size, 1), torch.rand(1).item() * 0.1 + 0.9, device=self.device)
  64.             # smooth_fake = torch.full((batch_size, 1), torch.rand(1).item() * 0.1, device=self.device)
  65.             # smooth_valid = torch.full((batch_size, 1), max(0.7, 1 - epoch * 0.001), device=self.device)
  66.             # smooth_fake = torch.full((batch_size, 1), min(0.3, epoch * 0.001), device=self.device)
  67.             # 动态调整标签范围
  68.             smooth_valid = torch.full((batch_size, 1), max(0.9, 1 - 0.0001 * epoch), device=self.device)
  69.             smooth_fake = torch.full((batch_size, 1), min(0.1, 0.0001 * epoch), device=self.device)
  70.             # 替换以下两处代码
  71.             valid = smooth_valid
  72.             fake = smooth_fake
  73.             # ========== 训练判别器 ==========
  74.             real_pred = self.discriminator(real_img, labels)
  75.             # d_real_loss = self.criterion(real_pred, valid)
  76.             d_real_loss = self.criterion(real_pred, valid - 0.1 * torch.rand_like(valid))
  77.             noise = torch.randn(batch_size, self.latent_dim, device=self.device)
  78.             # gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device)
  79.             gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device) + torch.randint(-1, 2, (
  80.             batch_size,), device=self.device)
  81.             gen_labels = torch.clamp(gen_labels, 0, self.num_classes - 1)  # 确保标签在范围内
  82.             gen_img = self.generator(noise, gen_labels)
  83.             fake_pred = self.discriminator(gen_img.detach(), gen_labels)
  84.             # d_fake_loss = self.criterion(fake_pred, fake)
  85.             d_fake_loss = self.criterion(fake_pred, fake + 0.1 * torch.rand_like(fake))
  86.             d_loss = (d_real_loss + d_fake_loss) / 2
  87.             self.optimizer_D.zero_grad()
  88.             d_loss.backward()
  89.             self.optimizer_D.step()
  90.             D_losses.append(d_loss.item())
  91.             # ========== 训练生成器 ==========
  92.             gen_pred = self.discriminator(gen_img, gen_labels)
  93.             g_loss = self.criterion(gen_pred, valid)
  94.             self.optimizer_G.zero_grad()
  95.             g_loss.backward()
  96.             self.optimizer_G.step()
  97.             G_losses.append(g_loss.item())
  98.             print(f'第 {epoch}/{num_epochs} 轮, Batch {i + 1}/{len(self.train_loader)}, '
  99.                   f'D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}')
  100.         step += 1
  101.         return G_losses, D_losses, step
  102.     def build_generator(self):
  103.         """生成器"""
  104.         return Generator(latent_dim=self.latent_dim, num_classes=self.num_classes)
  105.     def build_discriminator(self):
  106.         """判别器"""
  107.         return Discriminator(num_classes=self.num_classes)
  108.     def load_model(self, model_path):
  109.         """加载模型权重"""
  110.         checkpoint = torch.load(model_path, map_location=self.device)
  111.         self.generator.load_state_dict(checkpoint['generator_state_dict'])
  112.         self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
  113.         self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
  114.         self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
  115.         epoch = checkpoint['epoch']
  116.         print(f"加载了模型权重,起始训练轮次为 {epoch}")
  117.         return epoch
  118.     def save_model(self, epoch, save_path):
  119.         """保存模型"""
  120.         torch.save({
  121.             'epoch': epoch,
  122.             'scheduler_G_state_dict': self.scheduler_G.state_dict(),
  123.             'scheduler_D_state_dict': self.scheduler_D.state_dict(),
  124.             'generator_state_dict': self.generator.state_dict(),
  125.             'optimizer_G_state_dict': self.optimizer_G.state_dict(),
  126.             'discriminator_state_dict': self.discriminator.state_dict(),
  127.             'optimizer_D_state_dict': self.optimizer_D.state_dict(),
  128.         }, save_path)
  129.         print(f"模型已保存至 {save_path}")
  130. # ===========================
  131. # 生成器
  132. # ===========================
  133. class Generator(nn.Module):
  134.     def __init__(self, latent_dim=100, num_classes=4, img_channels=3):
  135.         super(Generator, self).__init__()
  136.         self.latent_dim = latent_dim
  137.         self.label_emb = nn.Embedding(num_classes, num_classes)
  138.         self.init_size = 8
  139.         self.l1 = nn.Linear(latent_dim + num_classes, 256 * self.init_size * self.init_size)
  140.         self.conv_blocks = nn.Sequential(
  141.             nn.BatchNorm2d(256),
  142.             nn.Upsample(scale_factor=2),
  143.             nn.Conv2d(256, 128, 3, padding=1),
  144.             nn.BatchNorm2d(128),
  145.             nn.LeakyReLU(0.2, inplace=True),
  146.             nn.Upsample(scale_factor=2),
  147.             nn.Conv2d(128, 64, 3, padding=1),
  148.             nn.BatchNorm2d(64),
  149.             nn.LeakyReLU(0.2, inplace=True),
  150.             nn.Upsample(scale_factor=2),
  151.             nn.Conv2d(64, 32, 3, padding=1),
  152.             nn.BatchNorm2d(32),
  153.             nn.LeakyReLU(0.2, inplace=True),
  154.             nn.Upsample(scale_factor=2),
  155.             nn.Conv2d(32, img_channels, 3, padding=1),
  156.             nn.Tanh()
  157.         )
  158.     def forward(self, noise, labels):
  159.         labels = labels.to(self.label_emb.weight.device)
  160.         label_embedding = self.label_emb(labels)
  161.         x = torch.cat((noise, label_embedding), dim=1)
  162.         x = self.l1(x).view(x.size(0), 256, self.init_size, self.init_size)
  163.         return self.conv_blocks(x)
  164. # ===========================
  165. # 判别器
  166. # ===========================
  167. class Discriminator(nn.Module):
  168.     def __init__(self, img_channels=3, num_classes=4):
  169.         super(Discriminator, self).__init__()
  170.         self.label_embedding = nn.Embedding(num_classes, img_channels)
  171.         self.model = nn.Sequential(
  172.             nn.Conv2d(img_channels * 2, 64, 4, stride=2, padding=1),
  173.             nn.LeakyReLU(0.2, inplace=True),
  174.             nn.Conv2d(64, 128, 4, stride=2, padding=1),
  175.             nn.BatchNorm2d(128),
  176.             nn.LeakyReLU(0.2, inplace=True),
  177.             nn.Conv2d(128, 256, 4, stride=2, padding=1),
  178.             nn.BatchNorm2d(256),
  179.             nn.LeakyReLU(0.2, inplace=True),
  180.             nn.Conv2d(256, 512, 4, stride=2, padding=1),
  181.             nn.BatchNorm2d(512),
  182.             nn.LeakyReLU(0.2, inplace=True)
  183.         )
  184.         self.output_layer = nn.Sequential(
  185.             nn.Linear(512 * 8 * 8, 1),
  186.             nn.Sigmoid()
  187.         )
  188.     def forward(self, img, labels):
  189.         labels = labels.to(self.label_embedding.weight.device)
  190.         label_embedding = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)
  191.         label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
  192.         x = torch.cat((img, label_embedding), dim=1)
  193.         x = self.model(x).view(x.size(0), -1)
  194.         return self.output_layer(x)
复制代码
2.3新建cGAN_trainer.py

  1. import os
  2. import torch
  3. import argparse
  4. from cGAN_net import cDCGAN
  5. from utils import plot_loss, plot_result
  6. import time
  7. os.environ['OMP_NUM_THREADS'] = '1'
  8. def main(args):
  9.     # 初始化设备和训练参数
  10.     device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
  11.     model = cDCGAN(data_root=args.data_root, batch_size=args.batch_size, device=device, latent_dim=args.latent_dim)
  12.     # 添加学习率调度器
  13.     scheduler_G = torch.optim.lr_scheduler.StepLR(model.optimizer_G, step_size=10, gamma=0.5)
  14.     scheduler_D = torch.optim.lr_scheduler.StepLR(model.optimizer_D, step_size=10, gamma=0.5)
  15.     start_epoch = 0
  16.     # 如果有保存的模型,加载
  17.     if args.load_model and os.path.exists(args.load_model):
  18.         start_epoch = model.load_model(args.load_model) + 1
  19.         # 恢复调度器状态
  20.         scheduler_G_path = f"{args.load_model}_scheduler_G.pt"
  21.         scheduler_D_path = f"{args.load_model}_scheduler_D.pt"
  22.         if os.path.exists(scheduler_G_path) and os.path.exists(scheduler_D_path):
  23.             scheduler_G.load_state_dict(torch.load(scheduler_G_path))
  24.             scheduler_D.load_state_dict(torch.load(scheduler_D_path))
  25.             print(f"成功恢复调度器状态:{scheduler_G_path}, {scheduler_D_path}")
  26.         else:
  27.             print("未找到调度器状态文件,使用默认调度器设置")
  28.         print(f"从第 {start_epoch} 轮继续训练...")
  29.     print(f"开始训练,从第 {start_epoch + 1} 轮开始...")
  30.     # 创建保存路径
  31.     os.makedirs(args.save_dir, exist_ok=True)
  32.     os.makedirs(os.path.join(args.save_dir, 'log'), exist_ok=True)
  33.     # 训练循环
  34.     D_avg_losses, G_avg_losses = [], []
  35.     for epoch in range(start_epoch, args.epochs):
  36.         G_losses, D_losses, step = model.train_step(epoch, step=0, num_epochs=args.epochs)
  37.         # 计算平均损失
  38.         D_avg_loss = sum(D_losses) / len(D_losses) if D_losses else 0.0
  39.         G_avg_loss = sum(G_losses) / len(G_losses) if G_losses else 0.0
  40.         D_avg_losses.append(D_avg_loss)
  41.         G_avg_losses.append(G_avg_loss)
  42.         # 保存损失曲线图
  43.         plot_loss(start_epoch, args.epochs, D_avg_losses, G_avg_losses, epoch + 1, save=True,
  44.                   save_dir=os.path.join(args.save_dir, "log"))
  45.         # 生成并保存图片
  46.         labels = torch.tensor([0, 1, 2, 3]).to(device)
  47.         if (epoch + 1) % args.save_freq == 0:  # 每隔一定轮次保存生成结果
  48.             z = torch.randn(len(labels), args.latent_dim, device=device)  # 随机生成噪声
  49.             plot_result(model.generator, z, labels, epoch + 1, save_dir=os.path.join(args.save_dir, 'log'))
  50.         # 每10个epoch保存模型
  51.         if (epoch + 1) % args.save_interval == 0:
  52.             timestamp = int(time.time())
  53.             save_path = os.path.join(args.save_dir, f"cgan_epoch_{epoch + 1}_{timestamp}.pth")
  54.             model.save_model(epoch + 1, save_path)
  55.             print(f"第 {epoch + 1} 轮的模型已保存,保存路径为 {save_path}")
  56.         # 更新学习率调度器
  57.         scheduler_G.step()
  58.         scheduler_D.step()
  59. if __name__ == "__main__":
  60.     parser = argparse.ArgumentParser()
  61.     parser.add_argument('--data_root', type=str, default='data/crop128', help="数据集根目录")
  62.     parser.add_argument('--save_dir', type=str, default='./chkpt/cgan_model', help="保存模型的目录")
  63.     parser.add_argument('--load_model', type=str, default=None, help="要加载的模型路径(可选)")
  64.     parser.add_argument('--epochs', type=int, default=1000, help="训练的轮数")
  65.     parser.add_argument('--save_interval', type=int, default=10, help="保存模型检查点的间隔(按轮数)")
  66.     parser.add_argument('--batch_size', type=int, default=64, help="训练的批次大小")
  67.     parser.add_argument('--device', type=str, default='cuda', help="使用的设备(如 cuda 或 cpu)")
  68.     parser.add_argument('--latent_dim', type=int, default=100, help="生成器的潜在空间维度")
  69.     parser.add_argument('--save_freq', type=int, default=1, help="每隔多少轮保存一次生成结果(默认: 1)")
  70.     args = parser.parse_args()
  71.     main(args)
复制代码
效果分析:



2.4中间效果可视化处理

新建utils.py,编写绘制中间效果和中间损失线图的函数,代码如下:
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import os
  4. import torch
  5. def denorm(x):
  6.     out = (x + 1) / 2
  7.     return out.clamp(0, 1)
  8. def plot_loss(start_epoch, num_epochs, d_losses, g_losses, num_epoch, save=False, save_dir='celebA_cDCGAN_results/', show=False):
  9.     """
  10.     绘制损失函数曲线,从 start_epoch 到 num_epochs。
  11.     Args:
  12.         start_epoch: 起始轮次
  13.         num_epochs: 总轮次
  14.         d_losses: 判别器损失列表
  15.         g_losses: 生成器损失列表
  16.         num_epoch: 当前训练轮次
  17.         save: 是否保存绘图
  18.         save_dir: 保存路径
  19.         show: 是否显示绘图
  20.     """
  21.     fig, ax = plt.subplots()
  22.     ax.set_xlim(start_epoch, num_epochs)
  23.     ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses)) * 1.1)
  24.     plt.xlabel(f'Epoch {num_epoch + 1}')
  25.     plt.ylabel('Loss values')
  26.     plt.plot(d_losses, label='Discriminator')
  27.     plt.plot(g_losses, label='Generator')
  28.     plt.legend()
  29.     if save:
  30.         if not os.path.exists(save_dir):
  31.             os.makedirs(save_dir)
  32.         save_fn = os.path.join(save_dir, f'cDCGAN_losses_epoch.png')
  33.         plt.savefig(save_fn)
  34.     if show:
  35.         plt.show()
  36.     else:
  37.         plt.close()
  38. def plot_result(generator, z, labels, epoch, save_dir=None, show=False):
  39.     """
  40.     生成并保存或显示生成的图片结果。
  41.     Args:
  42.         generator: 生成器模型
  43.         z: 随机噪声张量
  44.         labels: 标签张量
  45.         epoch: 当前训练轮数
  46.         save_dir: 保存图片的路径(可选)
  47.         show: 是否显示生成的图片(可选)
  48.     """
  49.     # 调用生成器,生成图像
  50.     generator.eval()  # 设置为评估模式
  51.     with torch.no_grad():
  52.         gen_images = generator(z, labels)  # 同时传入 z 和 labels
  53.     generator.train()  # 恢复训练模式
  54.     # 图像反归一化
  55.     gen_images = denorm(gen_images)
  56.     # 绘制图片
  57.     fig, ax = plt.subplots(1, len(gen_images), figsize=(15, 15))
  58.     for i in range(len(gen_images)):
  59.         ax[i].imshow(gen_images[i].permute(1, 2, 0).cpu().numpy())  # 转换为可显示格式
  60.         ax[i].axis('off')
  61.     # 保存或显示图片
  62.     if save_dir:
  63.         os.makedirs(save_dir, exist_ok=True)
  64.         save_path = os.path.join(save_dir, f'epoch_{epoch}.png')
  65.         plt.savefig(save_path)
  66.     if show:
  67.         plt.show()
  68.     plt.close(fig)
复制代码
执行 cGAN_trainer.py 文件,完成模子练习。


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

勿忘初心做自己

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表