当前位置: 首页 > ops >正文

使用 BayesFlow 通过神经网络简化贝叶斯推断(一)

本篇文章Easier Bayesian Inference with Neural Networks using BayesFlow (Code Included)是介绍BayesFlow 大致功能的开篇。


文章目录

  • 1 什么是 BayesFlow?
    • 1.1 工作原理:BayesFlow 工作流
    • 1.2 组件
    • 1.3 核心功能
  • 2 实际应用
    • 2.1 简单入门
    • 2.2 专家级定制
    • 2.3 使用 BayesFlow 进行贝叶斯线性回归:摊销推理的实践入门
      • 2.3.1 为什么选择摊销推理?
      • 2.3.2 核心架构:摘要网络与推理网络
      • 2.3.3 分步实现
      • 2.3.4 定义生成模型
      • 2.3.5 通过适配器准备数据
      • 2.3.6 构建神经网络
        • 2.3.6.1 摘要网络
        • 2.3.6.2 推理网络
    • 2.4 连接所有组件:摊销器
    • 2.5 后验估计
    • 2.6 结果可视化


贝叶斯推断为不确定性下的推理、复杂系统建模以及基于观测数据进行预测提供了一种有原则且强大的方法。然而,尽管贝叶斯建模优雅,但它常常遇到严重的计算障碍:

后验分布通常难以处理。

模型验证和比较需要重复推断。

基于仿真的工作流(例如,校准、恢复、敏感性分析)变得慢得令人望而却步。

这种计算成本传统上限制了贝叶斯工作流的实际应用——直到 BayesFlow 的出现。

1 什么是 BayesFlow?

BayesFlow 是一个开源的 Python 库,旨在利用摊销神经网络加速和扩展贝叶斯推断。通过训练神经网络“学习”逆问题(从数据推断参数)或正向模型(从参数生成数据),BayesFlow 可以在初始训练后实现近乎即时的推断——通常在毫秒级完成。

核心思想: 一次性投入计算资源训练神经网络,然后将其重复用于数千次快速推断

BayesFlow 基于 TensorFlow 构建,无缝支持 GPU/TPU 加速,并与 TensorFlow Probability 集成,以实现灵活的先验和潜在变量。

1.1 工作原理:BayesFlow 工作流

BayesFlow 的核心是一个形式化、模块化的架构,它模仿了传统贝叶斯工作流的关键组件,但通过神经网络近似器对其进行了超强赋能。其工作原理如下:

1.2 组件

  1. 模拟 + 先验:定义你的生成模型(例如,流行病学中的 SIR 模型)。
  2. 配置器:准备用于训练的数据(例如,归一化、嵌入)。
  3. 神经网络
    • 摘要网络:将原始模拟数据或参数压缩为密集嵌入。
    • 后验网络:学习从数据到参数的逆映射。
    • 似然网络:学习从参数到数据的正向映射。

这些网络可以组合使用,也可以根据你的任务(后验估计、似然模拟、模型比较等)独立使用。

1.3 核心功能

BayesFlow 支持现代贝叶斯工作流的四个关键功能:

  1. Amortized 后验估计
    一次训练,多次推断。实现跨数据集的完整后验快速估计。
    → 解决逆问题。
  2. Amortized 似然估计
    模拟复杂模拟器以估计似然,无需重新运行。
    → 解决正向问题。
  3. Amortized 模型比较
    根据模型解释数据的能力对模型进行分类或排序——使用学习到的后验和似然。
    → 计算贝叶斯证据和预测准确性。
  4. 模型误设定检测
    诊断你的模拟器何时不再代表现实——即使推断“有效”。
    → 避免自信地犯错。

2 实际应用

BayesFlow 不仅仅是理论——它已被部署到广泛的领域:

  • 流行病学:使用基于模拟的 SIR 模型进行疾病传播建模。
  • 神经科学与精神病学:认知和计算模型的参数恢复。
  • 地震学:地震建模中的高维逆问题。
  • 粒子物理学:复杂模拟器的快速代理模型。
  • 航空航天、MEMS、风力涡轮机:不确定性下的工程设计。

简而言之:如果你有一个模拟器,你就可以使用 BayesFlow。

2.1 简单入门

以下是入门的简单方法:

import bayesflow as bfworkflow = bf.BasicWorkflow(inference_network=bf.networks.CouplingFlow(),summary_network=bf.networks.TimeSeriesNetwork(),inference_variables=["parameters"],summary_variables=["observables"],simulator=bf.simulators.SIR()
)
history = workflow.fit_online(epochs=15, batch_size=32, num_batches_per_epoch=200)
diagnostics = workflow.plot_default_diagnostics(test_data=300)

无需构建复杂的训练循环——BayesFlow 处理从模拟到诊断的所有环节。

2.2 专家级定制

BayesFlow 提供:

  • 一个用户友好的 API,适用于应用研究人员。
  • 一个模块化设计,供机器学习专家插入自定义网络、训练方案或推断策略。
  • 开箱即用的默认设置,适用于许多基于模拟的模型。

无论你是为认知建模构建管道,还是为航空航天设计调整代理模型,BayesFlow 都能适应你的工作流。

2.3 使用 BayesFlow 进行贝叶斯线性回归:摊销推理的实践入门

欢迎来到我们使用 BayesFlow 的第一个演练——一个用于通过神经网络进行摊销贝叶斯推断的强大库。在本教程中,我们将使用一个简单的线性回归示例来探索摊销后验估计的基本概念,并演示 BayesFlow 的模块化架构。

我们将通过使用 BayesFlow 的低级 API 来保持透明,从而完全控制每个组件——从模拟器创建到网络架构。如果你刚开始学习并想了解内部工作原理,这将是完美的选择。

2.3.1 为什么选择摊销推理?

传统贝叶斯推断中,我们根据观测数据估计模型参数的后验分布。这通常需要计算成本高昂的方法,如 MCMC 或变分推断——对于每个新数据集都是如此。

但是,如果我们能学会推断呢?

这就是摊销贝叶斯推断的切入点:我们不是为每个新数据集从头开始计算后验,而是训练一个神经网络学习一个函数,该函数直接将数据映射到后验估计。一旦训练完成,这种方法就可以对新数据集进行即时推断

这在高吞吐量、实时或基于模拟的推断设置中尤其有价值。

2.3.2 核心架构:摘要网络与推理网络

我们的 BayesFlow 模型由两个核心网络组成:

  • 摘要网络:将可变长度的输入数据(如观测值)转换为固定长度的嵌入。
  • 推理网络:使用条件生成模型(通常是可逆神经网络)基于此嵌入学习从近似后验中采样。

这些网络共同学习“反转”一个从潜在参数生成数据的模拟器。

2.3.3 分步实现

让我们首先导入必要的库并设置 BayesFlow 环境。

import numpy as np
from pathlib import Path
import keras
import bayesflow as bfnp.set_printoptions(suppress=True)

2.3.4 定义生成模型

我们首先为基本线性回归模型定义似然

def likelihood(beta, sigma, N):x = np.random.normal(0, 1, size=N)y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N)return dict(y=y, x=x)

现在,定义我们模型参数的先验

def prior():beta = np.random.normal([2, 0], [3, 1])sigma = np.random.gamma(1, 1)return dict(beta=beta, sigma=sigma)

为了实现不同数据大小的摊销,定义一个元函数来采样数据集大小:

def meta():N = np.random.randint(5, 15)return dict(N=N)

现在,让我们将上述所有内容封装在一个 BayesFlow 模拟器中:

simulator = bf.simulators.make_simulator([prior, likelihood], meta_fn=meta)

从模拟器中采样:

sim_draws = simulator.sample(500)

2.3.5 通过适配器准备数据

BayesFlow 提供灵活的适配器管道来准备用于训练的原始模拟数据。

adapter = (bf.Adapter().broadcast("N", to="x").as_set(["x", "y"]).constrain("sigma", lower=0).standardize(exclude=["N"]).sqrt("N").convert_dtype("float64", "float32").concatenate(["beta", "sigma"], into="inference_variables").concatenate(["x", "y"], into="summary_variables").rename("N", "inference_conditions")
)

此适配器执行:

  • 上下文变量([N])的广播
  • 标准化(排除常量)
  • 维度检查
  • 连接和重塑

运行适配器:

processed_draws = adapter(sim_draws)

检查形状:

print(processed_draws["summary_variables"].shape)
print(processed_draws["inference_variables"].shape)
print(processed_draws["inference_conditions"].shape)

2.3.6 构建神经网络

2.3.6.1 摘要网络

由于我们的数据是置换不变的(顺序无关紧要),我们使用 SetTransformerDeepSet 架构从 ([x], [y]) 观测值中学习有意义的嵌入。

summary_net = bf.networks.DeepSet(input_shape=(None, 2), output_dim=64)
2.3.6.2 推理网络

我们将使用 BayesFlow 可逆网络来建模后验分布:

inference_net = bf.networks.InvertibleNetwork(n_params=3, num_coupling_layers=6)

2.4 连接所有组件:摊销器

BayesFlow 提供了一个方便的 Amortizer 类,它组合了所有组件。

amortizer = bf.amortizers.AmortizedPosterior(summary_net=summary_net,inference_net=inference_net
)

使用 Keras 风格的回调进行编译和训练:

amortizer.compile(optimizer="adam")
amortizer.train(processed_draws, epochs=30, batch_size=64)

2.5 后验估计

训练完成后,我们可以为任何新数据集推断后验样本:

test_data = adapter(simulator.sample(1))
posterior_samples = amortizer.sample(test_data["summary_variables"],conditions=test_data["inference_conditions"],n_samples=1000)

2.6 结果可视化

BayesFlow 包含方便的诊断工具来可视化结果:

bf.diagnostics.plots.pairs_samples(samples=posterior_samples,variable_names=[r"$\beta_0$", r"$\beta_1$", r"$\sigma$"]
)
http://www.xdnf.cn/news/19360.html

相关文章:

  • 中医文化学习软件,传承国粹精华
  • 动态滑动窗口还搞不清?一文搞定动态滑动窗口 | 基础算法
  • Windows系统安装Git详细教程
  • 【Java后端】Spring Boot 全局域名替换
  • TCP实现线程池竞争任务
  • FPGA|Quartus II 中使用TCL文件进行引脚一键分配
  • 深入理解零拷贝:本地IO与网络IO的性能优化利器
  • Docker基本介绍
  • MySQL 慢查询 debug:索引没生效的三重陷阱
  • 深度学习框架与工具使用心得:从入门到实战优化
  • 动作指令活体检测通过动态交互验证真实活人,保障安全
  • 数字后端tap cell:新老工艺tap cell区别
  • 软考中级数据库系统工程师学习专篇(67、数据库恢复)
  • Linux网络socket套接字(中)
  • AI人工智能大模型应用如何落地
  • DriveDreamer-2
  • C++ 模板全览:从“非特化”到“全特化 / 偏特化”的完整原理与区别
  • CUDA与图形API的深度互操作:解锁GPU硬件接口的真正潜力
  • Linux 系统都有哪些
  • Playwright Python 教程:实战篇
  • docker中的命令(四)
  • Coze源码分析-工作空间-项目开发-前端源码
  • 如何重置SVN被保存的用户名和密码
  • 【pve】
  • 轻量化注意力+脉冲机制,Transformer在低功耗AI中再度进化
  • 吴恩达机器学习作业十 PCA主成分分析
  • 基于单片机智能大棚/温室大棚/智慧农业/智能栽培种植系统/温湿度控制
  • LeetCode 37.解数独
  • k8s三阶段项目
  • 狂神说--Nginx--通俗易懂