当前位置: 首页 > news >正文

梯度下降代码

整体流程

数据预处理:标准化->加一列全为1的偏置项

训练:梯度下降,将数学公式转换成代码

预测

模型代码 

import numpy as np# 标准化函数:对特征做均值-方差标准化
# 返回标准化后的特征、新数据的均值和标准差,用于后续预测def standard(feats):new_feats = np.copy(feats).astype(float)mean = np.mean(new_feats, axis=0)std = np.std(new_feats, axis=0)std[std == 0] = 1new_feats = (new_feats - mean) / stdreturn new_feats, mean, stdclass LinearRegression:def __init__(self, data, labels):# 对训练数据进行标准化new_data, mean, std = standard(data)# 存储用于预测的均值和标准差self.mean = meanself.std = std# 样本数 m 和 原始特征数 nm, n = new_data.shape# 在特征矩阵前加一列 1 作为偏置项X = np.hstack((np.ones((m, 1)), new_data))  # shape (m, n+1)self.X = X                # 训练特征 (m, n+1)self.y = labels           # 训练标签 (m, 1)self.m = m                # 样本数self.n = n + 1            # 特征数(含偏置)# 初始化参数 thetaself.theta = np.zeros((self.n, 1))def train(self, alpha, num_iterations=500):"""执行梯度下降:param alpha: 学习率:param num_iterations: 迭代次数:return: 学习到的 theta 和每次迭代的损失历史"""cost_history = []for _ in range(num_iterations):self.gradient_step(alpha)cost_history.append(self.cost_function())return self.theta, cost_historydef gradient_step(self, alpha):# 计算预测值predictions = self.X.dot(self.theta)          # shape (m,1)# 计算误差delta = predictions - self.y                  # shape (m,1)# 计算梯度并更新 thetagrad = (self.X.T.dot(delta)) / self.m         # shape (n+1,1)self.theta -= alpha * graddef cost_function(self):# 计算当前 theta 下的损失delta = self.X.dot(self.theta) - self.y       # shape (m,1)return float((delta.T.dot(delta)) / (2 * self.m))def predict(self, data):"""对新数据进行预测:param data: 新数据,shape (m_new, n):return: 预测值,shape (m_new, 1)"""# 确保输入为二维数组data = np.array(data, ndmin=2)# 使用训练时的均值和标准差进行标准化new_data = (data - self.mean) / self.std# 加入偏置项m_new = new_data.shape[0]X_new = np.hstack((np.ones((m_new, 1)), new_data))# 返回预测结果return X_new.dot(self.theta)

测试代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as pltfrom linear_regression import LinearRegression
data = pd.read_csv('../data/world-happiness-report-2017.csv')train_data = data.sample(frac = 0.8)
test_data = data.drop(train_data.index)
input_param_name = 'Economy..GDP.per.Capita.'
output_param_name = 'Happiness.Score'
# 取出城市gdp的值和对应的幸福指数
x_train = train_data[[input_param_name]].values
y_train = train_data[[output_param_name]].values
x_test = test_data[input_param_name].values
y_test = test_data[output_param_name].valuesnum_iterations = 500
learning_rate = 0.01
# 训练
# x_train是gdp值,y_train是幸福指数
linear_regression = LinearRegression(x_train,y_train)
# 梯度下降比率,训练轮数
(theta,cost_history) = linear_regression.train(learning_rate,num_iterations)print ('开始时的损失:',cost_history[0])
print ('训练后的损失:',cost_history[-1])plt.plot(range(num_iterations),cost_history)
plt.xlabel('Iter')
plt.ylabel('cost')
plt.title('GD')
plt.show()predictions_num = 100
# 最小值,最大值,多少个等间隔的数,然后做成列向量的形式
x_predictions = np.linspace(x_train.min(),x_train.max(),predictions_num).reshape(predictions_num,1)y_predictions = linear_regression.predict(x_predictions)plt.scatter(x_train,y_train,label='Train data')
plt.scatter(x_test,y_test,label='test data')
plt.plot(x_predictions,y_predictions,'r',label = 'Prediction')
plt.xlabel(input_param_name)
plt.ylabel(output_param_name)
plt.title('Happy')
plt.legend()
plt.show()

http://www.xdnf.cn/news/14257.html

相关文章:

  • fatdds:传输层SHM和DATA-SHARING的区别
  • 数据结构|基数排序及八个排序总结
  • Python爬虫入门
  • 使用veaury,在vue项目中运行react组件
  • 汉诺塔专题:P1760 通天之汉诺塔 题解 + Problem D: 汉诺塔 题解
  • AI写程序: 多线程网络扫描网段ip工具
  • STM32使用rand()生成随机数并显示波形
  • 【最后203篇系列】028 FastAPI的后台任务处理
  • JVM之经典垃圾回收器
  • C++数据结构与二叉树详解
  • Kubernetes》》k8s》》Namespace
  • ProfibusDP转ModbusRTU网关,流量计接入新方案!
  • React 中如何获取 DOM:用 useRef 操作非受控组件
  • 珈和科技:无人机技术赋能智慧农业,精准施肥与病虫害监控全面升级
  • Perf学习
  • 使用最新threejs复刻经典贪吃蛇游戏的3D版,附完整源码
  • Spring Boot配置文件优先级全解析:如何优雅覆盖默认配置?
  • 盲超分-双循环对比学习退化网络(蒸馏)
  • Cursor 生成java测试用例
  • k8s低版本1.15安装prometheus+grafana进行Spring boot数据采集
  • npx 的作用以及延伸知识(.bin目录,npm run xx 执行)
  • AI 推理框架详解,包含如COT、ReAct、LLM+P等的详细说明和分类整理,涵盖其原理、应用场景及对比分析
  • Linux 线程互斥
  • Power BI 中 EXCEPT() 函数的使用指南
  • 悟空CRM系统部署+迁移
  • Vue.directive自定义v-指令
  • 【AI部署】腾讯云GPU-常见故障—SadTalker的AI数字人视频—未来之窗超算中心 tb-lightly
  • JAVA中多线程的经典案例
  • 4.黑马学习笔记-SpringMVC(P43-P47)
  • 学习设计模式《一》——简单工厂