PyTorch中的损失函数:F.nll_loss 与 nn.CrossEntropyLoss

打印 上一主题 下一主题

主题 941|帖子 941|积分 2823

配景介绍

无论是图像分类、文本分类还是其他范例的分类使命,交织熵损失(Cross Entropy Loss)都是最常用的一种损失函数。它衡量的是模型预测的概率分布与真实标签之间的差异。在 PyTorch 中,有两个特别值得注意的实现:F.nll_loss 和 nn.CrossEntropyLoss。
F.nll_loss

什么是负对数似然损失?

F.nll_loss 是负对数似然损失(Negative Log Likelihood Loss),主要用于多类分类题目。它的输入是对数概率(log-probabilities),这意味着在利用 F.nll_loss 之前,我们需要先对模型的输出应用 log_softmax 函数,将原始输出转换为对数概率形式。
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, TensorDataset
  5. # 创建一些虚拟数据
  6. features = torch.randn(100, 20)  # 假设有100个样本,每个样本有20个特征
  7. labels = torch.randint(0, 3, (100,))  # 假设有3个类别
  8. # 创建数据加载器
  9. dataset = TensorDataset(features, labels)
  10. data_loader = DataLoader(dataset, batch_size=10, shuffle=True)
  11. class SimpleModel(nn.Module):
  12.     def __init__(self):
  13.         super(SimpleModel, self).__init__()
  14.         self.fc = nn.Linear(20, 3)  # 输入维度为20,输出维度为3(对应3个类别)
  15.     def forward(self, x):
  16.         return self.fc(x)
  17. model_nll = SimpleModel()
  18. optimizer = torch.optim.SGD(model_nll.parameters(), lr=0.01)
  19. for inputs, targets in data_loader:
  20.     optimizer.zero_grad()  # 清除梯度
  21.     outputs = model_nll(inputs)  # 模型前向传播
  22.     log_softmax_outputs = F.log_softmax(outputs, dim=1)  # 应用 log_softmax
  23.     loss = F.nll_loss(log_softmax_outputs, targets)  # 计算 nll_loss
  24.     loss.backward()  # 反向传播
  25.     optimizer.step()  # 更新权重
  26.     print(f"Batch Loss with F.nll_loss: {loss.item():.4f}")
复制代码
应用场景

由于 F.nll_loss 需要预先盘算 log_softmax,这为用户提供了一定程度的灵活性,尤其是在需要复用 log_softmax 结果的环境下。
nn.CrossEntropyLoss

简化工作流程

相比之下,nn.CrossEntropyLoss 更加直接和易用。它团结了 log_softmax 和 nll_loss 的功能,因此可以直接接受未经归一化的原始输出作为输入,内部主动完成这两个步骤。
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.utils.data import DataLoader, TensorDataset
  5. # 创建一些虚拟数据
  6. features = torch.randn(100, 20)  # 假设有100个样本,每个样本有20个特征
  7. labels = torch.randint(0, 3, (100,))  # 假设有3个类别
  8. # 创建数据加载器
  9. dataset = TensorDataset(features, labels)
  10. data_loader = DataLoader(dataset, batch_size=10, shuffle=True)
  11. class SimpleModel(nn.Module):
  12.     def __init__(self):
  13.         super(SimpleModel, self).__init__()
  14.         self.fc = nn.Linear(20, 3)  # 输入维度为20,输出维度为3(对应3个类别)
  15.     def forward(self, x):
  16.         return self.fc(x)
  17. model_ce = SimpleModel()
  18. criterion = nn.CrossEntropyLoss()
  19. optimizer = torch.optim.SGD(model_ce.parameters(), lr=0.01)
  20. for inputs, targets in data_loader:
  21.     optimizer.zero_grad()  # 清除梯度
  22.     outputs = model_ce(inputs)  # 模型前向传播
  23.     loss = criterion(outputs, targets)  # 直接计算交叉熵损失,内部包含 log_softmax
  24.     loss.backward()  # 反向传播
  25.     optimizer.step()  # 更新权重
  26.     print(f"Batch Loss with nn.CrossEntropyLoss: {loss.item():.4f}")
复制代码
内部机制

现实上,nn.CrossEntropyLoss = log_softmax + nll_loss 。这种设计简化了用户的代码编写过程,特别是当不需要对中间结果进行额外操纵时。
区别与联系



  • 输入要求:F.nll_loss 要求输入为 log_softmax 后的结果;而 nn.CrossEntropyLoss 可以直接接受未经 softmax 处置惩罚的原始输出。
  • 灵活性:假如需要对 log_softmax 结果进行进一步处置惩罚或调试,那么 F.nll_loss 提供了更大的灵活性。
  • 便捷性:对于大多数用户而言,nn.CrossEntropyLoss 因其简洁性和内置的 log_softmax 步骤,是更方便的选择。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

徐锦洪

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表