Tensorflow音频分类

打印 上一主题 下一主题

主题 511|帖子 511|积分 1533

tensorflow

https://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn

官方有移动端demo


前端不会  就只能找找有没有java支持




注意版本






注意JDK版本
  1. package com.example.demo17.controller;
  2. import org.tensorflow.*;
  3. import org.tensorflow.ndarray.*;
  4. import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
  5. import org.tensorflow.proto.framework.DataType;
  6. import org.tensorflow.proto.framework.MetaGraphDef;
  7. import org.tensorflow.proto.framework.SignatureDef;
  8. import org.tensorflow.proto.framework.TensorInfo;
  9. import org.tensorflow.types.TFloat32;
  10. import org.tensorflow.types.TInt64;
  11. import javax.sound.sampled.AudioFormat;
  12. import javax.sound.sampled.AudioInputStream;
  13. import javax.sound.sampled.AudioSystem;
  14. import javax.sound.sampled.UnsupportedAudioFileException;
  15. import javax.xml.transform.Result;
  16. import java.io.File;
  17. import java.io.IOException;
  18. import java.io.InputStream;
  19. import java.nio.file.Files;
  20. import java.nio.file.Paths;
  21. import java.util.*;
  22. import java.util.concurrent.ConcurrentHashMap;
  23. public class Test {
  24.     private static FloatNdArray t1() {
  25. //        String audioFilePath = "D:\\ai\\cat.wav";
  26.         String audioFilePath = "C:\\Users\\user\\Downloads\\output_Wo9KJb-5zuz1_2.wav";
  27. //        String audioFilePath = "D:\\ai\\111\\111.wav";
  28.         // YAMNet期望的采样率
  29.         int sampleRate = 16000;
  30.         // YAMNet帧大小,0.96秒
  31.         int frameSizeInMs = 96;
  32.         // YAMNet帧步长,0.48秒
  33.         int hopSizeInMs = 48;
  34.         try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
  35.             AudioFormat format = audioStream.getFormat();
  36.             if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
  37.                 System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
  38.             }
  39.             int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
  40.             int hopSize = (int) (sampleRate * hopSizeInMs / 1000);
  41.             byte[] buffer = new byte[frameSize * format.getFrameSize()];
  42.             short[] audioSamples = new short[frameSize];
  43.             // 存储每个帧的音频数据
  44.             List<Float> floatList = new ArrayList<>();
  45.             while (true) {
  46.                 int bytesRead = audioStream.read(buffer);
  47.                 if (bytesRead == -1) {
  48.                     break;
  49.                 }
  50.                 // 将读取的字节转换为short数组(假设16位精度)
  51.                 for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
  52.                     audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
  53.                 }
  54.                 // 对当前帧进行处理(例如,归一化和准备送入模型)
  55.                 float[] floats = processFrame(audioSamples);
  56.                 for (float aFloat : floats) {
  57.                     floatList.add(aFloat);
  58.                 }
  59.                 // 移动到下一个帧
  60.                 System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
  61.             }
  62.             // 将List<Float>转换为float[]
  63.             float[] floatArray = new float[floatList.size()];
  64.             for (int i = 0; i < floatList.size(); i++) {
  65.                 floatArray[i] = floatList.get(i);
  66.             }
  67.             return StdArrays.ndCopyOf(floatArray);
  68.         } catch (UnsupportedAudioFileException | IOException e) {
  69.             e.printStackTrace();
  70.         }
  71.         return null;
  72.     }
  73.     private static float[] processFrame(short[] frame) {
  74.         // 示例:归一化音频数据到[-1.0, 1.0]
  75.         float[] normalizedFrame = new float[frame.length];
  76.         for (int i = 0; i < frame.length; i++) {
  77.             // short的最大值为32767,故除以32768得到[-1.0, 1.0]
  78.             normalizedFrame[i] = frame[i] / 32768f;
  79.         }
  80.         return normalizedFrame;
  81.     }
  82.     static Map<String,String> map=new ConcurrentHashMap<>();
  83.     public static void main(String[] args) throws Exception {
  84.         FloatNdArray floatNdArray = t1();
  85.         TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);
  86.         //SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\\saved_model", "serve");
  87.         SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\user\\Downloads\\archive", "serve");
  88.         Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
  89.         /**
  90.          * 获取基本定义信息
  91.          */
  92.         SignatureDef modelSig = signatureDefMap.get("serving_default");
  93.         String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
  94.         String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
  95.         savedModelBundle.graph();
  96.         try (Session session = savedModelBundle.session()) {
  97.             /*JDK 17*/
  98. //            Result run = session.runner()
  99. //                    .feed(inputTensorName, tFloat32)
  100. //                    .fetch(outputTensorName)
  101. //                    .run();
  102. //            Tensor out = run.get(0);
  103. //            Shape shape = out.shape();
  104. //
  105. //            System.out.println(shape);
  106.             /*JDK 8*/
  107.             List<Tensor> run = session.runner()
  108.                     .feed(inputTensorName, tFloat32)
  109.                     .fetch(outputTensorName)
  110.                     .run();
  111.             Tensor tensor = run.get(0);
  112.             Shape shape = tensor.shape();
  113.             System.out.println(shape.asArray());
  114.             String l=String.valueOf(shape.asArray()[0]);
  115.             //读取CSV文件
  116.             String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
  117.             try {
  118.                 List<String> lines = Files.readAllLines(Paths.get(csvFile));
  119.                 for (String line : lines) {
  120.                     String[] values = line.split(",");
  121.                     map.put(values[0], values[2]);
  122.                 }
  123.             } catch (IOException e) {
  124.                 e.printStackTrace();
  125.             }
  126.             String s = map.get(l);
  127.             System.out.println(s);
  128.         }
  129.     }
  130. }
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

风雨同行

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表