30天打牢数模基础-灰色预测模型讲解
案例代码实现
一、代码说明
本代码实现了灰色预测模型(GM(1,1)),用于解决小样本(4-10个数据点)、单调趋势序列的预测问题。以某企业2024年1-4月能源消耗数据为例,预测5月的能耗,并包含模型精度评估(后验差比)和残差修正等核心步骤。
二、完整Python代码
import numpy as np
import math
from matplotlib import pyplot as pltdef ago(x0):"""累加生成(AGO)函数:将原始序列累加,增强趋势性参数:x0 - 原始序列(numpy数组,正数、单调)返回:x1 - 累加序列(numpy数组)"""return np.cumsum(x0)def adjacent_mean(x1):"""计算紧邻均值序列:用于将离散数据与连续微分方程联系起来参数:x1 - 累加序列(numpy数组)返回:z1 - 紧邻均值序列(numpy数组),长度为len(x1)-1"""return (x1[1:] + x1[:-1]) / 2def least_squares(x0, z1):"""最小二乘法求解GM(1,1)模型参数a(发展系数)和b(灰作用量)参数:x0 - 原始序列(numpy数组);z1 - 紧邻均值序列(numpy数组)返回:a - 发展系数;b - 灰作用量"""# 构造Y矩阵(原始序列从第二个元素开始,形状为(n-1,1))Y = x0[1:].reshape(-1, 1)# 构造B矩阵(第一列为-z1,第二列为全1,形状为(n-1,2))B = np.hstack([-z1.reshape(-1, 1), np.ones_like(z1).reshape(-1, 1)])# 计算(B^T * B)的逆矩阵try:B_T_B = B.T @ Binv_B_T_B = np.linalg.inv(B_T_B)except np.linalg.LinAlgError:raise ValueError("矩阵不可逆,请检查数据是否符合要求(单调、正数)")# 求解参数θ = [a, b]^Ttheta = inv_B_T_B @ B.T @ Yreturn theta[0, 0], theta[1, 0]def gm11_model(a, b, x1_0):"""构建GM(1,1)累加序列预测模型参数:a - 发展系数;b - 灰作用量;x1_0 - 累加序列初始值(x0[0])返回:model - 预测函数,输入t(时间点,从1开始),返回累加值x1(t)"""def model(t):if t < 1:raise ValueError("t必须≥1(1对应第一个数据点)")return (x1_0 - b / a) * math.exp(-a * (t - 1)) + b / areturn modeldef predict_original(gm_model, n):"""从累加序列预测原始序列(逆累加生成,IAGO)参数:gm_model - GM(1,1)累加预测函数;n - 预测长度(包括已有数据)返回:x0_hat - 原始序列预测值(numpy数组)"""# 预测累加序列(t=1到n)x1_hat = np.array([gm_model(t) for t in range(1, n+1)])# 逆累加:x0_hat[0] = x1_hat[0],x0_hat[k] = x1_hat[k] - x1_hat[k-1](k≥1)x0_hat = np.concatenate([[x1_hat[0]], np.diff(x1_hat)])return x0_hatdef residual_correction(x0, x0_hat):"""残差修正:用残差序列建立GM(1,1)模型,修正预测值参数:x0 - 原始序列;x0_hat - 原始预测值返回:x0_hat_corrected - 修正后预测值;residual_model - 残差模型(用于未来预测)"""# 计算残差(从第二个点开始,第一个点残差为0)e0 = x0 - x0_hate0 = e0[1:] # 残差序列长度为len(x0)-1# 残差全0,无需修正if np.all(e0 == 0):print("残差序列全为0,无需修正")return x0_hat, None# 对残差做AGO和紧邻均值e1 = ago(e0)z_e1 = adjacent_mean(e1)# 求解残差模型参数try:a_e, b_e = least_squares(e0, z_e1)except ValueError as e:print(f"残差模型无法建立:{e},无需修正")return x0_hat, None# 构建残差累加预测模型residual_model = gm11_model(a_e, b_e, e1[0])# 预测残差原始值(对应原始序列t=2到t=len(x0))e1_hat = np.array([residual_model(t) for t in range(1, len(e0)+1)])e0_hat = np.concatenate([[e1_hat[0]], np.diff(e1_hat)])# 修正原始预测值(t≥2)x0_hat_corrected = x0_hat.copy()x0_hat_corrected[1:] += e0_hatreturn x0_hat_corrected, residual_modeldef model_test(x0, x0_hat):"""模型精度检验:计算相对误差、后验差比C、小误差概率P参数:x0 - 原始序列;x0_hat - 预测值返回:relative_errors - 相对误差(%);C - 后验差比;P - 小误差概率"""# 残差序列(从第二个点开始)e0 = x0 - x0_hate0 = e0[1:]n = len(e0)# 相对误差(%)relative_errors = [abs(e / x0[i+1]) * 100 for i, e in enumerate(e0)]# 后验差比C:残差标准差/原始数据标准差if n == 0:C = 0P = 0else:mu_e = np.mean(e0)sigma_e = np.std(e0, ddof=1) # 样本标准差mu_x = np.mean(x0)sigma_x = np.std(x0, ddof=1)C = sigma_e / sigma_x if sigma_x != 0 else 0# 小误差概率P:|e_k - μ_e| < 0.6745σ_x 的比例threshold = 0.6745 * sigma_xP = sum(abs(e - mu_e) < threshold for e in e0) / n# 输出检验结果print("\n=== 模型精度检验 ===")print(f"相对误差(%):{[round(re, 2) for re in relative_errors]}")print(f"后验差比C:{round(C, 2)}(C<0.35为优秀)")print(f"小误差概率P:{round(P, 2)}(P>0.95为优秀)")# 判断模型等级if C < 0.35 and P > 0.95:print("模型等级:优秀")elif 0.35 <= C < 0.5 and 0.8 <= P <= 0.95:print("模型等级:良好")elif 0.5 <= C < 0.65 and 0.7 <= P < 0.8:print("模型等级:可用")else:print("模型等级:不可用")return relative_errors, C, Pif __name__ == "__main__":# 1. 输入原始数据(某企业2024年1-4月能源消耗,吨标准煤)x0 = np.array([120, 150, 190, 240])months = np.arange(1, len(x0)+1) # 1-4月print(f"原始数据(1-4月):{x0}")# 2. 累加生成(AGO):增强趋势性x1 = ago(x0)print(f"累加序列:{x1.round(2)}")# 3. 计算紧邻均值:连接离散与连续z1 = adjacent_mean(x1)print(f"紧邻均值序列:{z1.round(2)}")# 4. 求解GM(1,1)参数a和ba, b = least_squares(x0, z1)print(f"\n模型参数:发展系数a={a.round(4)},灰作用量b={b.round(4)}")# 5. 构建累加预测模型gm_model = gm11_model(a, b, x1[0])print(f"累加序列模型:x1(t) = ({x1[0].round(2)} - {b.round(2)}/{a.round(2)}) * exp(-{a.round(2)}*(t-1)) + {b.round(2)}/{a.round(2)}")# 6. 预测原始序列(1-4月拟合值)x0_hat = predict_original(gm_model, len(x0))print(f"\n原始序列拟合值(1-4月):{x0_hat.round(2)}")# 7. 残差修正:提高预测精度x0_hat_corrected, residual_model = residual_correction(x0, x0_hat)if residual_model is not None:print(f"修正后拟合值(1-4月):{x0_hat_corrected.round(2)}")# 8. 模型精度检验model_test(x0, x0_hat_corrected if residual_model is not None else x0_hat)# 9. 预测5月能源消耗(未来1步)n_predict = 1 # 预测1个月(5月)total_months = len(x0) + n_predict # 1-5月# 预测累加序列(1-5月)x1_hat_all = np.array([gm_model(t) for t in range(1, total_months+1)])# 逆累加得到原始预测(1-5月)x0_hat_all = predict_original(gm_model, total_months)# 5月预测值(原模型)x0_hat_may = x0_hat_all[-n_predict:]print(f"\n5月能源消耗预测(原模型):{x0_hat_may.round(2)}吨标准煤")# 10. 残差修正未来预测(如果残差模型存在)if residual_model is not None:# 残差序列对应原始序列t=2-4,未来t=5对应残差模型t=4e1_hat_may = residual_model(4) # 残差累加预测(t=4)e1_prev = ago(x0[1:] - x0_hat[1:])[-1] # 残差累加前值(t=3)e0_hat_may = e1_hat_may - e1_prev # 残差原始预测(t=5)# 修正5月预测值x0_hat_may_corrected = x0_hat_may + e0_hat_mayprint(f"5月能源消耗预测(残差修正后):{x0_hat_may_corrected.round(2)}吨标准煤")# 11. 可视化结果(帮助理解趋势)plt.figure(figsize=(10, 6))plt.plot(months, x0, 'o-', label='原始数据', markersize=8)plt.plot(months, x0_hat_corrected if residual_model is not None else x0_hat, 's-', label='拟合值', markersize=8)plt.plot([5], x0_hat_may_corrected if residual_model is not None else x0_hat_may, '^-', label='预测值', markersize=10, color='red')plt.xlabel('月份', fontsize=12)plt.ylabel('能源消耗(吨标准煤)', fontsize=12)plt.title('GM(1,1)模型能源消耗预测', fontsize=14)plt.legend(fontsize=12)plt.grid(True, linestyle='--', alpha=0.7)plt.xticks(np.arange(1, 6)) # 显示1-5月plt.show()
三、代码使用说明
1. 环境准备
需要安装numpy(矩阵运算)、matplotlib(可视化):
pip install numpy matplotlib
2. 数据替换
将代码中x0替换为你的数据(必须是正数、单调序列,长度4-10):
# 示例:替换为你的数据(比如最近4次月考成绩)
x0 = np.array(80, 85, 90, 92)
3. 运行代码
直接运行脚本,会输出以下结果:
原始数据、累加序列、紧邻均值序列;
模型参数(a、b);
1-4月拟合值(修正前/后);
模型精度检验(相对误差、后验差比、小误差概率);
5月预测值(原模型/残差修正后);
可视化图表(原始数据、拟合值、预测值)。
4. 结果解读
相对误差:越小越好(一般<10%为可接受);
后验差比C:C<0.35为优秀,C<0.5为良好;
小误差概率P:P>0.95为优秀,P>0.8为良好;
预测值:残差修正后的预测值更接近实际趋势(如果残差有规律)。
四、示例结果(以能源消耗数据为例)
原始数据:120, 150, 190, 240
累加序列:120, 270, 460, 700
紧邻均值:195, 365, 580
模型参数:a≈-0.2337(负号表示指数增长),b≈104.52
1-4月拟合值:120, 149.2, 188.3, 238.2(修正后更接近原始值)
模型检验:相对误差<1%,C≈0.01(优秀),P=1.0(优秀)
5月预测值:原模型≈297.8吨,残差修正后≈299.7吨(更可靠)
五、注意事项
数据要求:必须是正数、单调序列(若有负数,可加常数调整);
数据量:4-10个点(太少模型不稳定,太多不如用ARIMA);
残差修正:若残差波动大(如正负交替),无需修正;
适用场景:适合预测单调增长/减少的序列(如销售额、能耗、人口)。
通过以上代码,你可以快速实现GM(1,1)模型,解决小样本预测问题。赶紧用自己的数据试试吧!