马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
GAN 是一种常用的优秀的图像天生模子。我们使用了支持条件天生的 cGAN。下面先容简单 cGAN 模子的构建以及练习过程。
2.1 在 model 文件夹中新建 nets.py 文件
- import torch
- import torch.nn as nn
- # 生成器类
- class Generator(nn.Module):
- def __init__(self, nz=100, nc=3, ngf=128, num_classes=4):
- super(Generator, self).__init__()
- self.label_emb = nn.Embedding(num_classes, nz)
- self.main = nn.Sequential(
- nn.ConvTranspose2d(nz + nz, ngf * 8, 4, 1, 0, bias=False),
- nn.BatchNorm2d(ngf * 8),
- nn.ReLU(True),
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 4),
- nn.ReLU(True),
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 2),
- nn.ReLU(True),
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf),
- nn.ReLU(True),
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
- nn.Tanh()
- )
- def forward(self, z, labels):
- c = self.label_emb(labels).unsqueeze(2).unsqueeze(3)
- x = torch.cat([z, c], 1)
- return self.main(x)
- # 判别器类
- class Discriminator(nn.Module):
- def __init__(self, nc=3, ndf=64, num_classes=4):
- super(Discriminator, self).__init__()
- self.label_emb = nn.Embedding(num_classes, nc * 64 * 64)
- self.main = nn.Sequential(
- nn.Conv2d(nc + 1, ndf, 4, 2, 1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 2),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 4),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
- nn.Sigmoid()
- )
- def forward(self, img, labels):
- c = self.label_emb(labels).view(labels.size(0), 1, 64, 64)
- x = torch.cat([img, c], 1)
- return self.main(x)
复制代码 2.2新建cGAN_net.py
- import torch
- import torch.nn as nn
- from torch.optim import Adam
- from torchvision import datasets, transforms
- from torch.utils.data import DataLoader
- from torch.optim.lr_scheduler import StepLR
- # ===========================
- # Conditional DCGAN 实现
- # ===========================
- class cDCGAN:
- def __init__(self, data_root, batch_size, device, latent_dim=100, num_classes=4):
- self.device = device
- self.batch_size = batch_size
- self.latent_dim = latent_dim
- self.num_classes = num_classes
- # 数据加载器
- self.train_loader = self.get_dataloader(data_root)
- # 初始化生成器和判别器
- self.generator = self.build_generator().to(device)
- self.discriminator = self.build_discriminator().to(device)
- # 初始化权重
- self.generator.apply(self.weights_init)
- self.discriminator.apply(self.weights_init)
- # 损失函数和优化器
- self.criterion = nn.BCELoss()
- self.optimizer_G = Adam(self.generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
- self.optimizer_D = Adam(self.discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))
- # 学习率调度器
- self.scheduler_G = StepLR(self.optimizer_G, step_size=10, gamma=0.5) # 每10个epoch学习率减半
- self.scheduler_D = StepLR(self.optimizer_D, step_size=10, gamma=0.5)
- def get_dataloader(self, data_root):
- transform = transforms.Compose([
- transforms.Resize(128),
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,))
- ])
- dataset = datasets.ImageFolder(root=data_root, transform=transform)
- return DataLoader(dataset, batch_size=self.batch_size, shuffle=True,
- num_workers=8, pin_memory=True, persistent_workers=True)
- @staticmethod
- def weights_init(model):
- """权重初始化"""
- if isinstance(model, (nn.Conv2d, nn.Linear)):
- nn.init.normal_(model.weight.data, 0.0, 0.02)
- if model.bias is not None:
- nn.init.constant_(model.bias.data, 0)
- def train_step(self, epoch, step, num_epochs):
- """单次训练步骤"""
- self.generator.train()
- self.discriminator.train()
- G_losses, D_losses = [], []
- for i, (real_img, labels) in enumerate(self.train_loader):
- # 确保 real_img 和 labels 在同一设备
- real_img = real_img.to(self.device)
- labels = labels.to(self.device)
- batch_size = real_img.size(0)
- # # 标签 11.19 15:11:12修改
- # valid = torch.ones((batch_size, 1), device=self.device)
- # fake = torch.zeros((batch_size, 1), device=self.device)
- # 标签平滑
- # smooth_valid = torch.full((batch_size, 1), 1, device=self.device) # 平滑真实标签
- # smooth_fake = torch.full((batch_size, 1), 0, device=self.device) # 平滑伪造标签
- # smooth_valid = torch.full((batch_size, 1), torch.rand(1).item() * 0.1 + 0.9, device=self.device)
- # smooth_fake = torch.full((batch_size, 1), torch.rand(1).item() * 0.1, device=self.device)
- # smooth_valid = torch.full((batch_size, 1), max(0.7, 1 - epoch * 0.001), device=self.device)
- # smooth_fake = torch.full((batch_size, 1), min(0.3, epoch * 0.001), device=self.device)
- # 动态调整标签范围
- smooth_valid = torch.full((batch_size, 1), max(0.9, 1 - 0.0001 * epoch), device=self.device)
- smooth_fake = torch.full((batch_size, 1), min(0.1, 0.0001 * epoch), device=self.device)
- # 替换以下两处代码
- valid = smooth_valid
- fake = smooth_fake
- # ========== 训练判别器 ==========
- real_pred = self.discriminator(real_img, labels)
- # d_real_loss = self.criterion(real_pred, valid)
- d_real_loss = self.criterion(real_pred, valid - 0.1 * torch.rand_like(valid))
- noise = torch.randn(batch_size, self.latent_dim, device=self.device)
- # gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device)
- gen_labels = torch.randint(0, self.num_classes, (batch_size,), device=self.device) + torch.randint(-1, 2, (
- batch_size,), device=self.device)
- gen_labels = torch.clamp(gen_labels, 0, self.num_classes - 1) # 确保标签在范围内
- gen_img = self.generator(noise, gen_labels)
- fake_pred = self.discriminator(gen_img.detach(), gen_labels)
- # d_fake_loss = self.criterion(fake_pred, fake)
- d_fake_loss = self.criterion(fake_pred, fake + 0.1 * torch.rand_like(fake))
- d_loss = (d_real_loss + d_fake_loss) / 2
- self.optimizer_D.zero_grad()
- d_loss.backward()
- self.optimizer_D.step()
- D_losses.append(d_loss.item())
- # ========== 训练生成器 ==========
- gen_pred = self.discriminator(gen_img, gen_labels)
- g_loss = self.criterion(gen_pred, valid)
- self.optimizer_G.zero_grad()
- g_loss.backward()
- self.optimizer_G.step()
- G_losses.append(g_loss.item())
- print(f'第 {epoch}/{num_epochs} 轮, Batch {i + 1}/{len(self.train_loader)}, '
- f'D Loss: {d_loss:.4f}, G Loss: {g_loss:.4f}')
- step += 1
- return G_losses, D_losses, step
- def build_generator(self):
- """生成器"""
- return Generator(latent_dim=self.latent_dim, num_classes=self.num_classes)
- def build_discriminator(self):
- """判别器"""
- return Discriminator(num_classes=self.num_classes)
- def load_model(self, model_path):
- """加载模型权重"""
- checkpoint = torch.load(model_path, map_location=self.device)
- self.generator.load_state_dict(checkpoint['generator_state_dict'])
- self.optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
- self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
- self.optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
- epoch = checkpoint['epoch']
- print(f"加载了模型权重,起始训练轮次为 {epoch}")
- return epoch
- def save_model(self, epoch, save_path):
- """保存模型"""
- torch.save({
- 'epoch': epoch,
- 'scheduler_G_state_dict': self.scheduler_G.state_dict(),
- 'scheduler_D_state_dict': self.scheduler_D.state_dict(),
- 'generator_state_dict': self.generator.state_dict(),
- 'optimizer_G_state_dict': self.optimizer_G.state_dict(),
- 'discriminator_state_dict': self.discriminator.state_dict(),
- 'optimizer_D_state_dict': self.optimizer_D.state_dict(),
- }, save_path)
- print(f"模型已保存至 {save_path}")
- # ===========================
- # 生成器
- # ===========================
- class Generator(nn.Module):
- def __init__(self, latent_dim=100, num_classes=4, img_channels=3):
- super(Generator, self).__init__()
- self.latent_dim = latent_dim
- self.label_emb = nn.Embedding(num_classes, num_classes)
- self.init_size = 8
- self.l1 = nn.Linear(latent_dim + num_classes, 256 * self.init_size * self.init_size)
- self.conv_blocks = nn.Sequential(
- nn.BatchNorm2d(256),
- nn.Upsample(scale_factor=2),
- nn.Conv2d(256, 128, 3, padding=1),
- nn.BatchNorm2d(128),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Upsample(scale_factor=2),
- nn.Conv2d(128, 64, 3, padding=1),
- nn.BatchNorm2d(64),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Upsample(scale_factor=2),
- nn.Conv2d(64, 32, 3, padding=1),
- nn.BatchNorm2d(32),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Upsample(scale_factor=2),
- nn.Conv2d(32, img_channels, 3, padding=1),
- nn.Tanh()
- )
- def forward(self, noise, labels):
- labels = labels.to(self.label_emb.weight.device)
- label_embedding = self.label_emb(labels)
- x = torch.cat((noise, label_embedding), dim=1)
- x = self.l1(x).view(x.size(0), 256, self.init_size, self.init_size)
- return self.conv_blocks(x)
- # ===========================
- # 判别器
- # ===========================
- class Discriminator(nn.Module):
- def __init__(self, img_channels=3, num_classes=4):
- super(Discriminator, self).__init__()
- self.label_embedding = nn.Embedding(num_classes, img_channels)
- self.model = nn.Sequential(
- nn.Conv2d(img_channels * 2, 64, 4, stride=2, padding=1),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(64, 128, 4, stride=2, padding=1),
- nn.BatchNorm2d(128),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(128, 256, 4, stride=2, padding=1),
- nn.BatchNorm2d(256),
- nn.LeakyReLU(0.2, inplace=True),
- nn.Conv2d(256, 512, 4, stride=2, padding=1),
- nn.BatchNorm2d(512),
- nn.LeakyReLU(0.2, inplace=True)
- )
- self.output_layer = nn.Sequential(
- nn.Linear(512 * 8 * 8, 1),
- nn.Sigmoid()
- )
- def forward(self, img, labels):
- labels = labels.to(self.label_embedding.weight.device)
- label_embedding = self.label_embedding(labels).unsqueeze(2).unsqueeze(3)
- label_embedding = label_embedding.expand(-1, -1, img.size(2), img.size(3))
- x = torch.cat((img, label_embedding), dim=1)
- x = self.model(x).view(x.size(0), -1)
- return self.output_layer(x)
复制代码 2.3新建cGAN_trainer.py
- import os
- import torch
- import argparse
- from cGAN_net import cDCGAN
- from utils import plot_loss, plot_result
- import time
- os.environ['OMP_NUM_THREADS'] = '1'
- def main(args):
- # 初始化设备和训练参数
- device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
- model = cDCGAN(data_root=args.data_root, batch_size=args.batch_size, device=device, latent_dim=args.latent_dim)
- # 添加学习率调度器
- scheduler_G = torch.optim.lr_scheduler.StepLR(model.optimizer_G, step_size=10, gamma=0.5)
- scheduler_D = torch.optim.lr_scheduler.StepLR(model.optimizer_D, step_size=10, gamma=0.5)
- start_epoch = 0
- # 如果有保存的模型,加载
- if args.load_model and os.path.exists(args.load_model):
- start_epoch = model.load_model(args.load_model) + 1
- # 恢复调度器状态
- scheduler_G_path = f"{args.load_model}_scheduler_G.pt"
- scheduler_D_path = f"{args.load_model}_scheduler_D.pt"
- if os.path.exists(scheduler_G_path) and os.path.exists(scheduler_D_path):
- scheduler_G.load_state_dict(torch.load(scheduler_G_path))
- scheduler_D.load_state_dict(torch.load(scheduler_D_path))
- print(f"成功恢复调度器状态:{scheduler_G_path}, {scheduler_D_path}")
- else:
- print("未找到调度器状态文件,使用默认调度器设置")
- print(f"从第 {start_epoch} 轮继续训练...")
- print(f"开始训练,从第 {start_epoch + 1} 轮开始...")
- # 创建保存路径
- os.makedirs(args.save_dir, exist_ok=True)
- os.makedirs(os.path.join(args.save_dir, 'log'), exist_ok=True)
- # 训练循环
- D_avg_losses, G_avg_losses = [], []
- for epoch in range(start_epoch, args.epochs):
- G_losses, D_losses, step = model.train_step(epoch, step=0, num_epochs=args.epochs)
- # 计算平均损失
- D_avg_loss = sum(D_losses) / len(D_losses) if D_losses else 0.0
- G_avg_loss = sum(G_losses) / len(G_losses) if G_losses else 0.0
- D_avg_losses.append(D_avg_loss)
- G_avg_losses.append(G_avg_loss)
- # 保存损失曲线图
- plot_loss(start_epoch, args.epochs, D_avg_losses, G_avg_losses, epoch + 1, save=True,
- save_dir=os.path.join(args.save_dir, "log"))
- # 生成并保存图片
- labels = torch.tensor([0, 1, 2, 3]).to(device)
- if (epoch + 1) % args.save_freq == 0: # 每隔一定轮次保存生成结果
- z = torch.randn(len(labels), args.latent_dim, device=device) # 随机生成噪声
- plot_result(model.generator, z, labels, epoch + 1, save_dir=os.path.join(args.save_dir, 'log'))
- # 每10个epoch保存模型
- if (epoch + 1) % args.save_interval == 0:
- timestamp = int(time.time())
- save_path = os.path.join(args.save_dir, f"cgan_epoch_{epoch + 1}_{timestamp}.pth")
- model.save_model(epoch + 1, save_path)
- print(f"第 {epoch + 1} 轮的模型已保存,保存路径为 {save_path}")
- # 更新学习率调度器
- scheduler_G.step()
- scheduler_D.step()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument('--data_root', type=str, default='data/crop128', help="数据集根目录")
- parser.add_argument('--save_dir', type=str, default='./chkpt/cgan_model', help="保存模型的目录")
- parser.add_argument('--load_model', type=str, default=None, help="要加载的模型路径(可选)")
- parser.add_argument('--epochs', type=int, default=1000, help="训练的轮数")
- parser.add_argument('--save_interval', type=int, default=10, help="保存模型检查点的间隔(按轮数)")
- parser.add_argument('--batch_size', type=int, default=64, help="训练的批次大小")
- parser.add_argument('--device', type=str, default='cuda', help="使用的设备(如 cuda 或 cpu)")
- parser.add_argument('--latent_dim', type=int, default=100, help="生成器的潜在空间维度")
- parser.add_argument('--save_freq', type=int, default=1, help="每隔多少轮保存一次生成结果(默认: 1)")
- args = parser.parse_args()
- main(args)
复制代码 效果分析:
2.4中间效果可视化处理
新建utils.py,编写绘制中间效果和中间损失线图的函数,代码如下:
- import matplotlib.pyplot as plt
- import numpy as np
- import os
- import torch
- def denorm(x):
- out = (x + 1) / 2
- return out.clamp(0, 1)
- def plot_loss(start_epoch, num_epochs, d_losses, g_losses, num_epoch, save=False, save_dir='celebA_cDCGAN_results/', show=False):
- """
- 绘制损失函数曲线,从 start_epoch 到 num_epochs。
- Args:
- start_epoch: 起始轮次
- num_epochs: 总轮次
- d_losses: 判别器损失列表
- g_losses: 生成器损失列表
- num_epoch: 当前训练轮次
- save: 是否保存绘图
- save_dir: 保存路径
- show: 是否显示绘图
- """
- fig, ax = plt.subplots()
- ax.set_xlim(start_epoch, num_epochs)
- ax.set_ylim(0, max(np.max(g_losses), np.max(d_losses)) * 1.1)
- plt.xlabel(f'Epoch {num_epoch + 1}')
- plt.ylabel('Loss values')
- plt.plot(d_losses, label='Discriminator')
- plt.plot(g_losses, label='Generator')
- plt.legend()
- if save:
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
- save_fn = os.path.join(save_dir, f'cDCGAN_losses_epoch.png')
- plt.savefig(save_fn)
- if show:
- plt.show()
- else:
- plt.close()
- def plot_result(generator, z, labels, epoch, save_dir=None, show=False):
- """
- 生成并保存或显示生成的图片结果。
- Args:
- generator: 生成器模型
- z: 随机噪声张量
- labels: 标签张量
- epoch: 当前训练轮数
- save_dir: 保存图片的路径(可选)
- show: 是否显示生成的图片(可选)
- """
- # 调用生成器,生成图像
- generator.eval() # 设置为评估模式
- with torch.no_grad():
- gen_images = generator(z, labels) # 同时传入 z 和 labels
- generator.train() # 恢复训练模式
- # 图像反归一化
- gen_images = denorm(gen_images)
- # 绘制图片
- fig, ax = plt.subplots(1, len(gen_images), figsize=(15, 15))
- for i in range(len(gen_images)):
- ax[i].imshow(gen_images[i].permute(1, 2, 0).cpu().numpy()) # 转换为可显示格式
- ax[i].axis('off')
- # 保存或显示图片
- if save_dir:
- os.makedirs(save_dir, exist_ok=True)
- save_path = os.path.join(save_dir, f'epoch_{epoch}.png')
- plt.savefig(save_path)
- if show:
- plt.show()
- plt.close(fig)
复制代码执行 cGAN_trainer.py 文件,完成模子练习。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |