ToB企服应用市场:ToB评测及商务社交产业平台
标题:
Tensorflow音频分类
[打印本页]
作者:
风雨同行
时间:
2024-6-14 23:00
标题:
Tensorflow音频分类
tensorflow
https://www.tensorflow.org/lite/examples/audio_classification/overview?hl=zh-cn
官方有移动端demo
前端不会 就只能找找有没有java支持
注意版本
注意JDK版本
package com.example.demo17.controller;
import org.tensorflow.*;
import org.tensorflow.ndarray.*;
import org.tensorflow.ndarray.impl.dense.FloatDenseNdArray;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import javax.sound.sampled.UnsupportedAudioFileException;
import javax.xml.transform.Result;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
public class Test {
private static FloatNdArray t1() {
// String audioFilePath = "D:\\ai\\cat.wav";
String audioFilePath = "C:\\Users\\user\\Downloads\\output_Wo9KJb-5zuz1_2.wav";
// String audioFilePath = "D:\\ai\\111\\111.wav";
// YAMNet期望的采样率
int sampleRate = 16000;
// YAMNet帧大小,0.96秒
int frameSizeInMs = 96;
// YAMNet帧步长,0.48秒
int hopSizeInMs = 48;
try (AudioInputStream audioStream = AudioSystem.getAudioInputStream(Paths.get(audioFilePath).toFile())) {
AudioFormat format = audioStream.getFormat();
if (format.getSampleRate() != sampleRate || format.getChannels() != 1) {
System.out.println("Warning: Audio must be 16kHz mono. Consider preprocessing.");
}
int frameSize = (int) (sampleRate * frameSizeInMs / 1000);
int hopSize = (int) (sampleRate * hopSizeInMs / 1000);
byte[] buffer = new byte[frameSize * format.getFrameSize()];
short[] audioSamples = new short[frameSize];
// 存储每个帧的音频数据
List<Float> floatList = new ArrayList<>();
while (true) {
int bytesRead = audioStream.read(buffer);
if (bytesRead == -1) {
break;
}
// 将读取的字节转换为short数组(假设16位精度)
for (int i = 0; i < bytesRead / format.getFrameSize(); i++) {
audioSamples[i] = (short) ((buffer[i * 2] & 0xFF) | (buffer[i * 2 + 1] << 8));
}
// 对当前帧进行处理(例如,归一化和准备送入模型)
float[] floats = processFrame(audioSamples);
for (float aFloat : floats) {
floatList.add(aFloat);
}
// 移动到下一个帧
System.arraycopy(audioSamples, hopSize, audioSamples, 0, frameSize - hopSize);
}
// 将List<Float>转换为float[]
float[] floatArray = new float[floatList.size()];
for (int i = 0; i < floatList.size(); i++) {
floatArray[i] = floatList.get(i);
}
return StdArrays.ndCopyOf(floatArray);
} catch (UnsupportedAudioFileException | IOException e) {
e.printStackTrace();
}
return null;
}
private static float[] processFrame(short[] frame) {
// 示例:归一化音频数据到[-1.0, 1.0]
float[] normalizedFrame = new float[frame.length];
for (int i = 0; i < frame.length; i++) {
// short的最大值为32767,故除以32768得到[-1.0, 1.0]
normalizedFrame[i] = frame[i] / 32768f;
}
return normalizedFrame;
}
static Map<String,String> map=new ConcurrentHashMap<>();
public static void main(String[] args) throws Exception {
FloatNdArray floatNdArray = t1();
TFloat32 tFloat32 = TFloat32.tensorOf(floatNdArray);
//SavedModelBundle savedModelBundle = SavedModelBundle.load("D:\\saved_model", "serve");
SavedModelBundle savedModelBundle = SavedModelBundle.load("C:\\Users\\user\\Downloads\\archive", "serve");
Map<String, SignatureDef> signatureDefMap = MetaGraphDef.parseFrom(savedModelBundle.metaGraphDef().toByteArray()).getSignatureDefMap();
/**
* 获取基本定义信息
*/
SignatureDef modelSig = signatureDefMap.get("serving_default");
String inputTensorName = modelSig.getInputsMap().get("waveform").getName();
String outputTensorName = modelSig.getOutputsMap().get("output_0").getName();
savedModelBundle.graph();
try (Session session = savedModelBundle.session()) {
/*JDK 17*/
// Result run = session.runner()
// .feed(inputTensorName, tFloat32)
// .fetch(outputTensorName)
// .run();
// Tensor out = run.get(0);
// Shape shape = out.shape();
//
// System.out.println(shape);
/*JDK 8*/
List<Tensor> run = session.runner()
.feed(inputTensorName, tFloat32)
.fetch(outputTensorName)
.run();
Tensor tensor = run.get(0);
Shape shape = tensor.shape();
System.out.println(shape.asArray());
String l=String.valueOf(shape.asArray()[0]);
//读取CSV文件
String csvFile = "C:\\Users\\user\\Downloads\\archive\\assets\\yamnet_class_map.csv";
try {
List<String> lines = Files.readAllLines(Paths.get(csvFile));
for (String line : lines) {
String[] values = line.split(",");
map.put(values[0], values[2]);
}
} catch (IOException e) {
e.printStackTrace();
}
String s = map.get(l);
System.out.println(s);
}
}
}
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/)
Powered by Discuz! X3.4