知识蒸馏 Knowledge Distillation 概率链式法则(Probability Chain Rule)
知识蒸馏 Knowledge Distillation 概率链式法则(Probability Chain Rule)
flyfish
代码实践
论文 Generalized Knowledge Distillation (GKD)
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
“概率链式法则”(Probability Chain Rule),有时也被称为“乘法法则”(Multiplication Rule)的推广形式。它是概率论中用于分解“联合概率”为“条件概率乘积”** 的核心工具,尤其适用于处理多个随机变量或多个事件同时发生的概率计算。
概率链式法则的核心思想
当需要计算多个随机变量(或事件)同时发生的“联合概率”时,直接计算往往困难;而概率链式法则的作用是将复杂的联合概率,拆解成一系列更容易计算的“条件概率”的乘积,本质是“逐步细化概率依赖关系”。
多个事件同时发生的概率 = 第一个事件发生的概率 × 第二个事件在第一个事件发生下的概率 × 第三个事件在前两个事件都发生下的概率 × … × 最后一个事件在前面所有事件都发生下的概率。
概率链式法则的数学表达
根据随机变量的类型(离散型/连续型),链式法则的公式形式略有差异,但核心逻辑一致。以下以离散型随机变量为例(连续型只需将“概率”替换为“概率密度函数”,符号从P
改为f
)。
1. 基础形式:两个随机变量
对于两个随机变量 X1X_1X1 和 X2X_2X2,其联合概率 P(X1=x1,X2=x2)P(X_1 = x_1, X_2 = x_2)P(X1=x1,X2=x2)(简记为 P(X1,X2)P(X_1, X_2)P(X1,X2))可通过“乘法法则”拆解:
P(X1,X2)=P(X1)×P(X2∣X1)P(X_1, X_2) = P(X_1) \times P(X_2 \mid X_1) P(X1,X2)=P(X1)×P(X2∣X1)
- P(X1)P(X_1)P(X1):随机变量 X1X_1X1 的边缘概率;
- P(X2∣X1)P(X_2 \mid X_1)P(X2∣X1):在 X1X_1X1 已发生的条件下,X2X_2X2 发生的条件概率。
2. 推广形式:n个随机变量
对于 nnn 个随机变量 X1,X2,...,XnX_1, X_2, ..., X_nX1,X2,...,Xn,其联合概率 P(X1,X2,...,Xn)P(X_1, X_2, ..., X_n)P(X1,X2,...,Xn) 可通过链式法则拆解为 nnn 个条件概率的乘积:
P(X1,X2,...,Xn)=P(X1)×P(X2∣X1)×P(X3∣X1,X2)×...×P(Xn∣X1,X2,...,Xn−1)P(X_1, X_2, ..., X_n) = P(X_1) \times P(X_2 \mid X_1) \times P(X_3 \mid X_1, X_2) \times ... \times P(X_n \mid X_1, X_2, ..., X_{n-1}) P(X1,X2,...,Xn)=P(X1)×P(X2∣X1)×P(X3∣X1,X2)×...×P(Xn∣X1,X2,...,Xn−1)
这是概率链式法则的通用公式,后一个变量的概率依赖于所有前面已发生的变量。
用链式法则计算联合概率
假设关注三个随机变量:
- X1X_1X1:“今天是否下雨”(1=下雨,0=不下雨),已知 P(X1=1)=0.3P(X_1=1) = 0.3P(X1=1)=0.3;
- X2X_2X2:“是否带伞”(1=带伞,0=不带伞),已知下雨时带伞的概率 P(X2=1∣X1=1)=0.9P(X_2=1 \mid X_1=1) = 0.9P(X2=1∣X1=1)=0.9,不下雨时带伞的概率 P(X2=1∣X1=0)=0.1P(X_2=1 \mid X_1=0) = 0.1P(X2=1∣X1=0)=0.1;
- X3X_3X3:“是否淋湿”(1=淋湿,0=不淋湿),已知“下雨且带伞”时淋湿的概率 P(X3=1∣X1=1,X2=1)=0.1P(X_3=1 \mid X_1=1, X_2=1) = 0.1P(X3=1∣X1=1,X2=1)=0.1,“下雨且不带伞”时淋湿的概率 P(X3=1∣X1=1,X2=0)=0.95P(X_3=1 \mid X_1=1, X_2=0) = 0.95P(X3=1∣X1=1,X2=0)=0.95。
现在计算“今天下雨、带伞、且淋湿”的联合概率 P(X1=1,X2=1,X3=1)P(X_1=1, X_2=1, X_3=1)P(X1=1,X2=1,X3=1),直接套用链式法则:
P(X1=1,X2=1,X3=1)=P(X1=1)×P(X2=1∣X1=1)×P(X3=1∣X1=1,X2=1)=0.3×0.9×0.1=0.027\begin{align*} P(X_1=1, X_2=1, X_3=1) &= P(X_1=1) \times P(X_2=1 \mid X_1=1) \times P(X_3=1 \mid X_1=1, X_2=1) \\ &= 0.3 \times 0.9 \times 0.1 \\ &= 0.027 \end{align*} P(X1=1,X2=1,X3=1)=P(X1=1)×P(X2=1∣X1=1)×P(X3=1∣X1=1,X2=1)=0.3×0.9×0.1=0.027
通过拆解将复杂的联合概率转化为三个已知简单概率的乘积,大幅降低了计算难度。
∏\prod∏ 表示(后面有说明这个符号的含义)
概率链式法则(联合概率分解)
用于将多个随机变量的联合概率分解为条件概率的连乘,适用于随机变量序列 X1,X2,…,XnX_1, X_2, \dots, X_nX1,X2,…,Xn。
当 k=1k=1k=1 时,X1,…,X0X_1, \dots, X_{0}X1,…,X0 表示“无前置变量”,此时 P(X1∣X1,…,X0)=P(X1)P(X_1 \mid X_1, \dots, X_0) = P(X_1)P(X1∣X1,…,X0)=P(X1),即边缘概率)
-
公式:
P(X1,X2,…,Xn)=∏k=1nP(Xk∣X1,X2,…,Xk−1)P(X_1, X_2, \dots, X_n) = \prod_{k=1}^n P\left(X_k \mid X_1, X_2, \dots, X_{k-1}\right)P(X1,X2,…,Xn)=k=1∏nP(Xk∣X1,X2,…,Xk−1) -
解读:
联合概率 P(X1,…,Xn)P(X_1, \dots, X_n)P(X1,…,Xn) 等于从第1个变量到第n个变量的条件概率连乘:- 第1项(k=1k=1k=1):P(X1)P(X_1)P(X1)(无前置条件,直接是 X1X_1X1 的边缘概率);
- 第2项(k=2k=2k=2):P(X2∣X1)P(X_2 \mid X_1)P(X2∣X1)(依赖第1个变量);
- 第3项(k=3k=3k=3):P(X3∣X1,X2)P(X_3 \mid X_1, X_2)P(X3∣X1,X2)(依赖前2个变量);
- …
- 第n项(k=nk=nk=n):P(Xn∣X1,…,Xn−1)P(X_n \mid X_1, \dots, X_{n-1})P(Xn∣X1,…,Xn−1)(依赖前n-1个变量)。
展开后为:
P(X1,…,Xn)=P(X1)×P(X2∣X1)×P(X3∣X1,X2)×⋯×P(Xn∣X1,…,Xn−1)P(X_1, \dots, X_n) = P(X_1) \times P(X_2 \mid X_1) \times P(X_3 \mid X_1, X_2) \times \dots \times P(X_n \mid X_1, \dots, X_{n-1})P(X1,…,Xn)=P(X1)×P(X2∣X1)×P(X3∣X1,X2)×⋯×P(Xn∣X1,…,Xn−1)
条件概率链式法则(给定条件下的联合概率分解)
用于将**“给定某个条件变量”时的联合概率**分解为条件概率的连乘,适用于随机变量序列 X1,X2,…,XnX_1, X_2, \dots, X_nX1,X2,…,Xn 和条件变量 YYY。
当 k=1k=1k=1 时,P(X1∣X1,…,X0,Y)=P(X1∣Y)P(X_1 \mid X_1, \dots, X_0, Y) = P(X_1 \mid Y)P(X1∣X1,…,X0,Y)=P(X1∣Y),即给定Y时X1X_1X1的条件概率)
-
公式:
P(X1,X2,…,Xn∣Y)=∏k=1nP(Xk∣X1,X2,…,Xk−1,Y)P(X_1, X_2, \dots, X_n \mid Y) = \prod_{k=1}^n P\left(X_k \mid X_1, X_2, \dots, X_{k-1}, Y\right)P(X1,X2,…,Xn∣Y)=k=1∏nP(Xk∣X1,X2,…,Xk−1,Y) -
解读:
给定Y时,X1,…,XnX_1, \dots, X_nX1,…,Xn 的联合条件概率等于“每个变量依赖于前置变量和Y”的条件概率连乘:- 第1项(k=1k=1k=1):P(X1∣Y)P(X_1 \mid Y)P(X1∣Y)(仅依赖条件变量Y);
- 第2项(k=2k=2k=2):P(X2∣X1,Y)P(X_2 \mid X_1, Y)P(X2∣X1,Y)(依赖第1个变量和Y);
- 第3项(k=3k=3k=3):P(X3∣X1,X2,Y)P(X_3 \mid X_1, X_2, Y)P(X3∣X1,X2,Y)(依赖前2个变量和Y);
- …
- 第n项(k=nk=nk=n):P(Xn∣X1,…,Xn−1,Y)P(X_n \mid X_1, \dots, X_{n-1}, Y)P(Xn∣X1,…,Xn−1,Y)(依赖前n-1个变量和Y)。
展开后为:
P(X1,…,Xn∣Y)=P(X1∣Y)×P(X2∣X1,Y)×⋯×P(Xn∣X1,…,Xn−1,Y)P(X_1, \dots, X_n \mid Y) = P(X_1 \mid Y) \times P(X_2 \mid X_1, Y) \times \dots \times P(X_n \mid X_1, \dots, X_{n-1}, Y)P(X1,…,Xn∣Y)=P(X1∣Y)×P(X2∣X1,Y)×⋯×P(Xn∣X1,…,Xn−1,Y)
两个法则的区别
法则类型 | 连乘符号公式的核心差异 | 适用场景 |
---|---|---|
概率链式法则 | 每个项的条件仅包含“前置变量” | 直接分解多个变量的联合概率 |
条件概率链式法则 | 每个项的条件包含“前置变量+全局条件Y” | 分解“给定Y时”的联合概率 |
连乘符号(∏\prod∏)
一个规则:
∏下标=下限上限被乘项\prod_{\text{下标}=\text{下限}}^{\text{上限}} \text{被乘项}∏下标=下限上限被乘项 → 表示“下标从下限到上限变化时,所有被乘项的乘积”。
它的作用和求和符号(∑\sum∑)完全对应,只是将“加法”换成了“乘法”,是为了简化“重复运算”的书写。
最基础的连乘——n的阶乘(Factorial)
阶乘是连乘符号最经典的应用,表示从1到n的所有正整数的乘积。
- 公式:
n!=∏k=1nkn! = \prod_{k=1}^n kn!=k=1∏nk - 解读:
- 连乘下限:k=1k=1k=1(从1开始),上限:k=nk=nk=n(到n结束);
- 被乘项:kkk(每一项都是当前的k值);
- 展开含义:n!=1×2×3×⋯×(n−1)×nn! = 1 \times 2 \times 3 \times \dots \times (n-1) \times nn!=1×2×3×⋯×(n−1)×n。
比如当n=5n=5n=5时,5!=∏k=15k=1×2×3×4×5=1205! = \prod_{k=1}^5 k = 1×2×3×4×5 = 1205!=∏k=15k=1×2×3×4×5=120。
概率论——独立事件的联合概率
若多个事件相互独立(一个事件的发生不影响其他事件),则它们的“同时发生概率”(联合概率)等于每个事件概率的连乘。
- 公式:
P(A1∩A2∩⋯∩An)=∏i=1nP(Ai)P(A_1 \cap A_2 \cap \dots \cap A_n) = \prod_{i=1}^n P(A_i)P(A1∩A2∩⋯∩An)=i=1∏nP(Ai) - 解读:
- A1,A2,…,AnA_1, A_2, \dots, A_nA1,A2,…,An 是相互独立的随机事件;
- 左边:事件“A1A_1A1且A2A_2A2且…且AnA_nAn”同时发生的概率;
- 右边:连乘表示“每个事件概率的乘积”,展开为 P(A1)×P(A2)×⋯×P(An)P(A_1) \times P(A_2) \times \dots \times P(A_n)P(A1)×P(A2)×⋯×P(An)。
比如掷3次独立的硬币,“3次都正面”的概率:P(H1∩H2∩H3)=P(H1)×P(H2)×P(H3)=12×12×12=18P(H_1 \cap H_2 \cap H_3) = P(H_1)×P(H_2)×P(H_3) = \frac{1}{2}×\frac{1}{2}×\frac{1}{2} = \frac{1}{8}P(H1∩H2∩H3)=P(H1)×P(H2)×P(H3)=21×21×21=81。
线性代数——n阶行列式的定义
行列式是矩阵的核心属性,其定义本质是“所有置换对应的项的代数和”,其中每一项包含一个连乘。
- 公式:
det(A)=∑σ∈Sn(−1)sign(σ)∏i=1nai,σ(i)\det(A) = \sum_{\sigma \in S_n} (-1)^{\text{sign}(\sigma)} \prod_{i=1}^n a_{i, \sigma(i)}det(A)=σ∈Sn∑(−1)sign(σ)i=1∏nai,σ(i) - 解读:
- AAA 是n阶矩阵,ai,ja_{i,j}ai,j 是矩阵第i行第j列的元素;
- SnS_nSn 表示“n个元素的所有置换集合”(比如n=3时,S3S_3S3有6种置换);
- σ(i)\sigma(i)σ(i) 是置换函数(表示第i行元素对应的列索引);
- 连乘部分:∏i=1nai,σ(i)\prod_{i=1}^n a_{i, \sigma(i)}∏i=1nai,σ(i) 表示“每个行i选一个不同列σ(i)\sigma(i)σ(i)的元素相乘”(确保每行每列仅选一个元素);
- 整体含义:对所有置换对应的“连乘项”乘以符号((−1)sign(σ)(-1)^{\text{sign}(\sigma)}(−1)sign(σ))后求和,即行列式的值。
比如2阶矩阵A=(abcd)A=\begin{pmatrix}a&b\\c&d\end{pmatrix}A=(acbd),行列式为:det(A)=a×d−b×c\det(A) = a×d - b×cdet(A)=a×d−b×c(对应2种置换的连乘项求和)。
组合数学——排列数公式
排列数表示“从n个元素中选k个进行有序排列”的方案数,其公式可通过连乘符号简化。
- 公式:
P(n,k)=∏i=0k−1(n−i)P(n, k) = \prod_{i=0}^{k-1} (n - i)P(n,k)=i=0∏k−1(n−i) - 解读:
- 连乘下限:i=0i=0i=0,上限:i=k−1i=k-1i=k−1(共k项相乘);
- 被乘项:n−in - in−i(每一项比前一项少1);
- 展开含义:P(n,k)=n×(n−1)×(n−2)×⋯×(n−k+1)P(n, k) = n \times (n-1) \times (n-2) \times \dots \times (n - k + 1)P(n,k)=n×(n−1)×(n−2)×⋯×(n−k+1)。
比如从5个元素中选3个排列:P(5,3)=5×4×3=60P(5, 3) = 5×4×3 = 60P(5,3)=5×4×3=60(对应i=0,1,2i=0,1,2i=0,1,2时的连乘:5−0=55-0=55−0=5,5−1=45-1=45−1=4,5−2=35-2=35−2=3)。
无穷连乘——正弦函数的乘积展开
连乘符号不仅限于“有限项”,还可表示“无穷项的乘积”(需满足收敛条件),典型例子是正弦函数的无穷乘积展开。
- 公式:
sin(πx)=πx∏n=1∞(1−x2n2)\sin(\pi x) = \pi x \prod_{n=1}^{\infty} \left(1 - \frac{x^2}{n^2}\right)sin(πx)=πxn=1∏∞(1−n2x2) - 解读:
- 连乘上限:n=∞n=\inftyn=∞(无穷多项相乘);
- 被乘项:1−x2n21 - \frac{x^2}{n^2}1−n2x2;
- 展开含义:sin(πx)=πx×(1−x212)×(1−x222)×(1−x232)×…\sin(\pi x) = \pi x \times \left(1 - \frac{x^2}{1^2}\right) \times \left(1 - \frac{x^2}{2^2}\right) \times \left(1 - \frac{x^2}{3^2}\right) \times \dotssin(πx)=πx×(1−12x2)×(1−22x2)×(1−32x2)×…。
这是数学分析中著名的欧拉乘积公式,可用于近似计算正弦函数值(项数越多,精度越高)。
与微积分“链式法则”的区别
虽然名称中都有“链式”,但概率链式法则与微积分链式法则完全是两个领域的概念
对比维度 | 概率链式法则 | 微积分链式法则 |
---|---|---|
适用领域 | 概率论(概率计算) | 微积分(导数计算) |
核心作用 | 分解联合概率为条件概率乘积 | 求复合函数的导数(如 dydx=dydu⋅dudx\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx}dxdy=dudy⋅dxdu) |
操作对象 | 随机变量/事件的概率 | 函数的导数 |