IT评测·应用市场-qidao123.com技术社区

标题: Yolov8目的检测——在Android上摆设Yolov8 tflite模型 [打印本页]

作者: 悠扬随风    时间: 2024-6-21 02:00
标题: Yolov8目的检测——在Android上摆设Yolov8 tflite模型
1. 简介

YOLOv8 是一种用于目的检测的深度学习模型,它是 YOLO(You Only Look Once)系列的最新版本之一。YOLO 系列因其高效和准确性而在计算机视觉范畴非常受欢迎,特别是在需要实时目的检测的应用中,如视频监控、主动驾驶汽车、机器人视觉等。
以下是 YOLOv8 的一些关键特点:

2.模型转换

2.1 tflite模型

TensorFlow Lite (tflite) 是一种用于移动和嵌入式设备上的机器学习模型的格式。它允许开发者将训练好的 TensorFlow 模型转换为一个更小、更快、更高效的格式,以便于在资源受限的情况中运行,比如智能手机和微控制器。

2.2 Pytorch 格式转换为 tflite 格式

YOLOv8 是以 pytorch 格式构建的。将其转换为 tflite,以便在 Android 上使用。
安装 Ultralytics 框架
使用 pip 安装 Ultralytics 框架,该框架包含了 YOLOv8:
  1. conda create -n yolov8 python=3.8
  2. activate ylolv8
  3. pip install ultralytics
复制代码
转换模型为 tflite 格式
使用 Ultralytics 框架提供的 YOLO 类来加载 PyTorch 格式的 YOLOv8 模型,并导出为 tflite 格式:
  1.   from ultralytics import YOLO
  2.   model = YOLO('yolov8s.pt')  # 这里 'yolov8s.pt' 是模型权重文件
  3.   model.export(format="tflite")
复制代码
这将天生一个 tflite 文件,比方 yolov8s_saved_model/yolov8s_float16.tflite。
处置惩罚转换过程中的错误
如果在转换过程中遇到错误,特别是与 TensorFlow 版本相关的问题,需要安装一个特定版本的 TensorFlow 来解决兼容性问题:
  1.   pip install tensorflow==2.13.0
复制代码
3.创建项目

3.1 创建项目

创建一个安卓项目,语言选择Kotlin,如下图所示:

然后在 Android Studio 项目的 app 目录中创建一个 assets 目录(文件 → 新建 → 文件夹 → 资产文件夹),并将 tflite 文件(比方 yolov8s_float32.tflite)和 labels.txt 添加进去。labels.txt其中描述了 YOLOv8 模型的种别名称。
3.2 添加依靠

将以下内容添加到 app/build.gradle.kts 中的依靠项以安装 tflite 框架。
  1. implementation("org.tensorflow:tensorflow-lite:2.14.0")
  2. implementation("org.tensorflow:tensorflow-lite-support:0.4.4")
复制代码
导入所需的模块
  1. import org.tensorflow.lite.DataType
  2. import org.tensorflow.lite.Interpreter
  3. import org.tensorflow.lite.gpu.CompatibilityList
  4. import org.tensorflow.lite.gpu.GpuDelegate
  5. import org.tensorflow.lite.support.common.FileUtil
  6. import org.tensorflow.lite.support.common.ops.CastOp
  7. import org.tensorflow.lite.support.common.ops.NormalizeOp
  8. import org.tensorflow.lite.support.image.ImageProcessor
  9. import org.tensorflow.lite.support.image.TensorImage
  10. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
  11. import java.io.BufferedReader
  12. import java.io.IOException
  13. import java.io.InputStream
  14. import java.io.InputStreamReader
复制代码
3.3 初始化模型

  1. private val modelPath = "yolov8s_float32.tflite"
  2. private val labelPath = "labels.txt"
  3. private var interpreter: Interpreter? = null
  4. private var tensorWidth = 0
  5. private var tensorHeight = 0
  6. private var numChannel = 0
  7. private var numElements = 0
  8. private var labels = mutableListOf<String>()
  9. private val imageProcessor = ImageProcessor.Builder()
  10.     .add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
  11.     .add(CastOp(INPUT_IMAGE_TYPE))
  12.     .build() // preprocess input
  13. companion object {
  14.     private const val INPUT_MEAN = 0f
  15.     private const val INPUT_STANDARD_DEVIATION = 255f
  16.     private val INPUT_IMAGE_TYPE = DataType.FLOAT32
  17.     private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
  18.     private const val CONFIDENCE_THRESHOLD = 0.3F
  19.     private const val IOU_THRESHOLD = 0.5F
  20. }
复制代码
初始化 tflite 模型。获取模型文件并将其传递给 tflite 的 Interpreter。选择推理使用的线程数。
  1. val model = FileUtil.loadMappedFile(context, modelPath)
  2. val options = Interpreter.Options()
  3. options.numThreads = 4
  4. interpreter = Interpreter(model, options)
复制代码
从 Interpreter 获取 yolov8s 输入和输层:
  1. val inputShape = interpreter.getInputTensor(0).shape()
  2. val outputShape = interpreter.getOutputTensor(0).shape()
  3. tensorWidth = inputShape[1]
  4. tensorHeight = inputShape[2]
  5. numChannel = outputShape[1]
  6. numElements = outputShape[2]
复制代码
3.4 从 label.txt 文件中读取类名称

  1. try {
  2.     val inputStream: InputStream = context.assets.open(labelPath)
  3.     val reader = BufferedReader(InputStreamReader(inputStream))
  4.     var line: String? = reader.readLine()
  5.     while (line != null && line != "") {
  6.         labels.add(line)
  7.         line = reader.readLine()
  8.     }
  9.     reader.close()
  10.     inputStream.close()
  11. } catch (e: IOException) {
  12.     e.printStackTrace()
  13. }
复制代码
3.5 对图像进行推理

在 Android 应用中,输入是位图(Bitmap),需要根据模型的输入格式进行预处置惩罚:

  1. import android.graphics.Bitmap;
  2. import android.graphics.ImageFormat;
  3. import org.tensorflow.lite.Interpreter;
  4. import java.nio.ByteBuffer;
  5. import java.nio.ByteOrder;
  6. import java.nio.channels.WritableByteChannel;
  7. // 假设 tflite 已经初始化,且 bitmap 是您要处理的位图
  8. Bitmap bitmap
  9. val resizedBitmap = Bitmap.createScaledBitmap(bitmap, tensorWidth, tensorHeight, false)
  10. val tensorImage = TensorImage(DataType.FLOAT32)
  11. tensorImage.load(resizedBitmap)
  12. val processedImage = imageProcessor.process(tensorImage)
  13. val imageBuffer = processedImage.buffer
复制代码
创建一个与模型输出层匹配的输出张量缓冲区,并将其与上面的输入 imageBuffer 一起传递给解释器以实行。
  1. val output = TensorBuffer.createFixedSize(intArrayOf(1 , numChannel, numElements), OUTPUT_IMAGE_TYPE)
  2. interpreter.run(imageBuffer, output.buffer)
复制代码
3.6 处置惩罚输出

输出框被视为 BoudingBox 类。这是一个具有种别、框和置信度级别的类。其中x1,y1 是起始点。x2, y2 是终点,cx, cy 是中心。w 宽度,h 是高度。
  1. data class BoundingBox(
  2.     val x1: Float,
  3.     val y1: Float,
  4.     val x2: Float,
  5.     val y2: Float,
  6.     val cx: Float,
  7.     val cy: Float,
  8.     val w: Float,
  9.     val h: Float,
  10.     val cnf: Float,
  11.     val cls: Int,
  12.     val clsName: String
  13. )
复制代码
提取置信度高于置信度阈值的框,在重叠的框中,留下置信度最高的框。(nms)
  1. private fun bestBox(array: FloatArray) : List<BoundingBox>? {
  2.     val boundingBoxes = mutableListOf<BoundingBox>()
  3.     for (c in 0 until numElements) {
  4.         var maxConf = -1.0f        var maxIdx = -1        var j = 4        var arrayIdx = c + numElements * j
  5.         while (j < numChannel){
  6.             if (array[arrayIdx] > maxConf) {
  7.                 maxConf = array[arrayIdx]
  8.                 maxIdx = j - 4
  9.             }
  10.             j++
  11.             arrayIdx += numElements
  12.         }
  13.         if (maxConf > CONFIDENCE_THRESHOLD) {
  14.             val clsName = labels[maxIdx]
  15.             val cx = array[c] // 0            val cy = array[c + numElements] // 1            val w = array[c + numElements * 2]
  16.             val h = array[c + numElements * 3]
  17.             val x1 = cx - (w/2F)
  18.             val y1 = cy - (h/2F)
  19.             val x2 = cx + (w/2F)
  20.             val y2 = cy + (h/2F)
  21.             if (x1 < 0F || x1 > 1F) continue            if (y1 < 0F || y1 > 1F) continue            if (x2 < 0F || x2 > 1F) continue            if (y2 < 0F || y2 > 1F) continue
  22.             boundingBoxes.add(
  23.                 BoundingBox(
  24.                     x1 = x1, y1 = y1, x2 = x2, y2 = y2,
  25.                     cx = cx, cy = cy, w = w, h = h,
  26.                     cnf = maxConf, cls = maxIdx, clsName = clsName
  27.                 )
  28.             )
  29.         }
  30.     }
  31.     if (boundingBoxes.isEmpty()) return null    return applyNMS(boundingBoxes)
  32. }
  33. private fun applyNMS(boxes: List<BoundingBox>) : MutableList<BoundingBox> {
  34.     val sortedBoxes = boxes.sortedByDescending { it.cnf }.toMutableList()
  35.     val selectedBoxes = mutableListOf<BoundingBox>()
  36.     while(sortedBoxes.isNotEmpty()) {
  37.         val first = sortedBoxes.first()
  38.         selectedBoxes.add(first)
  39.         sortedBoxes.remove(first)
  40.         val iterator = sortedBoxes.iterator()
  41.         while (iterator.hasNext()) {
  42.             val nextBox = iterator.next()
  43.             val iou = calculateIoU(first, nextBox)
  44.             if (iou >= IOU_THRESHOLD) {
  45.                 iterator.remove()
  46.             }
  47.         }
  48.     }
  49.     return selectedBoxes
  50. }
  51. private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float {
  52.     val x1 = maxOf(box1.x1, box2.x1)
  53.     val y1 = maxOf(box1.y1, box2.y1)
  54.     val x2 = minOf(box1.x2, box2.x2)
  55.     val y2 = minOf(box1.y2, box2.y2)
  56.     val intersectionArea = maxOf(0F, x2 - x1) * maxOf(0F, y2 - y1)
  57.     val box1Area = box1.w * box1.h
  58.     val box2Area = box2.w * box2.h
  59.     return intersectionArea / (box1Area + box2Area - intersectionArea)
  60. }
复制代码
将得到 yolov8 的输出。
  1. val bestBoxes = bestBox(output.floatArray)
复制代码
将输出框绘制到图像上
  1. fun drawBoundingBoxes(bitmap: Bitmap, boxes: List<BoundingBox>): Bitmap {
  2.     val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
  3.     val canvas = Canvas(mutableBitmap)
  4.     val paint = Paint().apply {
  5.         color = Color.RED
  6.         style = Paint.Style.STROKE
  7.         strokeWidth = 8f
  8.     }
  9.     val textPaint = Paint().apply {
  10.         color = Color.WHITE
  11.         textSize = 40f
  12.         typeface = Typeface.DEFAULT_BOLD
  13.     }
  14.     for (box in boxes) {
  15.         val rect = RectF(
  16.             box.x1 * mutableBitmap.width,
  17.             box.y1 * mutableBitmap.height,
  18.             box.x2 * mutableBitmap.width,
  19.             box.y2 * mutableBitmap.height
  20.         )
  21.         canvas.drawRect(rect, paint)
  22.         canvas.drawText(box.clsName, rect.left, rect.bottom, textPaint)
  23.     }
  24.     return mutableBitmap
  25. }
复制代码
运行效果:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 IT评测·应用市场-qidao123.com技术社区 (https://dis.qidao123.com/) Powered by Discuz! X3.4