tensorflow
https://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn
官方有移动端demo
前端不会 就只能找找有没有java支持
注意版本
注意JDK版本
- package com.example.demo17.controller;
- import org.tensorflow.*;
- import org.tensorflow.ndarray.*;
- import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
- import org.tensorflow.proto.framework.DataType;
- import org.tensorflow.proto.framework.MetaGraphDef;
- import org.tensorflow.proto.framework.SignatureDef;
- import org.tensorflow.proto.framework.TensorInfo;
- import org.tensorflow.types.TFloat32;
- import org.tensorflow.types.TInt64;
- import javax.sound.sampled.AudioFormat;
- import javax.sound.sampled.AudioInputStream;
- import javax.sound.sampled.AudioSystem;
- import javax.sound.sampled.UnsupportedAudioFileException;
- import javax.xml.transform.Result;
- import java.io.File;
- import java.io.IOException;
- import java.io.InputStream;
- import java.nio.file.Files;
- import java.nio.file.Paths;
- import java.util.*;
- import java.util.concurrent.ConcurrentHashMap;
- public class Test {
- private static FloatNdArray t1() {
- // String audioFilePath = "D:\\ai\\cat.wav";
- String audioFilePath = "C:\\Users\\user\\Downloads\\output_Wo9KJb-5zuz1_2.wav";
- // String audioFilePath = "D:\\ai\\111\\111.wav";
- // YAMNet期望的采样率
- int sampleRate = 16000;
- // YAMNet帧大小,0.96秒
- int frameSizeInMs = 96;
- // YAMNet帧步长,0.48秒
- int hopSizeInMs = 48;
- try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
- AudioFormat format = audioStream.getFormat();
- if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
- System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
- }
- int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
- int hopSize = (int) (sampleRate * hopSizeInMs / 1000);
- byte[] buffer = new byte[frameSize * format.getFrameSize()];
- short[] audioSamples = new short[frameSize];
- // 存储每个帧的音频数据
- List<Float> floatList = new ArrayList<>();
- while (true) {
- int bytesRead = audioStream.read(buffer);
- if (bytesRead == -1) {
- break;
- }
- // 将读取的字节转换为short数组(假设16位精度)
- for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
- audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
- }
- // 对当前帧进行处理(例如,归一化和准备送入模型)
- float[] floats = processFrame(audioSamples);
- for (float aFloat : floats) {
- floatList.add(aFloat);
- }
- // 移动到下一个帧
- System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
- }
- // 将List<Float>转换为float[]
- float[] floatArray = new float[floatList.size()];
- for (int i = 0; i < floatList.size(); i++) {
- floatArray[i] = floatList.get(i);
- }
- return StdArrays.ndCopyOf(floatArray);
- } catch (UnsupportedAudioFileException | IOException e) {
- e.printStackTrace();
- }
- return null;
- }
- private static float[] processFrame(short[] frame) {
- // 示例:归一化音频数据到[-1.0, 1.0]
- float[] normalizedFrame = new float[frame.length];
- for (int i = 0; i < frame.length; i++) {
- // short的最大值为32767,故除以32768得到[-1.0, 1.0]
- normalizedFrame[i] = frame[i] / 32768f;
- }
- return normalizedFrame;
- }
- static Map<String,String> map=new ConcurrentHashMap<>();
- public static void main(String[] args) throws Exception {
- FloatNdArray floatNdArray = t1();
- TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);
- //SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\\saved_model", "serve");
- SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\user\\Downloads\\archive", "serve");
- Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
- /**
- * 获取基本定义信息
- */
- SignatureDef modelSig = signatureDefMap.get("serving_default");
- String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
- String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
- savedModelBundle.graph();
- try (Session session = savedModelBundle.session()) {
- /*JDK 17*/
- // Result run = session.runner()
- // .feed(inputTensorName, tFloat32)
- // .fetch(outputTensorName)
- // .run();
- // Tensor out = run.get(0);
- // Shape shape = out.shape();
- //
- // System.out.println(shape);
- /*JDK 8*/
- List<Tensor> run = session.runner()
- .feed(inputTensorName, tFloat32)
- .fetch(outputTensorName)
- .run();
- Tensor tensor = run.get(0);
- Shape shape = tensor.shape();
- System.out.println(shape.asArray());
- String l=String.valueOf(shape.asArray()[0]);
- //读取CSV文件
- String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
- try {
- List<String> lines = Files.readAllLines(Paths.get(csvFile));
- for (String line : lines) {
- String[] values = line.split(",");
- map.put(values[0], values[2]);
- }
- } catch (IOException e) {
- e.printStackTrace();
- }
- String s = map.get(l);
- System.out.println(s);
- }
- }
- }
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |