风雨同行 发表于 2024-6-14 23:00:21

Tensorflow音频分类

tensorflow

https://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn
https://img-blog.csdnimg.cn/direct/ffd26b0af64a402fbad5ce78622a8503.png
官方有移动端demo
https://img-blog.csdnimg.cn/direct/a912f803740d4ead83aaa64af2def91c.png

前端不会  就只能找找有没有java支持
https://img-blog.csdnimg.cn/direct/5f225e2cd07e4154bfe36bd1d30dfcaa.png
https://img-blog.csdnimg.cn/direct/8ad1597388894c0b89e8083cec23c8b2.png
https://img-blog.csdnimg.cn/direct/7c98611441a6476ca069de08cf44ddaf.png

注意版本
https://img-blog.csdnimg.cn/direct/0e214d5587a449b288781f43aa84e153.png

https://img-blog.csdnimg.cn/direct/d4d8654c07b3496484412e15c04cab9f.png

https://img-blog.csdnimg.cn/direct/276e65a6319847889a4d8b7ecec05f3c.png

注意JDK版本
package com.example.demo17.controller;


import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;

import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;
import javax.xml.transform.Result;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;

public class Test {


    private static FloatNdArray t1() {
//      String audioFilePath = "D:\\ai\\cat.wav";
      String audioFilePath = "C:\\Users\\user\\Downloads\\output_Wo9KJb-5zuz1_2.wav";
//      String audioFilePath = "D:\\ai\\111\\111.wav";
      // YAMNet期望的采样率
      int sampleRate = 16000;
      // YAMNet帧大小,0.96秒
      int frameSizeInMs = 96;
      // YAMNet帧步长,0.48秒
      int hopSizeInMs = 48;

      try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
            AudioFormat format = audioStream.getFormat();
            if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
                System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
            }
            int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
            int hopSize = (int) (sampleRate * hopSizeInMs / 1000);

            byte[] buffer = new byte;
            short[] audioSamples = new short;
            // 存储每个帧的音频数据
            List<Float> floatList = new ArrayList<>();
            while (true) {
                int bytesRead = audioStream.read(buffer);
                if (bytesRead == -1) {
                  break;
                }
                // 将读取的字节转换为short数组(假设16位精度)
                for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
                  audioSamples = (short) ((buffer & 0xFF) | (buffer << 8));
                }
                // 对当前帧进行处理(例如,归一化和准备送入模型)
                float[] floats = processFrame(audioSamples);
                for (float aFloat : floats) {
                  floatList.add(aFloat);
                }
                // 移动到下一个帧
                System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
            }

            // 将List<Float>转换为float[]
            float[] floatArray = new float;
            for (int i = 0; i < floatList.size(); i++) {
                floatArray = floatList.get(i);
            }

            return StdArrays.ndCopyOf(floatArray);
      } catch (UnsupportedAudioFileException | IOException e) {
            e.printStackTrace();
      }
      return null;
    }


    private static float[] processFrame(short[] frame) {
      // 示例:归一化音频数据到[-1.0, 1.0]
      float[] normalizedFrame = new float;
      for (int i = 0; i < frame.length; i++) {
            // short的最大值为32767,故除以32768得到[-1.0, 1.0]
            normalizedFrame = frame / 32768f;
      }
      return normalizedFrame;
    }

    static Map<String,String> map=new ConcurrentHashMap<>();

    public static void main(String[] args) throws Exception {
      FloatNdArray floatNdArray = t1();
      TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);

      //SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\\saved_model", "serve");
      SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\user\\Downloads\\archive", "serve");
      Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
      /**
         * 获取基本定义信息
         */
      SignatureDef modelSig = signatureDefMap.get("serving_default");
      String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
      String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
      savedModelBundle.graph();
      try (Session session = savedModelBundle.session()) {
            /*JDK 17*/
//            Result run = session.runner()
//                  .feed(inputTensorName, tFloat32)
//                  .fetch(outputTensorName)
//                  .run();
//            Tensor out = run.get(0);
//            Shape shape = out.shape();
//
//            System.out.println(shape);
            /*JDK 8*/
            List<Tensor> run = session.runner()
                  .feed(inputTensorName, tFloat32)
                  .fetch(outputTensorName)
                  .run();
            Tensor tensor = run.get(0);
            Shape shape = tensor.shape();
            System.out.println(shape.asArray());
            String l=String.valueOf(shape.asArray());
            //读取CSV文件
            String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
            try {
                List<String> lines = Files.readAllLines(Paths.get(csvFile));
                for (String line : lines) {
                  String[] values = line.split(",");
                  map.put(values, values);
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            String s = map.get(l);
            System.out.println(s);
      }
    }
}

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: Tensorflow音频分类