【大模型开辟】ONNX 格式的大模型在 Android 上的部署与测试 ...

打印 上一主题 下一主题

主题 1815|帖子 1815|积分 5445

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
以下内容将以 ONNX 格式的大模型在 Android 上的部署与测试为核心,提供一套可运行的示例(基于 Android Studio/Gradle),并联合代码进行详细讲授。最后会给出一些针对在移动设备上部署 ONNX 推理的优化方法和将来发起。

目次


  • 整体流程概述
  • 准备工作
    2.1 ONNX 模型准备
    2.2 Android 项目准备
  • 在 Android 上利用 ONNX Runtime
    3.1 添加依赖
    3.2 项目结构阐明
    3.3 代码示例
  • 运行与测试示例
  • 优化方法
    5.1 模型压缩与量化
    5.2 算子融合与图优化
    5.3 硬件加快接口
  • 将来发起

1. 整体流程概述


  • 模型转换:将训练好的大模型从 PyTorch/TensorFlow 等框架导出为 ONNX 格式。
  • 接入 ONNX Runtime for Android:在 Android 应用中,通过 onnxruntime-android 库进行模型推理。
  • 编写推理逻辑:在代码中加载 ONNX 模型文件,准备输入张量,执行推理,并获取输出。
  • 部署到手机并测试:将应用安装到 Android 设备上,测试推理速率、准确率等指标。
如果模型非常大,通常需进行模型剪枝、量化或其他优化。看前面优化文章

2. 准备工作

2.1 ONNX 模型准备

假设我们有一个 NLP 或 CV 的预训练模型(如 GPT、BERT、ResNet、YOLO 等),而且已经将其转换为 model.onnx 文件。


  • 如果你利用的是 PyTorch,可以通过 torch.onnx.export 导出;
  • 如果是 TensorFlow,可以借助 tf2onnx 或 TensorFlow 官方工具进行转换。
例如,PyTorch 导出示例(仅供参考):
  1. import torch
  2. import torchvision
  3. # 示例:导出一个 pretrained ResNet18
  4. model = torchvision.models.resnet18(pretrained=True)
  5. model.eval()
  6. dummy_input = torch.randn(1, 3, 224, 224)
  7. torch.onnx.export(
  8.     model,
  9.     dummy_input,
  10.     "model.onnx",
  11.     input_names=["input"],
  12.     output_names=["output"],
  13.     opset_version=11
  14. )
复制代码
导出完成后,会在本地得到一个 model.onnx 文件。
2.2 Android 项目准备


  • Android Studio:发起版本 4.0 以上。
  • Android Gradle Plugin:发起版本与 Android Studio 保持一致。
  • 最低 SDK 要求:一般 minSdkVersion 发起设置为 21 或更高,以支持大部分 NNAPI / 硬件加快。
  • 设备:至少需要拥有 ARM64 架构、足够的 RAM;若模型较大,需要高端手机或者淘汰模型规模。

3. 在 Android 上利用 ONNX Runtime

3.1 添加依赖

ONNX Runtime 官方已提供 Android AAR 包,可在 Gradle 中直接添加依赖。
在 app/build.gradle 中,添加类似以下内容:
  1. android {
  2.     // 其他配置
  3.     compileOptions {
  4.         sourceCompatibility JavaVersion.VERSION_1_8
  5.         targetCompatibility JavaVersion.VERSION_1_8
  6.     }
  7.     // 如果需要Kotlin,确保启用合适的编译选项
  8. }
  9. // 在dependencies中添加
  10. dependencies {
  11.     implementation 'org.onnxruntime:onnxruntime-android:1.14.1'
  12. }
复制代码
  版本号可根据 ONNX Runtime 官方发布 来更新(此处以 1.14.1 为例)。
  3.2 项目结构阐明

假设项目结构如下(只列关键文件):
  1. MyOnnxApp/
  2.   ├── app/
  3.   │   ├── src/
  4.   │   │   ├── main/
  5.   │   │   │   ├── AndroidManifest.xml
  6.   │   │   │   ├── java/com/example/myonnxapp/
  7.   │   │   │   │   ├── MainActivity.java
  8.   │   │   │   ├── assets/
  9.   │   │   │   │   ├── model.onnx   (ONNX文件)
  10.   │   │   │   ├── res/
  11.   │   │   │   │   └── layout/activity_main.xml
  12.   │   ├── build.gradle
  13.   ├── settings.gradle
  14.   └── build.gradle
复制代码
关键点:


  • 将 model.onnx 放入 app/src/main/assets 目次,以便在运行时能读取模型文件。
  • MainActivity 或其他类中加载模型并执行推理。
3.3 代码示例

下面是一个简单的 Java 版本示例(Kotlin 同理),演示如安在 Android 上初始化 ONNX Runtime、加载模型并进行一次推理。这里假设输入是 [1, 3, 224, 224] 的图像张量(如典型的 ImageNet 模型),根据模型现实情况更换。
3.3.1 MainActivity.java

  1. package com.example.myonnxapp;
  2. import androidx.appcompat.app.AppCompatActivity;
  3. import android.os.Bundle;
  4. import android.widget.TextView;
  5. import org.jetbrains.annotations.Nullable;
  6. import org.json.JSONObject;
  7. import org.tensorflow.lite.DataType;
  8. import java.io.IOException;
  9. import java.io.InputStream;
  10. import java.nio.FloatBuffer;
  11. import java.util.Arrays;
  12. import ai.onnxruntime.*;
  13. public class MainActivity extends AppCompatActivity {
  14.     private TextView resultText;
  15.     private OrtEnvironment env;
  16.     private OrtSession session;
  17.     @Override
  18.     protected void onCreate(Bundle savedInstanceState) {
  19.         super.onCreate(savedInstanceState);
  20.         setContentView(R.layout.activity_main);
  21.         resultText = findViewById(R.id.result_text);
  22.         // 初始化 ONNX Runtime
  23.         try {
  24.             initOnnxRuntime();
  25.             // 执行推理
  26.             float[] outputScores = runInference();
  27.             // 显示结果
  28.             resultText.setText("Inference Output: " + Arrays.toString(outputScores));
  29.         } catch (Exception e) {
  30.             e.printStackTrace();
  31.             resultText.setText("Error: " + e.getMessage());
  32.         }
  33.     }
  34.     private void initOnnxRuntime() throws OrtException {
  35.         // 创建 ORT 环境
  36.         env = OrtEnvironment.getEnvironment();
  37.         // 构建 SessionOptions
  38.         OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
  39.         // 可选: 使用 CPU 或 NNAPI 等加速,如果需要,可启用如下:
  40.         // sessionOptions.addNnapi();
  41.         
  42.         // 从assets加载模型
  43.         try {
  44.             InputStream modelStream = getAssets().open("model.onnx");
  45.             byte[] modelBytes = new byte[modelStream.available()];
  46.             modelStream.read(modelBytes);
  47.             session = env.createSession(modelBytes, sessionOptions);
  48.         } catch (IOException ioException) {
  49.             throw new RuntimeException("Failed to load model from assets", ioException);
  50.         }
  51.     }
  52.     private float[] runInference() throws OrtException {
  53.         // 准备输入张量
  54.         // 假设输入大小 [1, 3, 224, 224],数据类型 float32
  55.         float[] inputData = new float[1 * 3 * 224 * 224];
  56.         // 这里示例: 全部填充随机值 or 0.5f
  57.         // 实际中可来自图像预处理
  58.         for (int i = 0; i < inputData.length; i++) {
  59.             inputData[i] = 0.5f;
  60.         }
  61.         // ONNX Runtime需要将Java数组包装成OnnxTensor
  62.         long[] inputShape = new long[]{1, 3, 224, 224};
  63.         OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputShape);
  64.         // 准备输入名 (与导出时的 input_names 对应)
  65.         String inputName = session.getInputNames().iterator().next();
  66.         // 运行会话
  67.         OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
  68.         // 假设输出名为 "output",或者取 getOutputNames() 的第一个
  69.         String outputName = session.getOutputNames().iterator().next();
  70.         float[][] outputRaw = (float[][]) result.get(0).getValue();
  71.         // 此时 outputRaw 可能为 [1, num_classes],示例中只返回数组
  72.         float[] outputScores = outputRaw[0];
  73.         // 释放资源
  74.         inputTensor.close();
  75.         result.close();
  76.         return outputScores;
  77.     }
  78.     @Override
  79.     protected void onDestroy() {
  80.         super.onDestroy();
  81.         // 关闭 Session 和 Env,避免内存泄漏
  82.         if (session != null) {
  83.             try {
  84.                 session.close();
  85.             } catch (OrtException e) {
  86.                 e.printStackTrace();
  87.             }
  88.         }
  89.         if (env != null) {
  90.             try {
  91.                 env.close();
  92.             } catch (OrtException e) {
  93.                 e.printStackTrace();
  94.             }
  95.         }
  96.     }
  97. }
复制代码
3.3.2 activity_main.xml

  1. <?xml version="1.0" encoding="utf-8"?>
  2. <LinearLayout
  3.     xmlns:android="http://schemas.android.com/apk/res/android"
  4.     android:layout_width="match_parent"
  5.     android:layout_height="match_parent"
  6.     android:orientation="vertical"
  7.     android:gravity="center">
  8.    
  9.     <TextView
  10.         android:id="@+id/result_text"
  11.         android:layout_width="wrap_content"
  12.         android:layout_height="wrap_content"
  13.         android:text="ONNX Test"
  14.         android:textSize="20sp"/>
  15. </LinearLayout>
复制代码
以上示例中:

  • initOnnxRuntime(): 初始化 ONNX Runtime 情况,加载 assets 中的 model.onnx。
  • runInference(): 构造一个假输入,并调用 session.run() 获取推理结果。
  • 若要在手机上处理处罚现实图像或文本,需要在推理前进行预处理处罚(图像缩放、归一化,或分词、构造输入 ID 等)。
  • 若模型输出不止一个,需根据真实的模型输着名(session.getOutputNames()可获取)来取出对应的张量。

4. 运行与测试示例


  • 将 ONNX 模型放入 assets:确保 model.onnx 存在并可读取。
  • 编译并运行:连接 Android 设备(或利用模拟器,但大型模型推理更适合真机),点击 Run,查看日志或界面的输出信息。
  • 查抄结果:若模型是分类网络,可以对 outputScores 做 argmax,得到类别索引,并在界面中表现。
在真实应用中,你可以:


  • 通过手机摄像头或相册加载图像 -> 预处理处罚 -> 输入模型 -> 输出猜测结果。
  • 如果是语言模型或其他结构,也需要对应的数据预处理处罚与后处理处罚流程。

5. 优化方法

对于大模型,在移动端或嵌入式设备上的推理可能存在 内存、速率、功耗 等瓶颈。可从以下几个角度进行优化。
5.1 模型压缩与量化


  • Post-Training Quantization:将 FP32 转为 INT8;配合校准数据,可明显淘汰模型体积并提升推理速率。
  • Knowledge Distillation:训练一个尺寸更小的“学生模型”,在移动端部署。
  • Pruning / Sparsity:移除不重要的权重或通道(需要硬件和库支持稀疏加快)。
ONNX Runtime 也支持量化后的模型推理,但需确保量化算子在该版本的 ORT for Android 上可用。
5.2 算子融合与图优化


  • ONNX Graph Optimizer:在导出后,可对 ONNX 图进行融合、消除冗余节点等处理处罚。ONNX Runtime 默认会进行部分优化。
  • 淘汰不必要的操作:将预处理处罚/后处理处罚逻辑只管简化,或者放到端上原生代码里去执行。
5.3 硬件加快接口


  • NNAPI(Android):可通过 sessionOptions.addNnapi() 启用 Android NNAPI,让系统层面自动调度 GPU/NPU。
  • GPU Delegate:ONNX Runtime 提供部分 GPU 后端支持,但兼容度可能不及 TensorFlow Lite GPU Delegate。
  • DSP / NPU 厂商库:某些芯片厂商提供自定义加快库,可将 ONNX 模型进一步编译成特定格式。

6. 将来发起


  • 更灵活的肴杂部署
    对于超大模型,可以考虑在云端服务器执行大部分推理或粗特征提取,只在移动端做小模型的精调或快速推理。
  • 分片与流式推理
    在内存特别有限时,可以将模型分成多段,分批加载盘算。
  • 连续关注 ONNX Runtime 更新
    随着新版本的推出,硬件加快和量化等特性会不断完善。
  • 联合边缘专用硬件
    如果有专用设备(如 Google Coral TPU、NVIDIA Jetson NX、ARM Ethos 等),可考虑将 ONNX 模型部署到相应 SDK 中,大幅提升性能。
  • 联合自动化 NAS
    如果对精度和性能要求极高,可利用神经网络架构搜索(NAS)探求更适合移动设备的模型结构,在保持结果的同时极大降低推理成本。

总结

通过以上步骤,各人可以将 ONNX 格式的大模型在 Android 设备上进行推理测试,并联合 ONNX Runtime 的接口进行快速部署。对于在移动端部署大模型,发起在精度与资源之间做充实的衡量,并利用量化、剪枝、蒸馏等方法进行模型优化。联合硬件加快与工具链的进一步发展,移动端也能承载越来越强盛的 AI 本事,满意更多的现实业务需求。
哈佛博后带小白玩转呆板学习】 哔哩哔哩_bilibili
总课时超400+,时长75+小时

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

我爱普洱茶

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表