当前位置: 首页 > java >正文

whisper 语种检测学习笔记

目录

transformers推理:

transformers 源代码

网上的语种检测调用例子:

语种检测 api


transformers推理:

https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py

   waveform, sample_rate = torchaudio.load(file_path)# Ensure the sample rate is 16000 Hz (Whisper's expected sample rate)if sample_rate != 16000:waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)inputs = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)with torch.no_grad():generated_ids = model.generate(inputs["input_features"])# Extract language token from the model's outputlanguage_token = processor.tokenizer.decode(generated_ids[0][:2])  # First two tokensreturn processor.tokenizer.convert_tokens_to_string(language_token)

transformers 源代码

https://github.com/huggingface/transformers/blob/05000aefe173bf7a10fa1d90e4c528585b45d3c7/src/transformers/models/whisper/generation_whisper.py#L1622

def detect_language(self,input_features: Optional[torch.FloatTensor] = None,encoder_outputs: Optional[Union[torch.FloatTensor, BaseModelOutput]] = None,generation_config: Optional[GenerationConfig] = None,num_segment_frames: int = 3000,) -> torch.Tensor:"""Detects language from log-mel input features or encoder_outputsParameters:input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*):Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained byloading a `.flac` or `.wav` audio file into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* viathe soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the[`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into atensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details.encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence ofhidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.generation_config (`~generation.GenerationConfig`, *optional*):The generation configuration to be used as base parametrization for the generation call. `**kwargs`passed to generate matching the attributes of `generation_config` will override them. If`generation_config` is not provided, the default will be used, which had the following loadingpriority: 1) from the `generation_config.json` model file, if it exists; 2) from the modelconfiguration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'sdefault values, whose documentation should be checked to parameterize generation.num_segment_frames (`int`, *optional*, defaults to 3000):The number of log-mel frames the model expectsReturn:A `torch.LongTensor` representing the detected language ids."""if input_features is None and encoder_outputs is None:raise ValueError("You have to specify either `input_features` or `encoder_outputs`")elif input_features is not None and encoder_outputs is not None:raise ValueError("Make sure to specify only one of `input_features` or `encoder_outputs` - not both!")elif input_features is not None:inputs = {"input_features": input_features[:, :, :num_segment_frames]}batch_size = input_features.shape[0]elif encoder_outputs is not None:inputs = {"encoder_outputs": encoder_outputs}batch_size = (encoder_outputs[0].shape[0] if isinstance(encoder_outputs, BaseModelOutput) else encoder_outputs[0])generation_config = generation_config or self.generation_configdecoder_input_ids = (torch.ones((batch_size, 1), device=self.device, dtype=torch.long)* generation_config.decoder_start_token_id)with torch.no_grad():logits = self(**inputs, decoder_input_ids=decoder_input_ids, use_cache=False).logits[:, -1]non_lang_mask = torch.ones_like(logits[0], dtype=torch.bool)non_lang_mask[list(generation_config.lang_to_id.values())] = Falselogits[:, non_lang_mask] = -np.inflang_ids = logits.argmax(-1)return lang_ids

网上的语种检测调用例子:

import whispermodel = whisper.load_model("base") # 加载预训练的语音识别模型,这里使用了名为"base"的模型。# load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio("audio.mp3")
audio = whisper.pad_or_trim(audio)  # 对加载的音频进行填充或裁剪,使其适合30秒的滑动窗口处理。# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to(model.device) 
# 将音频转换为对数梅尔频谱图,并将其移动到与模型相同的设备(如GPU)上进行处理。# detect the spoken language
_, probs = model.detect_language(mel) # 使用模型进行语言检测,返回检测到的语言和对应的概率。
# 打印检测到的语言,选取概率最高的语言作为结果。
print(f"Detected language: {max(probs, key=probs.get)}")# decode the audio
# 置解码的选项,如语言模型、解码器等。
options = whisper.DecodingOptions()
# 使用模型对音频进行解码,生成识别结果。
result = whisper.decode(model, mel, options)# print the recognized text
# 打印识别结果,即模型识别出的文本内容。
print(result.text)
————————————————
版权声明:本文为CSDN博主「陌上阳光」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_42831564/article/details/138667560

语种检测 api

语种检测源代码:

https://github.com/openai/whisper/blob/c0d2f624c09dc18e709e37c2ad90c039a4eb72a2/whisper/decoding.py

@torch.no_grad()
def detect_language(model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:"""Detect the spoken language in the audio, and return them as list of strings, along with the idsof the most probable language tokens and the probability distribution over all language tokens.This is performed outside the main decode loop in order to not interfere with kv-caching.Returns-------language_tokens : Tensor, shape = (n_audio,)ids of the most probable language tokens, which appears after the startoftranscript token.language_probs : List[Dict[str, float]], length = n_audiolist of dictionaries containing the probability distribution over all languages."""if tokenizer is None:tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)if (tokenizer.language is Noneor tokenizer.language_token not in tokenizer.sot_sequence):raise ValueError("This model doesn't have language tokens so it can't perform lang id")single = mel.ndim == 2if single:mel = mel.unsqueeze(0)# skip encoder forward pass if already-encoded audio features were givenif mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):mel = model.encoder(mel)# forward pass using a single token, startoftranscriptn_audio = mel.shape[0]x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device)  # [n_audio, 1]logits = model.logits(x, mel)[:, 0]# collect detected languages; suppress all non-language tokensmask = torch.ones(logits.shape[-1], dtype=torch.bool)mask[list(tokenizer.all_language_tokens)] = Falselogits[:, mask] = -np.inflanguage_tokens = logits.argmax(dim=-1)language_token_probs = logits.softmax(dim=-1).cpu()language_probs = [{c: language_token_probs[i, j].item()for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)}for i in range(n_audio)]if single:language_tokens = language_tokens[0]language_probs = language_probs[0]return language_tokens, language_probs

http://www.xdnf.cn/news/17898.html

相关文章:

  • 39 C++ STL模板库8-容器1-array
  • 解决hexo deploy报错:fatal: bad config line 1 in file .git/config
  • 跨网络 SSH 访问:借助 cpolar 内网穿透服务实现手机远程管理 Linux
  • 图像识别控制技术(Sikuli)深度解析:原理、应用与商业化前景
  • Vue 组件二次封装透传slots、refs、attrs、listeners
  • 把 AI 装进“冰箱贴”——基于超低功耗语音合成的小屏电子价签
  • StringBoot-SSE和WebFlux方式消息实时推送-默认单向-可增加交互接口
  • C语言中的输入输出函数:构建程序交互的基石
  • 开源数据发现平台:Amundsen Frontend Service 应用程序配置
  • 基于CodeBuddy的2D游戏开发实践:炫酷大便超人核心机制解析
  • NOI Online培训1至26期例题解析(16-20期)
  • week1-[一维数组]传送
  • MySQLl中OFFSET 的使用方法
  • PIDGenRc函数中lpstrRpc的由来和InitializePidVariables函数的关系
  • JMeter性能测试详细版(适合0基础小白学习--非常详细)
  • 基于SpringBoot的救援物资管理系统 受灾应急物资管理系统 物资管理小程序
  • 浏览器环境下AES-GCM JavaScript 加解密程序
  • Elasticsearch ABAC 配置:实现动态、细粒度的访问控制
  • 【C#】跨平台创建你的WinForms窗体应用(WindowsUbuntu)
  • 新手入门 Makefile:FPGA 项目实战教程(一)
  • Java面试场景题大全精简版
  • vue3使用leaflet地图
  • 力扣(LeetCode) ——225 用队列实现栈(C语言)
  • 算法基础 第3章 数据结构
  • C++类与对象核心知识点全解析(中)【六大默认成员函数详解】
  • P1281 [CERC1998] 书的复制
  • TCP 连接管理:深入分析四次握手与三次挥手
  • 2025年大模型安全岗的面试汇总(题目+回答)
  • 扩展用例-失败的嵌套
  • 大语言模型基础