模型摆设实战:PyTorch生产化指南

打印 上一主题 下一主题

主题 946|帖子 946|积分 2838

一、为什么要做模型摆设?

模型摆设是将训练好的模型‌投入现实应用‌的关键步骤,涉及:

  • 模型格式转换(TorchScript/ONNX)
  • 性能优化(量化/剪枝)
  • 构建API服务
  • 移动端集成
 本章使用ResNet18实现图像分类,并演示完整摆设流程。
二、模型转换:TorchScript与ONNX

1. 准备预训练模型

  1. import torch
  2. import torchvision
  3. # 加载预训练模型
  4. model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
  5. model.eval()
  6. # 示例输入
  7. dummy_input = torch.rand(1, 3, 224, 224)
复制代码
 ‌2. 导出为TorchScript
  1. # 方法一:追踪执行路径(适合无控制流模型)
  2. traced_model = torch.jit.trace(model, dummy_input)
  3. torch.jit.save(traced_model, "resnet18_traced.pt")
  4. # 方法二:直接转换(适合含if/for的模型)
  5. scripted_model = torch.jit.script(model)
  6. torch.jit.save(scripted_model, "resnet18_scripted.pt")
  7. # 加载测试
  8. loaded_model = torch.jit.load("resnet18_traced.pt")
  9. output = loaded_model(dummy_input)
  10. print("TorchScript输出形状:", output.shape)  # 应输出torch.Size([1, 1000])
复制代码
3. 导出为ONNX格式

  1. torch.onnx.export(
  2.     model,
  3.     dummy_input,
  4.     "resnet18.onnx",
  5.     input_names=["input"],
  6.     output_names=["output"],
  7.     dynamic_axes={
  8.         'input': {0: 'batch_size'},
  9.         'output': {0: 'batch_size'}
  10.     }
  11. )
  12. # 验证ONNX模型
  13. import onnx
  14. onnx_model = onnx.load("resnet18.onnx")
  15. onnx.checker.check_model(onnx_model)
  16. print("ONNX模型输入输出:")
  17. print(onnx_model.graph.input)
  18. print(onnx_model.graph.output)
复制代码
三、构建API服务

1. 使用FastAPI创建Web服务

  1. from fastapi import FastAPI, File, UploadFile
  2. from PIL import Image
  3. import io
  4. import numpy as np
  5. import torchvision.transforms as transforms
  6. app = FastAPI()
  7. # 加载TorchScript模型
  8. model = torch.jit.load("resnet18_traced.pt")
  9. # 图像预处理
  10. preprocess = transforms.Compose([
  11.     transforms.Resize(256),
  12.     transforms.CenterCrop(224),
  13.     transforms.ToTensor(),
  14.     transforms.Normalize(
  15.         mean=[0.485, 0.456, 0.406],
  16.         std=[0.229, 0.224, 0.225]
  17.     )
  18. ])
  19. @app.post("/predict")
  20. async def predict(image: UploadFile = File(...)):
  21.     # 读取并预处理图像
  22.     image_data = await image.read()
  23.     img = Image.open(io.BytesIO(image_data)).convert("RGB")
  24.     tensor = preprocess(img).unsqueeze(0)
  25.    
  26.     # 执行推理
  27.     with torch.no_grad():
  28.         output = model(tensor)
  29.    
  30.     # 获取预测结果
  31.     _, pred = torch.max(output, 1)
  32.     return {"class_id": int(pred)}
  33. # 运行命令:uvicorn main:app --reload
复制代码
2. 测试API服务

  1. import requests
  2. # 准备测试图片
  3. url = "https://images.unsplash.com/photo-1517849845537-4d257902454a?auto=format&fit=crop&w=224&q=80"
  4. response = requests.get(url)
  5. with open("test_dog.jpg", "wb") as f:
  6.     f.write(response.content)
  7. # 发送预测请求
  8. with open("test_dog.jpg", "rb") as f:
  9.     files = {"image": f}
  10.     response = requests.post("http://localhost:8000/predict", files=files)
  11.     print("预测结果:", response.json())  # 应输出对应类别ID
复制代码
四、移动端摆设(Android/iOS)

1. 转换Core ML格式(iOS)

  1. import coremltools as ct
  2. # 从PyTorch转换
  3. example_input = torch.rand(1, 3, 224, 224)
  4. traced_model = torch.jit.trace(model, example_input)
  5. mlmodel = ct.convert(
  6.     traced_model,
  7.     inputs=[ct.TensorType(shape=example_input.shape)]
  8. )
  9. mlmodel.save("ResNet18.mlmodel")
复制代码
2. 使用PyTorch Mobile(Android)

  1. // Android示例代码(Java)
  2. Module module = Module.load(assetFilePath(this, "resnet18_traced.pt"));
  3. Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
  4.     bitmap,
  5.     TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
  6.     TensorImageUtils.TORCHVISION_NORM_STD_RGB
  7. );
  8. Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
复制代码
五、性能优化技巧

1. 模型量化(淘汰体积/提升速率)

  1. # 动态量化
  2. quantized_model = torch.quantization.quantize_dynamic(
  3.     model, {torch.nn.Linear}, dtype=torch.qint8
  4. )
  5. torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt")
  6. # 测试量化效果
  7. print("原始模型大小:", sum(p.numel() for p in model.parameters()))
  8. print("量化模型大小:", sum(p.numel() for p in quantized_model.parameters()))
复制代码
2. ONNX Runtime加速推理

  1. import onnxruntime
  2. ort_session = onnxruntime.InferenceSession("resnet18.onnx")
  3. ort_inputs = {ort_session.get_inputs().name: dummy_input.numpy()}
  4. ort_outputs = ort_session.run(None, ort_inputs)
  5. print("ONNX Runtime输出形状:", ort_outputs.shape)
复制代码
六、常见问题解答

Q1:如那边理模型版本兼容性问题?



  • 保持PyTorch版本一致(使用requirements.txt固定版本)
  • 转换时指定opset_version:
  1. torch.onnx.export(..., opset_version=13)
复制代码
Q2:摆设时出现形状不匹配错误?



  • 检查预处理是否与训练时一致
  • 使用Netron可视化模型输入输出:
  1. pip install netron
  2. netron resnet18.onnx
复制代码
Q3:如何监控API性能?



  • 添加中间件记载响应时间:
  1. @app.middleware("http")
  2. async def add_process_time(request, call_next):
  3.     start_time = time.time()
  4.     response = await call_next(request)
  5.     response.headers["X-Process-Time"] = str(time.time() - start_time)
  6.     return response
复制代码
七、小结与下篇预报



  • 本文重点‌:

    • 模型格式转换(TorchScript/ONNX)
    • 构建高并发API服务
    • 移动端摆设与性能优化

  • 下篇预报‌:
    第六篇将深入PyTorch生态,介绍分布式训练与多GPU加速策略,实现工业级训练效率!

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

用户国营

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表