ToB企服应用市场:ToB评测及商务社交产业平台

标题: Pytorch | 使用PI-FGSM针对CIFAR10上的ResNet分类器举行对抗攻击 [打印本页]

作者: 灌篮少年    时间: 2024-12-27 08:33
标题: Pytorch | 使用PI-FGSM针对CIFAR10上的ResNet分类器举行对抗攻击
之前已经针对CIFAR10练习了多种分类器:
Pytorch | 从零构建AlexNet对CIFAR10举行分类
Pytorch | 从零构建Vgg对CIFAR10举行分类
Pytorch | 从零构建GoogleNet对CIFAR10举行分类
Pytorch | 从零构建ResNet对CIFAR10举行分类
Pytorch | 从零构建MobileNet对CIFAR10举行分类
Pytorch | 从零构建EfficientNet对CIFAR10举行分类
Pytorch | 从零构建ParNet对CIFAR10举行分类
本篇文章我们使用Pytorch实现PI-FGSM对CIFAR10上的ResNet分类器举行攻击.
CIFAR数据集

CIFAR-10数据集是由加拿大高级研究所(CIFAR)网络整理的用于图像识别研究的常用数据集,基本信息如下:

下面是一些示例样本:

PI-FGSM介绍

PI-FGSM(Patch-wise Iterative Fast Gradient Sign Method)是一种针对主流正常练习和防御模型的黑盒攻击算法,旨在生成具有强转移性的对抗样本。该算法通过引入放大因子和投影核,以块(patch)为单元生成对抗噪声,从而进步对抗样本在差别模型间的转移性。
配景和动机


算法原理


算法流程


PI-FGSM代码实现

PI-FGSM算法实现

  1. import torch
  2. import torch.nn as nn
  3. def PI_FGSM(model, criterion, original_images, labels, epsilon, beta=5, kernel_size=3, num_iterations=10):
  4.     """
  5.     PI-FGSM (Patch-wise Iterative Fast Gradient Sign Method)
  6.     参数:
  7.     - model: 要攻击的模型
  8.     - criterion: 损失函数
  9.     - original_images: 原始图像
  10.     - labels: 原始图像的标签
  11.     - epsilon: 扰动幅度
  12.     - beta: 放大因子
  13.     - kernel_size: 投影核大小
  14.     - num_iterations: 迭代次数
  15.    
  16.     返回:
  17.     - perturbed_image: 生成的对抗样本
  18.     """
  19.     # gamma: 投影因子
  20.     gamma = epsilon / num_iterations * beta
  21.     # 初始化累积放大噪声和裁剪噪声
  22.     a = torch.zeros_like(original_images)
  23.     C = torch.zeros_like(original_images)
  24.     perturbed_images = original_images.clone().detach().requires_grad_(True)
  25.     # 定义投影核
  26.     Wp = torch.ones((kernel_size, kernel_size), dtype=torch.float32) / (kernel_size ** 2 - 1)
  27.     Wp[kernel_size // 2, kernel_size // 2] = 0
  28.     Wp = Wp.expand(original_images.size(1), -1, -1).to(original_images.device)
  29.     Wp = Wp.unsqueeze(0)
  30.     for _ in range(num_iterations):
  31.         # 计算梯度
  32.         outputs = model(perturbed_images)
  33.         loss = criterion(outputs, labels)
  34.         
  35.         model.zero_grad()
  36.         loss.backward()
  37.         
  38.         data_grad = perturbed_images.grad.data
  39.         # 更新累积放大噪声
  40.         a = a + beta * (epsilon / num_iterations) * data_grad.sign()
  41.         # 裁剪噪声
  42.         if a.abs().max() >= epsilon:
  43.             C = (a.abs() - epsilon).clamp(0, float('inf')) * a.sign()
  44.             a = a + gamma * torch.nn.functional.conv2d(input=C, weight=Wp, stride=1, padding=kernel_size // 2)
  45.         # 更新对抗样本
  46.         perturbed_images = perturbed_images + beta * (epsilon / num_iterations) * data_grad.sign() + gamma * torch.nn.functional.conv2d(C, Wp, stride=1, padding=kernel_size // 2)
  47.         
  48.         perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
  49.         perturbed_images = perturbed_images.detach().requires_grad_(True)
  50.     return perturbed_images
复制代码
攻击效果


代码汇总

pifgsm.py

  1. import torch
  2. import torch.nn as nn
  3. def PI_FGSM(model, criterion, original_images, labels, epsilon, beta=5, kernel_size=3, num_iterations=10):
  4.     """
  5.     PI-FGSM (Patch-wise Iterative Fast Gradient Sign Method)
  6.     参数:
  7.     - model: 要攻击的模型
  8.     - criterion: 损失函数
  9.     - original_images: 原始图像
  10.     - labels: 原始图像的标签
  11.     - epsilon: 扰动幅度
  12.     - beta: 放大因子
  13.     - kernel_size: 投影核大小
  14.     - num_iterations: 迭代次数
  15.    
  16.     返回:
  17.     - perturbed_image: 生成的对抗样本
  18.     """
  19.     # gamma: 投影因子
  20.     gamma = epsilon / num_iterations * beta
  21.     # 初始化累积放大噪声和裁剪噪声
  22.     a = torch.zeros_like(original_images)
  23.     C = torch.zeros_like(original_images)
  24.     perturbed_images = original_images.clone().detach().requires_grad_(True)
  25.     # 定义投影核
  26.     Wp = torch.ones((kernel_size, kernel_size), dtype=torch.float32) / (kernel_size ** 2 - 1)
  27.     Wp[kernel_size // 2, kernel_size // 2] = 0
  28.     Wp = Wp.expand(original_images.size(1), -1, -1).to(original_images.device)
  29.     Wp = Wp.unsqueeze(0)
  30.     for _ in range(num_iterations):
  31.         # 计算梯度
  32.         outputs = model(perturbed_images)
  33.         loss = criterion(outputs, labels)
  34.         
  35.         model.zero_grad()
  36.         loss.backward()
  37.         
  38.         data_grad = perturbed_images.grad.data
  39.         # 更新累积放大噪声
  40.         a = a + beta * (epsilon / num_iterations) * data_grad.sign()
  41.         # 裁剪噪声
  42.         if a.abs().max() >= epsilon:
  43.             C = (a.abs() - epsilon).clamp(0, float('inf')) * a.sign()
  44.             a = a + gamma * torch.nn.functional.conv2d(input=C, weight=Wp, stride=1, padding=kernel_size // 2)
  45.         # 更新对抗样本
  46.         perturbed_images = perturbed_images + beta * (epsilon / num_iterations) * data_grad.sign() + gamma * torch.nn.functional.conv2d(C, Wp, stride=1, padding=kernel_size // 2)
  47.         
  48.         perturbed_images = torch.clamp(perturbed_images, original_images - epsilon, original_images + epsilon)
  49.         perturbed_images = perturbed_images.detach().requires_grad_(True)
  50.     return perturbed_images
复制代码
train.py

  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from models import ResNet18
  6. # 数据预处理
  7. transform_train = transforms.Compose([
  8.     transforms.RandomCrop(32, padding=4),
  9.     transforms.RandomHorizontalFlip(),
  10.     transforms.ToTensor(),
  11.     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  12. ])
  13. transform_test = transforms.Compose([
  14.     transforms.ToTensor(),
  15.     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
  16. ])
  17. # 加载Cifar10训练集和测试集
  18. trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform_train)
  19. trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
  20. testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test)
  21. testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
  22. # 定义设备(GPU或CPU)
  23. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  24. # 初始化模型
  25. model = ResNet18(num_classes=10)
  26. model.to(device)
  27. # 定义损失函数和优化器
  28. criterion = nn.CrossEntropyLoss()
  29. optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
  30. if __name__ == "__main__":
  31.     # 训练模型
  32.     for epoch in range(10):  # 可以根据实际情况调整训练轮数
  33.         running_loss = 0.0
  34.         for i, data in enumerate(trainloader, 0):
  35.             inputs, labels = data[0].to(device), data[1].to(device)
  36.             optimizer.zero_grad()
  37.             outputs = model(inputs)
  38.             loss = criterion(outputs, labels)
  39.             loss.backward()
  40.             optimizer.step()
  41.             running_loss += loss.item()
  42.             if i % 100 == 99:
  43.                 print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss = {running_loss / 100}')
  44.                 running_loss = 0.0
  45.     torch.save(model.state_dict(), f'weights/epoch_{epoch + 1}.pth')
  46.     print('Finished Training')
复制代码
advtest.py

  1. import torch
  2. import torch.nn as nn
  3. import torchvision
  4. import torchvision.transforms as transforms
  5. from models import *
  6. from attacks import *
  7. import ssl
  8. import os
  9. from PIL import Image
  10. import matplotlib.pyplot as plt
  11. ssl._create_default_https_context = ssl._create_unverified_context
  12. # 定义数据预处理操作
  13. transform = transforms.Compose(
  14.     [transforms.ToTensor(),
  15.      transforms.Normalize((0.491, 0.482, 0.446), (0.247, 0.243, 0.261))])
  16. # 加载CIFAR10测试集
  17. testset = torchvision.datasets.CIFAR10(root='./data', train=False,
  18.                                        download=False, transform=transform)
  19. testloader = torch.utils.data.DataLoader(testset, batch_size=128,
  20.                                          shuffle=False, num_workers=2)
  21. # 定义设备(GPU优先,若可用)
  22. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  23. model = ResNet18(num_classes=10).to(device)
  24. criterion = nn.CrossEntropyLoss()
  25. # 加载模型权重
  26. weights_path = "weights/epoch_10.pth"
  27. model.load_state_dict(torch.load(weights_path, map_location=device))
  28. if __name__ == "__main__":
  29.     # 在测试集上进行FGSM攻击并评估准确率
  30.     model.eval()  # 设置为评估模式
  31.     correct = 0
  32.     total = 0
  33.     epsilon = 16 / 255  # 可以调整扰动强度
  34.     for data in testloader:
  35.         original_images, labels = data[0].to(device), data[1].to(device)
  36.         original_images.requires_grad = True
  37.         
  38.         attack_name = 'PI-FGSM'
  39.         if attack_name == 'FGSM':
  40.             perturbed_images = FGSM(model, criterion, original_images, labels, epsilon)
  41.         elif attack_name == 'BIM':
  42.             perturbed_images = BIM(model, criterion, original_images, labels, epsilon)
  43.         elif attack_name == 'MI-FGSM':
  44.             perturbed_images = MI_FGSM(model, criterion, original_images, labels, epsilon)
  45.         elif attack_name == 'NI-FGSM':
  46.             perturbed_images = NI_FGSM(model, criterion, original_images, labels, epsilon)
  47.         elif attack_name == 'PI-FGSM':
  48.             perturbed_images = PI_FGSM(model, criterion, original_images, labels, epsilon)
  49.         
  50.         perturbed_outputs = model(perturbed_images)
  51.         _, predicted = torch.max(perturbed_outputs.data, 1)
  52.         total += labels.size(0)
  53.         correct += (predicted == labels).sum().item()
  54.     accuracy = 100 * correct / total
  55.     # Attack Success Rate
  56.     ASR = 100 - accuracy
  57.     print(f'Load ResNet Model Weight from {weights_path}')
  58.     print(f'epsilon: {epsilon:.4f}')
  59.     print(f'ASR of {attack_name} : {ASR :.2f}%')
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4