马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。
您需要 登录 才可以下载或查看,没有账号?立即注册
x
以下内容将以 ONNX 格式的大模型在 Android 上的部署与测试为核心,提供一套可运行的示例(基于 Android Studio/Gradle),并联合代码进行详细讲授。最后会给出一些针对在移动设备上部署 ONNX 推理的优化方法和将来发起。
目次
- 整体流程概述
- 准备工作
2.1 ONNX 模型准备
2.2 Android 项目准备
- 在 Android 上利用 ONNX Runtime
3.1 添加依赖
3.2 项目结构阐明
3.3 代码示例
- 运行与测试示例
- 优化方法
5.1 模型压缩与量化
5.2 算子融合与图优化
5.3 硬件加快接口
- 将来发起
1. 整体流程概述
- 模型转换:将训练好的大模型从 PyTorch/TensorFlow 等框架导出为 ONNX 格式。
- 接入 ONNX Runtime for Android:在 Android 应用中,通过 onnxruntime-android 库进行模型推理。
- 编写推理逻辑:在代码中加载 ONNX 模型文件,准备输入张量,执行推理,并获取输出。
- 部署到手机并测试:将应用安装到 Android 设备上,测试推理速率、准确率等指标。
如果模型非常大,通常需进行模型剪枝、量化或其他优化。看前面优化文章
2. 准备工作
2.1 ONNX 模型准备
假设我们有一个 NLP 或 CV 的预训练模型(如 GPT、BERT、ResNet、YOLO 等),而且已经将其转换为 model.onnx 文件。
- 如果你利用的是 PyTorch,可以通过 torch.onnx.export 导出;
- 如果是 TensorFlow,可以借助 tf2onnx 或 TensorFlow 官方工具进行转换。
例如,PyTorch 导出示例(仅供参考):
- import torch
- import torchvision
- # 示例:导出一个 pretrained ResNet18
- model = torchvision.models.resnet18(pretrained=True)
- model.eval()
- dummy_input = torch.randn(1, 3, 224, 224)
- torch.onnx.export(
- model,
- dummy_input,
- "model.onnx",
- input_names=["input"],
- output_names=["output"],
- opset_version=11
- )
复制代码 导出完成后,会在本地得到一个 model.onnx 文件。
2.2 Android 项目准备
- Android Studio:发起版本 4.0 以上。
- Android Gradle Plugin:发起版本与 Android Studio 保持一致。
- 最低 SDK 要求:一般 minSdkVersion 发起设置为 21 或更高,以支持大部分 NNAPI / 硬件加快。
- 设备:至少需要拥有 ARM64 架构、足够的 RAM;若模型较大,需要高端手机或者淘汰模型规模。
3. 在 Android 上利用 ONNX Runtime
3.1 添加依赖
ONNX Runtime 官方已提供 Android AAR 包,可在 Gradle 中直接添加依赖。
在 app/build.gradle 中,添加类似以下内容:
- android {
- // 其他配置
- compileOptions {
- sourceCompatibility JavaVersion.VERSION_1_8
- targetCompatibility JavaVersion.VERSION_1_8
- }
- // 如果需要Kotlin,确保启用合适的编译选项
- }
- // 在dependencies中添加
- dependencies {
- implementation 'org.onnxruntime:onnxruntime-android:1.14.1'
- }
复制代码 版本号可根据 ONNX Runtime 官方发布 来更新(此处以 1.14.1 为例)。
3.2 项目结构阐明
假设项目结构如下(只列关键文件):
- MyOnnxApp/
- ├── app/
- │ ├── src/
- │ │ ├── main/
- │ │ │ ├── AndroidManifest.xml
- │ │ │ ├── java/com/example/myonnxapp/
- │ │ │ │ ├── MainActivity.java
- │ │ │ ├── assets/
- │ │ │ │ ├── model.onnx (ONNX文件)
- │ │ │ ├── res/
- │ │ │ │ └── layout/activity_main.xml
- │ ├── build.gradle
- ├── settings.gradle
- └── build.gradle
复制代码 关键点:
- 将 model.onnx 放入 app/src/main/assets 目次,以便在运行时能读取模型文件。
- MainActivity 或其他类中加载模型并执行推理。
3.3 代码示例
下面是一个简单的 Java 版本示例(Kotlin 同理),演示如安在 Android 上初始化 ONNX Runtime、加载模型并进行一次推理。这里假设输入是 [1, 3, 224, 224] 的图像张量(如典型的 ImageNet 模型),根据模型现实情况更换。
3.3.1 MainActivity.java
- package com.example.myonnxapp;
- import androidx.appcompat.app.AppCompatActivity;
- import android.os.Bundle;
- import android.widget.TextView;
- import org.jetbrains.annotations.Nullable;
- import org.json.JSONObject;
- import org.tensorflow.lite.DataType;
- import java.io.IOException;
- import java.io.InputStream;
- import java.nio.FloatBuffer;
- import java.util.Arrays;
- import ai.onnxruntime.*;
- public class MainActivity extends AppCompatActivity {
- private TextView resultText;
- private OrtEnvironment env;
- private OrtSession session;
- @Override
- protected void onCreate(Bundle savedInstanceState) {
- super.onCreate(savedInstanceState);
- setContentView(R.layout.activity_main);
- resultText = findViewById(R.id.result_text);
- // 初始化 ONNX Runtime
- try {
- initOnnxRuntime();
- // 执行推理
- float[] outputScores = runInference();
- // 显示结果
- resultText.setText("Inference Output: " + Arrays.toString(outputScores));
- } catch (Exception e) {
- e.printStackTrace();
- resultText.setText("Error: " + e.getMessage());
- }
- }
- private void initOnnxRuntime() throws OrtException {
- // 创建 ORT 环境
- env = OrtEnvironment.getEnvironment();
- // 构建 SessionOptions
- OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
- // 可选: 使用 CPU 或 NNAPI 等加速,如果需要,可启用如下:
- // sessionOptions.addNnapi();
-
- // 从assets加载模型
- try {
- InputStream modelStream = getAssets().open("model.onnx");
- byte[] modelBytes = new byte[modelStream.available()];
- modelStream.read(modelBytes);
- session = env.createSession(modelBytes, sessionOptions);
- } catch (IOException ioException) {
- throw new RuntimeException("Failed to load model from assets", ioException);
- }
- }
- private float[] runInference() throws OrtException {
- // 准备输入张量
- // 假设输入大小 [1, 3, 224, 224],数据类型 float32
- float[] inputData = new float[1 * 3 * 224 * 224];
- // 这里示例: 全部填充随机值 or 0.5f
- // 实际中可来自图像预处理
- for (int i = 0; i < inputData.length; i++) {
- inputData[i] = 0.5f;
- }
- // ONNX Runtime需要将Java数组包装成OnnxTensor
- long[] inputShape = new long[]{1, 3, 224, 224};
- OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputShape);
- // 准备输入名 (与导出时的 input_names 对应)
- String inputName = session.getInputNames().iterator().next();
- // 运行会话
- OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
- // 假设输出名为 "output",或者取 getOutputNames() 的第一个
- String outputName = session.getOutputNames().iterator().next();
- float[][] outputRaw = (float[][]) result.get(0).getValue();
- // 此时 outputRaw 可能为 [1, num_classes],示例中只返回数组
- float[] outputScores = outputRaw[0];
- // 释放资源
- inputTensor.close();
- result.close();
- return outputScores;
- }
- @Override
- protected void onDestroy() {
- super.onDestroy();
- // 关闭 Session 和 Env,避免内存泄漏
- if (session != null) {
- try {
- session.close();
- } catch (OrtException e) {
- e.printStackTrace();
- }
- }
- if (env != null) {
- try {
- env.close();
- } catch (OrtException e) {
- e.printStackTrace();
- }
- }
- }
- }
复制代码 3.3.2 activity_main.xml
- <?xml version="1.0" encoding="utf-8"?>
- <LinearLayout
- xmlns:android="http://schemas.android.com/apk/res/android"
- android:layout_width="match_parent"
- android:layout_height="match_parent"
- android:orientation="vertical"
- android:gravity="center">
-
- <TextView
- android:id="@+id/result_text"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:text="ONNX Test"
- android:textSize="20sp"/>
- </LinearLayout>
复制代码 以上示例中:
- initOnnxRuntime(): 初始化 ONNX Runtime 情况,加载 assets 中的 model.onnx。
- runInference(): 构造一个假输入,并调用 session.run() 获取推理结果。
- 若要在手机上处理处罚现实图像或文本,需要在推理前进行预处理处罚(图像缩放、归一化,或分词、构造输入 ID 等)。
- 若模型输出不止一个,需根据真实的模型输着名(session.getOutputNames()可获取)来取出对应的张量。
4. 运行与测试示例
- 将 ONNX 模型放入 assets:确保 model.onnx 存在并可读取。
- 编译并运行:连接 Android 设备(或利用模拟器,但大型模型推理更适合真机),点击 Run,查看日志或界面的输出信息。
- 查抄结果:若模型是分类网络,可以对 outputScores 做 argmax,得到类别索引,并在界面中表现。
在真实应用中,你可以:
- 通过手机摄像头或相册加载图像 -> 预处理处罚 -> 输入模型 -> 输出猜测结果。
- 如果是语言模型或其他结构,也需要对应的数据预处理处罚与后处理处罚流程。
5. 优化方法
对于大模型,在移动端或嵌入式设备上的推理可能存在 内存、速率、功耗 等瓶颈。可从以下几个角度进行优化。
5.1 模型压缩与量化
- Post-Training Quantization:将 FP32 转为 INT8;配合校准数据,可明显淘汰模型体积并提升推理速率。
- Knowledge Distillation:训练一个尺寸更小的“学生模型”,在移动端部署。
- Pruning / Sparsity:移除不重要的权重或通道(需要硬件和库支持稀疏加快)。
ONNX Runtime 也支持量化后的模型推理,但需确保量化算子在该版本的 ORT for Android 上可用。
5.2 算子融合与图优化
- ONNX Graph Optimizer:在导出后,可对 ONNX 图进行融合、消除冗余节点等处理处罚。ONNX Runtime 默认会进行部分优化。
- 淘汰不必要的操作:将预处理处罚/后处理处罚逻辑只管简化,或者放到端上原生代码里去执行。
5.3 硬件加快接口
- NNAPI(Android):可通过 sessionOptions.addNnapi() 启用 Android NNAPI,让系统层面自动调度 GPU/NPU。
- GPU Delegate:ONNX Runtime 提供部分 GPU 后端支持,但兼容度可能不及 TensorFlow Lite GPU Delegate。
- DSP / NPU 厂商库:某些芯片厂商提供自定义加快库,可将 ONNX 模型进一步编译成特定格式。
6. 将来发起
- 更灵活的肴杂部署:
对于超大模型,可以考虑在云端服务器执行大部分推理或粗特征提取,只在移动端做小模型的精调或快速推理。
- 分片与流式推理:
在内存特别有限时,可以将模型分成多段,分批加载盘算。
- 连续关注 ONNX Runtime 更新:
随着新版本的推出,硬件加快和量化等特性会不断完善。
- 联合边缘专用硬件:
如果有专用设备(如 Google Coral TPU、NVIDIA Jetson NX、ARM Ethos 等),可考虑将 ONNX 模型部署到相应 SDK 中,大幅提升性能。
- 联合自动化 NAS:
如果对精度和性能要求极高,可利用神经网络架构搜索(NAS)探求更适合移动设备的模型结构,在保持结果的同时极大降低推理成本。
总结
通过以上步骤,各人可以将 ONNX 格式的大模型在 Android 设备上进行推理测试,并联合 ONNX Runtime 的接口进行快速部署。对于在移动端部署大模型,发起在精度与资源之间做充实的衡量,并利用量化、剪枝、蒸馏等方法进行模型优化。联合硬件加快与工具链的进一步发展,移动端也能承载越来越强盛的 AI 本事,满意更多的现实业务需求。
【哈佛博后带小白玩转呆板学习】 哔哩哔哩_bilibili
总课时超400+,时长75+小时
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |