音频分类模型笔记
目录
audio-mamba-aum
Small ImageNet
kimia_infer 音频理解
音频分类:
saurabhati/DASS_medium_AudioSet_50.2
ast
AudioClassification-Pytorch没有模型,自己训练:
OmniVec2.0 | 0.558 |
OmniVec | 0.548 |
EquiAV | 0.546 |
MAViL (Audio-Visual, single) | 0.533 |
PaSST-S / ConvNeXt-Tiny | 0.471 |
PSLA (Ensemble EfficientNet) | 0.474 |
AST | 0.485 |
Audio-MAE (SOTA self-supervised) | 超过已有监督方法(具体数值未详) |
你提到的 DASS_medium_AudioSet | 0.502 |
https://huggingface.co/hongroklim/omniverse-github/tree/main
https://github.com/JongSuk1/EquiAV
audio-mamba-aum
https://github.com/kaistmm/Audio-Mamba-AuM?tab=readme-ov-file
代码没测
import torch
import torchaudio
from pathlib import Path# 假设模型定义在 src/model.py 中(你需要根据实际路径替换)
from src.model import AudioMambaModel
from src.config import get_config # 如果配置是分离的def load_model(checkpoint_path, config_path=None, device='cuda'):# 初始化模型配置(如果是单文件加载则可简化)config = get_config(config_path) if config_path else Nonemodel = AudioMambaModel(config) if config else AudioMambaModel()checkpoint = torch.load(checkpoint_path, map_location=device)model.load_state_dict(checkpoint['model_state_dict'])model.to(device).eval()return modeldef preprocess_audio(audio_path, target_sample_rate=16000):waveform, sr = torchaudio.load(audio_path)if sr != target_sample_rate:waveform = torchaudio.transforms.Resample(sr, target_sample_rate)(waveform)# 这里可以加入其他预处理,如分帧、增益标准化等return waveformdef infer(model, waveform, device='cuda'):waveform = waveform.to(device)with torch.no_grad():outputs = model(waveform.unsqueeze(0)) # 增加 batch 维probs = torch.softmax(outputs, dim=1)top_prob, top_label = torch.max(probs, dim=1)return top_label.item(), top_prob.item()if __name__ == "__main__":checkpoint_path = "path/to/your/checkpoint.pth" # 替换为你下载的 checkpointconfig_path = None # 如果需要可提供配置文件路径audio_path = "path/to/your/audio.wav"label_map = {0: "class_a", 1: "class_b", ...} # 填入对应任务标签映射device = "cuda" if torch.cuda.is_available() else "cpu"model = load_model(checkpoint_path, config_path, device)waveform = preprocess_audio(audio_path)label_idx, confidence = infer(model, waveform, device)print(f"Predicted label: {label_map[label_idx]} (confidence: {confidence:.4f})")
Small ImageNet
These are the checkpoints for the small models with the variant Bi-Bi (c)
, initialized with ImageNet pretrained weights.
Dataset | #Params | Performance | Checkpoint |
---|---|---|---|
Audioset (mAP) | 25.5M | 39.74 | Link |
AS-20K (mAP) | 25.5M | 29.17 | Link |
VGGSound (Acc) | 25.5M | 49.61 | Link |
VoxCeleb (Acc) | 25.8M | 41.78 | Link |
Speech Commands V2 (Acc) | 25.2M | 97.61 | Link |
Epic Sounds (Acc) | 25.4M | 53.45 | Link |
kimia_infer 音频理解
git clone https://github.com/MoonshotAI/Kimi-Audio
git submodule update --init
cd Kimi-Audio
docker build -t kimi-audio:v0.1 .Alternatively, You can also use our pre-built image:docker pull moonshotai/kimi-audio:v0.1
docker pull 的方式成功了。
docker run -it a49762d13a3d bash
import soundfile as sf
import torch
from kimia_infer.api.kimia import KimiAudiodef main():# 1. 模型加载model_id = "moonshotai/Kimi-Audio-7B-Instruct" # 或者 "Kimi/Kimi-Audio-7B"device = "cuda" if torch.cuda.is_available() else "cpu"model = KimiAudio(model_path=model_id, load_detokenizer=True)model.to(device)# 2. 设置采样参数,调整生成行为sampling_params = {"audio_temperature": 0.8,"audio_top_k": 10,"text_temperature": 0.0,"text_top_k": 5,"audio_repetition_penalty": 1.0,"audio_repetition_window_size": 64,"text_repetition_penalty": 1.0,"text_repetition_window_size": 16,}# 3. 示例任务 A:音频转文本(ASR)asr_audio = "path/to/your_asr_audio.wav" # 请替换为真实音频路径messages_asr = [{"role": "user", "message_type": "text", "content": "请转录下面这段音频:"},{"role": "user", "message_type": "audio", "content": asr_audio}]_, text_out = model.generate(messages_asr, **sampling_params, output_type="text")print("ASR 输出内容:", text_out)# 4. 示例任务 B:音频问答(Audio-to-Audio/Text)qa_audio = "path/to/your_qa_audio.wav" # 替换为真实音频路径messages_qa = [{"role": "user", "message_type": "audio", "content": qa_audio}]wav_out, text_qa = model.generate(messages_qa, **sampling_params, output_type="both")output_wav = "output_generated.wav"sf.write(output_wav, wav_out.detach().cpu().view(-1).numpy(), 24000)print("问答生成音频已保存至:", output_wav)print("问答输出文本:", text_qa)if __name__ == "__main__":main()
音频分类:
# coding=utf-8
import glob
import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(current_dir)
print('current_dir', current_dir)
paths = [current_dir, current_dir+'/../']
paths.append(os.path.join(current_dir, 'src'))
for path in paths:sys.path.insert(0, path)os.environ['PYTHONPATH'] = (os.environ.get('PYTHONPATH', '') + ':' + path).strip(':')from kimia_infer.api.kimia import KimiAudio
import os
import soundfile as sf
import argparseif __name__ == "__main__":parser = argparse.ArgumentParser()# parser.add_argument("--model_path", type=str, default="moonshotai/Kimi-Audio-7B-Instruct")parser.add_argument("--model_path", type=str, default="/nas/lbg/models/Kimi-Audio-7B-Instruct")args = parser.parse_args()model = KimiAudio(model_path=args.model_path,load_detokenizer=True,)sampling_params = {"audio_temperature": 0.8,"audio_top_k": 10,"text_temperature": 0.0,"text_top_k": 5,"audio_repetition_penalty": 1.0,"audio_repetition_window_size": 64,"text_repetition_penalty": 1.0,"text_repetition_window_size": 16,}# messages = [# {# "role": "user",# "message_type": "audio",# "content": "test_audios/asr_example.wav",# }# ]# result = model.generate(messages, output_type="sec") # 声音事件分类# print(">>> SEC 分类结果: ", result)# result = model.generate(messages, output_type="asc") # 声学场景分类# print(">>> ASC 分类结果: ", result)base_dir=r"/nas/lbg/project/Kimi-Audio/test_audios/music/Music"files=glob.glob(base_dir+"/*.mp3")for file in files:messages = [{"role": "user", "message_type": "text", "content": "请判断这个音频属于以下哪一类: 无人声无音乐、说话、纯音乐、唱歌。只输出类别。"},{"role": "user","message_type": "audio","content": file,},]wav, text = model.generate(messages, **sampling_params, output_type="text")file_name=os.path.basename(file)print(">>> 分类结果: ", text,file_name)
saurabhati/DASS_medium_AudioSet_50.2
import torch
import librosa
from transformers import AutoConfig, AutoModelForAudioClassification, AutoFeatureExtractorconfig = AutoConfig.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
audio_model = AutoModelForAudioClassification.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)
feature_extractor = AutoFeatureExtractor.from_pretrained('saurabhati/DASS_medium_AudioSet_50.2',trust_remote_code=True)waveform, sr = librosa.load("audio/eval/_/_/--4gqARaEJE_0.000.flac", sr=16000)
inputs = feature_extractor(waveform,sr, return_tensors='pt')with torch.no_grad():logits = torch.sigmoid(audio_model(**inputs).logits)predicted_class_ids = torch.where(logits[0] > 0.5)[0]
predicted_label = [audio_model.config.id2label[i.item()] for i in predicted_class_ids]
predicted_label
['Animal', 'Domestic animals, pets', 'Dog']
ast
模型名称 | 数据集 | 层数 - 头数 | 性能指标(mAP) | 发布时间 | 核心特点 |
---|---|---|---|---|---|
MIT/ast-finetuned-audioset-16-16-0.442 | AudioSet | 16 层 - 16 头 | 0.442 | 2023 | 早期版本,参数较多,适合复杂音频分类 |
MIT/ast-finetuned-audioset-14-14-0.443 | AudioSet | 14 层 - 14 头 | 0.443 | 2023 | 参数精简,性能小幅提升 |
MIT/ast-finetuned-audioset-12-12-0.447 | AudioSet | 12 层 - 12 头 | 0.447 | 2024 | 进一步优化结构,性能显著提升 |
MIT/ast-finetuned-speech-commands-v2 | Speech Commands V2 | 轻量结构 | 98.1% 准确率 | - | 专为语音命令识别设计,实时性强 |
https://github.com/YuanGongND/ast
import os
import torch
from models import ASTModel
# download pretrained model in this directory
os.environ['TORCH_HOME'] = '../pretrained_models'
# assume each input spectrogram has 100 time frames
input_tdim = 100
# assume the task has 527 classes
label_dim = 527
# create a pseudo input: a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins
test_input = torch.rand([10, input_tdim, 128])
# create an AST model
ast_mdl = ASTModel(label_dim=label_dim, input_tdim=input_tdim, imagenet_pretrain=True)
test_output = ast_mdl(test_input)
# output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
print(test_output.shape)
gpt 生成代码:
import torch
import torchaudio
from ast.model import ASTModel# 1. 加载预训练的AST模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ASTModel(label_dim=527, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=True)
model = model.to(device)
model.eval()# 2. 加载并预处理音频文件
audio_path = "example.wav"
waveform, sample_rate = torchaudio.load(audio_path)# 将音频转换为单声道并重采样到16kHz
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sample_rate != 16000:resampler = torchaudio.transforms.Resample(sample_rate, 16000)waveform = resampler(waveform)# 3. 提取频谱图
fbank = torchaudio.compliance.kaldi.fbank(waveform, htk_compat=True, sample_frequency=16000, use_energy=False, window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10
)# 调整频谱图大小以适应模型输入
n_frames = fbank.shape[0]
p = 1024 - n_frames
if p > 0:# 如果太短则填充m = torch.nn.ZeroPad2d((0, 0, 0, p))fbank = m(fbank)
elif p < 0:# 如果太长则截断fbank = fbank[0:1024, :]# 4. 标准化
fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)# 5. 准备输入张量
input_tensor = fbank.unsqueeze(0).to(device)# 6. 进行预测
with torch.no_grad():output = model(input_tensor)# 7. 获取预测结果
probabilities = torch.sigmoid(output)
top5_prob, top5_labels = torch.topk(probabilities, 5)# 加载标签(假设我们有标签列表)
labels = [...] # 这里应该是你的分类标签列表print("Top 5 predictions:")
for i in range(5):print(f"{labels[top5_labels[0][i]]}: {top5_prob[0][i]*100:.2f}%")
AudioClassification-Pytorch没有模型,自己训练:
https://github.com/yeyupiaoling/AudioClassification-Pytorch