当前位置: 首页 > ops >正文

音频分类模型笔记

目录

audio-mamba-aum

Small ImageNet

kimia_infer 音频理解

音频分类:

saurabhati/DASS_medium_AudioSet_50.2

ast

AudioClassification-Pytorch没有模型,自己训练:


OmniVec2.00.558
OmniVec0.548
EquiAV0.546
MAViL (Audio-Visual, single)0.533
PaSST-S / ConvNeXt-Tiny0.471
PSLA (Ensemble EfficientNet)0.474
AST0.485
Audio-MAE (SOTA self-supervised)超过已有监督方法(具体数值未详)
你提到的 DASS_medium_AudioSet0.502

https://huggingface.co/hongroklim/omniverse-github/tree/main

https://github.com/JongSuk1/EquiAV

audio-mamba-aum

https://github.com/kaistmm/Audio-Mamba-AuM?tab=readme-ov-file

代码没测

import torch
import torchaudio
from pathlib import Path# 假设模型定义在 src/model.py 中(你需要根据实际路径替换)
from src.model import AudioMambaModel
from src.config import get_config  # 如果配置是分离的def load_model(checkpoint_path, config_path=None, device='cuda'):# 初始化模型配置(如果是单文件加载则可简化)config = get_config(config_path) if config_path else Nonemodel = AudioMambaModel(config) if config else AudioMambaModel()checkpoint = torch.load(checkpoint_path, map_location=device)model.load_state_dict(checkpoint['model_state_dict'])model.to(device).eval()return modeldef preprocess_audio(audio_path, target_sample_rate=16000):waveform, sr = torchaudio.load(audio_path)if sr != target_sample_rate:waveform = torchaudio.transforms.Resample(sr, target_sample_rate)(waveform)# 这里可以加入其他预处理,如分帧、增益标准化等return waveformdef infer(model, waveform, device='cuda'):waveform = waveform.to(device)with torch.no_grad():outputs = model(waveform.unsqueeze(0))  # 增加 batch 维probs = torch.softmax(outputs, dim=1)top_prob, top_label = torch.max(probs, dim=1)return top_label.item(), top_prob.item()if __name__ == "__main__":checkpoint_path = "path/to/your/checkpoint.pth"  # 替换为你下载的 checkpointconfig_path = None  # 如果需要可提供配置文件路径audio_path = "path/to/your/audio.wav"label_map = {0: "class_a", 1: "class_b", ...}  # 填入对应任务标签映射device = "cuda" if torch.cuda.is_available() else "cpu"model = load_model(checkpoint_path, config_path, device)waveform = preprocess_audio(audio_path)label_idx, confidence = infer(model, waveform, device)print(f"Predicted label: {label_map[label_idx]} (confidence: {confidence:.4f})")

Small ImageNet

These are the checkpoints for the small models with the variant Bi-Bi (c), initialized with ImageNet pretrained weights.

Dataset#ParamsPerformanceCheckpoint
Audioset (mAP)25.5M39.74Link
AS-20K (mAP)25.5M29.17Link
VGGSound (Acc)25.5M49.61Link
VoxCeleb (Acc)25.8M41.78Link
Speech Commands V2 (Acc)25.2M97.61Link
Epic Sounds (Acc)25.4M53.45Link

kimia_infer 音频理解

git clone https://github.com/MoonshotAI/Kimi-Audio
git submodule update --init
cd Kimi-Audio
docker build -t kimi-audio:v0.1 .Alternatively, You can also use our pre-built image:docker pull moonshotai/kimi-audio:v0.1

docker pull 的方式成功了。

docker run -it a49762d13a3d bash

import soundfile as sf
import torch
from kimia_infer.api.kimia import KimiAudiodef main():# 1. 模型加载model_id = "moonshotai/Kimi-Audio-7B-Instruct"  # 或者 "Kimi/Kimi-Audio-7B"device = "cuda" if torch.cuda.is_available() else "cpu"model = KimiAudio(model_path=model_id, load_detokenizer=True)model.to(device)# 2. 设置采样参数,调整生成行为sampling_params = {"audio_temperature": 0.8,"audio_top_k": 10,"text_temperature": 0.0,"text_top_k": 5,"audio_repetition_penalty": 1.0,"audio_repetition_window_size": 64,"text_repetition_penalty": 1.0,"text_repetition_window_size": 16,}# 3. 示例任务 A:音频转文本(ASR)asr_audio = "path/to/your_asr_audio.wav"  # 请替换为真实音频路径messages_asr = [{"role": "user", "message_type": "text", "content": "请转录下面这段音频:"},{"role": "user", "message_type": "audio", "content": asr_audio}]_, text_out = model.generate(messages_asr, **sampling_params, output_type="text")print("ASR 输出内容:", text_out)# 4. 示例任务 B:音频问答(Audio-to-Audio/Text)qa_audio = "path/to/your_qa_audio.wav"  # 替换为真实音频路径messages_qa = [{"role": "user", "message_type": "audio", "content": qa_audio}]wav_out, text_qa = model.generate(messages_qa, **sampling_params, output_type="both")output_wav = "output_generated.wav"sf.write(output_wav, wav_out.detach().cpu().view(-1).numpy(), 24000)print("问答生成音频已保存至:", output_wav)print("问答输出文本:", text_qa)if __name__ == "__main__":main()

音频分类:


# coding=utf-8
import glob
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir+'/../']
paths.append(os.path.join(current_dir, 'src'))
for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')from kimia_infer.api.kimia import KimiAudio
import os
import soundfile as sf
import argparseif __name__ == "__main__":parser = argparse.ArgumentParser()# parser.add_argument("--model_path", type=str, default="moonshotai/Kimi-Audio-7B-Instruct")parser.add_argument("--model_path", type=str, default="/nas/lbg/models/Kimi-Audio-7B-Instruct")args = parser.parse_args()model = KimiAudio(model_path=args.model_path,load_detokenizer=True,)sampling_params = {"audio_temperature": 0.8,"audio_top_k": 10,"text_temperature": 0.0,"text_top_k": 5,"audio_repetition_penalty": 1.0,"audio_repetition_window_size": 64,"text_repetition_penalty": 1.0,"text_repetition_window_size": 16,}# messages = [#     {#         "role": "user",#         "message_type": "audio",#         "content": "test_audios/asr_example.wav",#     }# ]# result = model.generate(messages, output_type="sec")  # 声音事件分类# print(">>> SEC 分类结果: ", result)# result = model.generate(messages, output_type="asc")  # 声学场景分类# print(">>> ASC 分类结果: ", result)base_dir=r"/nas/lbg/project/Kimi-Audio/test_audios/music/Music"files=glob.glob(base_dir+"/*.mp3")for file in files:messages = [{"role": "user", "message_type": "text", "content": "请判断这个音频属于以下哪一类: 无人声无音乐、说话、纯音乐、唱歌。只输出类别。"},{"role": "user","message_type": "audio","content": file,},]wav, text = model.generate(messages, **sampling_params, output_type="text")file_name=os.path.basename(file)print(">>> 分类结果: ", text,file_name)

saurabhati/DASS_medium_AudioSet_50.2

import torch
import librosa
from transformers import AutoConfig, AutoModelForAudioClassification, AutoFeatureExtractorconfig = AutoConfig.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
audio_model = AutoModelForAudioClassification.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)waveform, sr = librosa.load("audio/eval/_/_/--4gqARaEJE_0.000.flac", sr=16000)
inputs = feature_extractor(waveform,sr, return_tensors='pt')with torch.no_grad():logits = torch.sigmoid(audio_model(**inputs).logits)predicted_class_ids = torch.where(logits[0] > 0.5)[0]
predicted_label = [audio_model.config.id2label[i.item()] for i in predicted_class_ids]
predicted_label
['Animal', 'Domestic animals, pets', 'Dog']

ast

模型名称数据集层数 - 头数性能指标(mAP)发布时间核心特点
MIT/ast-finetuned-audioset-16-16-0.442AudioSet16 层 - 16 头0.4422023早期版本,参数较多,适合复杂音频分类
MIT/ast-finetuned-audioset-14-14-0.443AudioSet14 层 - 14 头0.4432023参数精简,性能小幅提升
MIT/ast-finetuned-audioset-12-12-0.447AudioSet12 层 - 12 头0.4472024进一步优化结构,性能显著提升
MIT/ast-finetuned-speech-commands-v2Speech Commands V2轻量结构98.1% 准确率-专为语音命令识别设计,实时性强

https://github.com/YuanGongND/ast

import os 
import torch
from models import ASTModel 
# download pretrained model in this directory
os.environ['TORCH_HOME'] = '../pretrained_models'  
# assume each input spectrogram has 100 time frames
input_tdim = 100
# assume the task has 527 classes
label_dim = 527
# create a pseudo input: a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins 
test_input = torch.rand([10, input_tdim, 128]) 
# create an AST model
ast_mdl = ASTModel(label_dim=label_dim, input_tdim=input_tdim, imagenet_pretrain=True)
test_output = ast_mdl(test_input) 
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes. 
print(test_output.shape)  

gpt 生成代码:

import torch
import torchaudio
from ast.model import ASTModel# 1. 加载预训练的AST模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASTModel(label_dim=527, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=True)
model = model.to(device)
model.eval()# 2. 加载并预处理音频文件
audio_path = "example.wav"
waveform, sample_rate = torchaudio.load(audio_path)# 将音频转换为单声道并重采样到16kHz
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:resampler = torchaudio.transforms.Resample(sample_rate, 16000)waveform = resampler(waveform)# 3. 提取频谱图
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=16000, use_energy=False, window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10
)# 调整频谱图大小以适应模型输入
n_frames = fbank.shape[0]
p = 1024 - n_frames
if p > 0:# 如果太短则填充m = torch.nn.ZeroPad2d((0, 0, 0, p))fbank = m(fbank)
elif p < 0:# 如果太长则截断fbank = fbank[0:1024, :]# 4. 标准化
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)# 5. 准备输入张量
input_tensor = fbank.unsqueeze(0).to(device)# 6. 进行预测
with torch.no_grad():output = model(input_tensor)# 7. 获取预测结果
probabilities = torch.sigmoid(output)
top5_prob, top5_labels = torch.topk(probabilities, 5)# 加载标签(假设我们有标签列表)
labels = [...]  # 这里应该是你的分类标签列表print("Top 5 predictions:")
for i in range(5):print(f"{labels[top5_labels[0][i]]}: {top5_prob[0][i]*100:.2f}%")

AudioClassification-Pytorch没有模型,自己训练:

https://github.com/yeyupiaoling/AudioClassification-Pytorch

http://www.xdnf.cn/news/18091.html

相关文章:

  • 集成电路学习:什么是Face Detection人脸检测
  • CentOS 7.9 部署 filebrowser 文件管理系统
  • 动态规划:入门思考篇
  • 【完整源码+数据集+部署教程】海洋垃圾与生物识别系统源码和数据集:改进yolo11-RVB
  • 第一阶段C#基础-15:面向对象梳理
  • nsfp-
  • 《Unity Shader入门精要》学习笔记二
  • 多数据源 Demo
  • python 数据拟合(线性拟合、多项式回归)
  • WPF 打印报告图片大小的自适应(含完整示例与详解)
  • quic协议与应用开发
  • 实战架构思考及实战问题:Docker+‌Jenkins 自动化部署
  • [Oracle数据库] Oracle 进阶应用
  • 基于 ONNX Runtime 的 YOLOv8 高性能 C++ 推理实现
  • 网络间的通用语言TCP/IP-网络中的通用规则2
  • CMakeLists.txt 学习笔记
  • Java中的128陷阱:深入解析Integer缓存机制及应对策略
  • 深度解析阿里巴巴国际站商品详情 API:从接口调用到数据结构化处理
  • 8.18决策树
  • Unity引擎播放HLS自适应码率流媒体视频
  • 代码随想录算法训练营四十五天|图论part03
  • 上网行为安全管理与组网方案
  • 在阿里云 CentOS Stream 9 64位 UEFI 版上离线安装 Docker Compose
  • 深入解析Kafka消费者重平衡机制与性能优化实践指南
  • Windows从零到一安装KingbaseES数据库及使用ksql工具连接全指南
  • 【Goland】:Map
  • 【音视频】ISP能力
  • iOS 应用上架全流程实践,从开发内测到正式发布的多工具组合方案
  • Qt笔试题
  • HTML应用指南:利用POST请求获取全国华为旗舰店门店位置信息