一、为什么要做模型摆设?
模型摆设是将训练好的模型投入现实应用的关键步骤,涉及:
- 模型格式转换(TorchScript/ONNX)
- 性能优化(量化/剪枝)
- 构建API服务
- 移动端集成
本章使用ResNet18实现图像分类,并演示完整摆设流程。
二、模型转换:TorchScript与ONNX
1. 准备预训练模型
- import torch
- import torchvision
- # 加载预训练模型
- model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
- model.eval()
- # 示例输入
- dummy_input = torch.rand(1, 3, 224, 224)
复制代码 2. 导出为TorchScript
- # 方法一:追踪执行路径(适合无控制流模型)
- traced_model = torch.jit.trace(model, dummy_input)
- torch.jit.save(traced_model, "resnet18_traced.pt")
- # 方法二:直接转换(适合含if/for的模型)
- scripted_model = torch.jit.script(model)
- torch.jit.save(scripted_model, "resnet18_scripted.pt")
- # 加载测试
- loaded_model = torch.jit.load("resnet18_traced.pt")
- output = loaded_model(dummy_input)
- print("TorchScript输出形状:", output.shape) # 应输出torch.Size([1, 1000])
复制代码 3. 导出为ONNX格式
- torch.onnx.export(
- model,
- dummy_input,
- "resnet18.onnx",
- input_names=["input"],
- output_names=["output"],
- dynamic_axes={
- 'input': {0: 'batch_size'},
- 'output': {0: 'batch_size'}
- }
- )
- # 验证ONNX模型
- import onnx
- onnx_model = onnx.load("resnet18.onnx")
- onnx.checker.check_model(onnx_model)
- print("ONNX模型输入输出:")
- print(onnx_model.graph.input)
- print(onnx_model.graph.output)
复制代码 三、构建API服务
1. 使用FastAPI创建Web服务
- from fastapi import FastAPI, File, UploadFile
- from PIL import Image
- import io
- import numpy as np
- import torchvision.transforms as transforms
- app = FastAPI()
- # 加载TorchScript模型
- model = torch.jit.load("resnet18_traced.pt")
- # 图像预处理
- preprocess = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(
- mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225]
- )
- ])
- @app.post("/predict")
- async def predict(image: UploadFile = File(...)):
- # 读取并预处理图像
- image_data = await image.read()
- img = Image.open(io.BytesIO(image_data)).convert("RGB")
- tensor = preprocess(img).unsqueeze(0)
-
- # 执行推理
- with torch.no_grad():
- output = model(tensor)
-
- # 获取预测结果
- _, pred = torch.max(output, 1)
- return {"class_id": int(pred)}
- # 运行命令:uvicorn main:app --reload
复制代码 2. 测试API服务
- import requests
- # 准备测试图片
- url = "https://images.unsplash.com/photo-1517849845537-4d257902454a?auto=format&fit=crop&w=224&q=80"
- response = requests.get(url)
- with open("test_dog.jpg", "wb") as f:
- f.write(response.content)
- # 发送预测请求
- with open("test_dog.jpg", "rb") as f:
- files = {"image": f}
- response = requests.post("http://localhost:8000/predict", files=files)
- print("预测结果:", response.json()) # 应输出对应类别ID
复制代码 四、移动端摆设(Android/iOS)
1. 转换Core ML格式(iOS)
- import coremltools as ct
- # 从PyTorch转换
- example_input = torch.rand(1, 3, 224, 224)
- traced_model = torch.jit.trace(model, example_input)
- mlmodel = ct.convert(
- traced_model,
- inputs=[ct.TensorType(shape=example_input.shape)]
- )
- mlmodel.save("ResNet18.mlmodel")
复制代码 2. 使用PyTorch Mobile(Android)
- // Android示例代码(Java)
- Module module = Module.load(assetFilePath(this, "resnet18_traced.pt"));
- Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
- bitmap,
- TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,
- TensorImageUtils.TORCHVISION_NORM_STD_RGB
- );
- Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
复制代码 五、性能优化技巧
1. 模型量化(淘汰体积/提升速率)
- # 动态量化
- quantized_model = torch.quantization.quantize_dynamic(
- model, {torch.nn.Linear}, dtype=torch.qint8
- )
- torch.jit.save(torch.jit.script(quantized_model), "resnet18_quantized.pt")
- # 测试量化效果
- print("原始模型大小:", sum(p.numel() for p in model.parameters()))
- print("量化模型大小:", sum(p.numel() for p in quantized_model.parameters()))
复制代码 2. ONNX Runtime加速推理
- import onnxruntime
- ort_session = onnxruntime.InferenceSession("resnet18.onnx")
- ort_inputs = {ort_session.get_inputs().name: dummy_input.numpy()}
- ort_outputs = ort_session.run(None, ort_inputs)
- print("ONNX Runtime输出形状:", ort_outputs.shape)
复制代码 六、常见问题解答
Q1:如那边理模型版本兼容性问题?
- 保持PyTorch版本一致(使用requirements.txt固定版本)
- 转换时指定opset_version:
- torch.onnx.export(..., opset_version=13)
复制代码 Q2:摆设时出现形状不匹配错误?
- 检查预处理是否与训练时一致
- 使用Netron可视化模型输入输出:
- pip install netron
- netron resnet18.onnx
复制代码 Q3:如何监控API性能?
- @app.middleware("http")
- async def add_process_time(request, call_next):
- start_time = time.time()
- response = await call_next(request)
- response.headers["X-Process-Time"] = str(time.time() - start_time)
- return response
复制代码 七、小结与下篇预报
- 本文重点:
- 模型格式转换(TorchScript/ONNX)
- 构建高并发API服务
- 移动端摆设与性能优化
- 下篇预报:
第六篇将深入PyTorch生态,介绍分布式训练与多GPU加速策略,实现工业级训练效率!
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |