PyTorch `.pth` 转 ONNX:从模型训练到跨平台部署

立聪堂德州十三局店  论坛元老 | 2025-2-18 22:00:47 | 来自手机 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1022|帖子 1022|积分 3066

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
PyTorch .pth 转 ONNX:从模型训练到跨平台部署

在深度学习里,模型的格式决定了它的可用性
如果你是 PyTorch 用户,你大概认识 .pth 文件,它用于存储训练好的模型。
但当你想在不同的环境(如 TensorRT、OpenVINO、ONNX Runtime)部署模型时,.pth 大概并不适用。这时,ONNX(Open Neural Network Exchange)就必不可少。
本文目录:


  • 什么是 .pth 文件?
  • 什么是 .onnx 文件?
  • 为什么要转换?
  • 怎样转换 .pth 到 .onnx?
  • 转换后的好处和潜伏风险

1. 什么是 .pth 文件?

.pth 是 PyTorch 专属的模型权重文件,用于存储:

  • 模型权重(state_dict):仅生存参数,不包罗模型结构。
  • 完整模型:包罗模型结构和权重,适用于直接 torch.save(model, "model.pth") 生存的环境。
在 PyTorch 中,你可以用以下方式加载 .pth:
  1. import torch
  2. from NestedUNet import NestedUNet  # 你的模型类
  3. # 仅保存权重的加载方式
  4. model = NestedUNet(num_classes=2, input_channels=3)
  5. model.load_state_dict(torch.load("best_model.pth"))
  6. model.eval()
复制代码
.pth 文件只能在 PyTorch 运行的环境中利用,不能直接在 TensorFlow、OpenVINO 或 TensorRT 里运行。

2. 什么是 ONNX?

ONNX(Open Neural Network Exchange)是 一个开放的神经网络标准格式,它的目的是:

  • 跨框架兼容:支持 PyTorch、TensorFlow、Keras、MXNet 等。
  • 优化推理:可以用 ONNX Runtime 或 TensorRT 加速推理。
  • 部署机动:支持在 CPU、GPU、FPGA、TPU 等硬件上运行。
ONNX 文件是一个 .onnx 文件,它包罗:


  • 模型的计算图
  • 算子(OPs)界说
  • 模型权重
ONNX 让你可以在不同平台上运行同一个模型,而不必依赖某个特定的深度学习框架。

3. 为什么要转换 .pth 到 .onnx?

转换为 ONNX 主要有以下好处:
跨平台兼容


  • .pth 只能在 PyTorch 里用,而 .onnx 可以在 TensorRT、ONNX Runtime、OpenVINO、CoreML 等多种环境中运行。
推理速率更快


  • ONNX Runtime 利用图优化(Graph Optimization),镌汰计算冗余,提高推理速率。
  • TensorRT 可以将 ONNX 模型编译为高度优化的 GPU 代码,明显提高吞吐量。
支持多种硬件


  • .pth 主要用于 CPU/GPU,而 .onnx 可用于 FPGA、TPU、ARM 装备,如 安卓手机、树莓派、Jetson Nano 等。
更轻量级


  • PyTorch 运行时需要完整的 Python 解释器,而 ONNX 可以直接用 C++/C 代码运行,适用于嵌入式装备。

4. 怎样转换 .pth 到 .onnx?

4.1 安装依赖

在转换前,确保你已安装 PyTorch 和 ONNX:
  1. pip install torch torchvision onnx
复制代码
4.2 编写转换代码

假设你有一个 NestedUNet 训练好的 .pth 文件,转换方式如下:
  1. import torch
  2. import torch.onnx
  3. from NestedUNet import NestedUNet  # 你的模型文件
  4. # 1. 加载 PyTorch 模型
  5. model = NestedUNet(num_classes=2, input_channels=3, deep_supervision=False)
  6. model.load_state_dict(torch.load("best_model.pth"))
  7. model.eval()
  8. # 2. 创建示例输入(确保形状正确)
  9. dummy_input = torch.randn(1, 3, 256, 256)
  10. # 3. 导出为 ONNX
  11. onnx_path = "nested_unet.onnx"
  12. torch.onnx.export(
  13.     model,
  14.     dummy_input,
  15.     onnx_path,
  16.     export_params=True,
  17.     opset_version=11,  # 确保兼容性
  18.     do_constant_folding=True,
  19.     input_names=["input"],
  20.     output_names=["output"],
  21.     dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
  22. )
  23. print(f"✅ 模型已成功转换为 {onnx_path}")
复制代码
4.3 验证 ONNX

安装 onnxruntime 并测试:
  1. pip install onnxruntime
复制代码
然后运行:
  1. import onnxruntime as ort
  2. import numpy as np
  3. # 加载 ONNX
  4. ort_session = ort.InferenceSession("nested_unet.onnx")
  5. # 生成随机输入
  6. input_data = np.random.randn(1, 3, 256, 256).astype(np.float32)
  7. outputs = ort_session.run(None, {"input": input_data})
  8. print("ONNX 推理结果:", outputs[0].shape)
复制代码

5. 转换后的好处和潜伏风险

5.1 好处

提高推理速率


  • ONNX Runtime 和 TensorRT 可以明显加速推理,尤其是在 GPU 上。
跨平台部署


  • .onnx 可用于 Windows、Linux、安卓、iOS、嵌入式装备。
镌汰依赖


  • 直接用 ONNX Runtime 运行,不需要完整的 PyTorch 依赖。

5.2 大概碰到的标题

ONNX 大概不支持某些 PyTorch 操作


  • PyTorch 的某些自界说操作(如 grid_sample)大概在 ONNX 不支持,需要手动修改模型。
ONNX 的 Upsample 大概需要 align_corners=False


  • 如果 Upsample(scale_factor=2, mode='bilinear', align_corners=True),大概会导致 ONNX 兼容性标题,建议改为:
    1. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
    复制代码
ONNX 在 CPU 上的推理大概比 PyTorch 慢


  • 如果模型没有经过优化,ONNX 大概不会比 PyTorch 快,尤其是在 CPU 上。
TensorRT 需要额外优化


  • 直接用 TensorRT 运行 ONNX 大概会报错,需要 onnx-simplifier:
    1. pip install onnx-simplifier
    2. python -m onnxsim nested_unet.onnx nested_unet_simplified.onnx
    复制代码

6. 对比

比较项.pth (PyTorch).onnx (ONNX)框架依赖仅支持 PyTorch兼容多框架推理速率较慢更快(ONNX Runtime / TensorRT)跨平台性仅支持 PyTorch可在多种装备上运行部署难度需要完整 Python轻量级,适用于嵌入式
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

立聪堂德州十三局店

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表