当前位置: 首页 > web >正文

flow-matching 之学习matcha-tts cosyvoice

文章目录

  • matcha 实现
  • cosyvoice 实现
  • chunk_fm
    • chunk_mask
    • cache_attn
  • stream token2wav

  • 关于flow-matching 很好的原理性解释文章, 值得仔细读,多读几遍,关于文章Flow Straight and Fast:
    Learning to Generate and Transfer Data with Rectified Flow 的讲解梳理。

matcha 实现

def fm_comput_loss()# x1 是target_mel# random timestept = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)# sample noise p(x_0)z = torch.randn_like(x1)y = (1 - (1 - self.sigma_min) * t) * z + t * x1u = x1 - (1 - self.sigma_min) * zpred_y = self.estimator(y, mask, mu, t.squeeze(), spks)loss = F.mse_loss(pred_y, u, reduction="sum") / (torch.sum(mask) * u.shape[1])return loss, y
def estimator_forward():x = pack(y, mu)x = pack(x, spks)q,k,v = x, x, xx = slf_attn(q,k,v)outputs = linear(x)return outputs

cosyvoice 实现

def fm_forward():# mu: encoder_outputs# x1: target_mel# cond: prompt_mel 随机取的部分conds = torch.zeros(feat.shape, device=token.device)for i, j in enumerate(feat_len):if random.random() < 0.5:continueindex = random.randint(0, int(0.3 * j))conds[i, :index] = feat[i, :index]conds = conds.transpose(1, 2)b, _, t = mu.shape# random timestept = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)if self.t_scheduler == 'cosine':t = 1 - torch.cos(t * 0.5 * torch.pi)# sample noise p(x_0)z = torch.randn_like(x1)y = (1 - (1 - self.sigma_min) * t) * z + t * x1u = x1 - (1 - self.sigma_min) * z# during training, we randomly drop condition to trade off mode coverage and sample fidelity# inference 的时候实际不需要condition, 给zero就可以if self.training_cfg_rate > 0:cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_ratemu = mu * cfg_mask.view(-1, 1, 1)spks = spks * cfg_mask.view(-1, 1)cond = cond * cfg_mask.view(-1, 1, 1)pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])return loss, ydef estimator(x, mu, spks, cond):x = pack(x, mu, spks, cond)x = slf_attn(x)outputs = linear(x)return outputs

chunk_fm

  • 训练的时候将特征进行chunk_mask,推理的时候只准备chunk的部分,pre_chunk 存为kv_cache,
  • cache 初始seq_len为0;每次得到的cache,只留下[-chunk_len:] 的长度,作为下一次的输入;特征x 的pos 按照真的来算;

chunk_mask

在这里插入图片描述

  • 训练阶段样本按照seq_len 维度被mask 成不同的可见部分;chunk_mask 和长度mask 都会出现,为了加速收敛;

cache_attn

def slf_attn_cache(x, cache):k_in, v_in, q_in = x, x, xkey_cache = linear1(k_in)value_cache = linear2(v_in)# NOTE here we judge cache.size(0) instead of cache.size(1), because init_cache has size (2, 0, 512, 2)if cache.size(0) != 0:# step into this branchkey = torch.concat([cache[:, :, :, 0], key_cache], dim=1)value = torch.concat([cache[:, :, :, 1], value_cache], dim=1)else:key, value = key_cache, value_cachecache = torch.stack([key_cache, value_cache], dim=3)outputs = scale_dot_production(key, value)return outputs, cache

stream token2wav

在这里插入图片描述

  • 第一个包没有kv_cache, 卷积cache 有,但是值为0;first chunk 推理完就可以存下kv_cache & cnn_cache;
  • 输入token+cache_token,得到token 对应的mel;
  • mel2wav 阶段也是,第一次没有hift cache,直接退出mel 对应的wav,最后8帧存下来作为hift_cache,用于hift_wav 预测以及输出的音频片段间平滑;前n-8 帧的音频输出;
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)):# mel->f0 mel:[1,80,T]#print('hift inference speech_feat', speech_feat.size())f0 = self.f0_predictor(speech_feat)# f0->sources = self.f0_upsamp(f0[:, None]).transpose(1, 2)  # bs,n,ts, _, _ = self.m_source(s)s = s.transpose(1, 2) #sample#print('f0 s', s.size()) # sample level# use cache_source to avoid glitchif cache_source.shape[2] != 0:# s[1,1,t*480]s[:, :, :cache_source.shape[2]] = cache_sourceprint('cache_source s', s.size())else:print('cache_source shape2 is 0')#print('hift inference s', s.size())generated_speech = self.decode(x=speech_feat, s=s)return generated_speech, s
http://www.xdnf.cn/news/4690.html

相关文章:

  • 集团云解决方案:集团企业IT基础架构的降本增效利器
  • RAG技术在测试用例生成中的应用
  • FAST角点检测算法原理附C++代码实现
  • HarmonyOS NEXT之深度解析ArkUI自定义组件:从基础实现到生产级登录组件的进化之路
  • 复盘20250508
  • CSS:元素显示模式与背景
  • 【Java ee 初阶】文件IO和操作(下)
  • 系统架构-面向服务架构(SOA)
  • 【嵌入式开发-SPI】
  • 常见的提示词攻击方法 和防御手段——提示词注入(Prompt Injection)攻击解析
  • 了解Dockerfile
  • 【计算机网络 第8版】谢希仁编著 第四章网络层 题型总结2
  • 如何用分布式防御抵扣大规模DDoS攻击?
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】电商数据分析案例-9.2 流量转化漏斗分析
  • 前端实战中的单例模式:以医疗药敏管理为例
  • [论文笔记] 超详细解读DeepSeek v3全论文技术报告
  • 零基础入门Hadoop:IntelliJ IDEA远程连接服务器中Hadoop运行WordCount
  • TDEngine 与 Grafana
  • 从零开始在亚马逊云科技 EC2上部署DeepSeek R1大语言模型:完整实战指南
  • Linux 网络命名空间:从内核资源管理到容器网络隔离
  • 算法与数据结构 - 常用图算法总结
  • 观测云:安全、可信赖的监控观测云服务
  • 《React Native性能优化:从卡顿到丝滑的蜕变之旅》
  • 菊厂笔试1
  • Django rest_framework 信号机制生成并使用token
  • SSH 服务部署指南
  • 学习基本乐理知识
  • 【C/C++】RPC与线程间通信:高效设计的关键选择
  • 如何使用npm下载指定版本的cli工具
  • Git查看某个commit的改动