当前位置: 首页 > news >正文

训练一个线性模型

import tensorflow as tf
import pandas as pd# 读取数据
data = pd.read_csv('../data/line_fit_data.csv').values
# 划分训练集和测试集
x = data[:-10, 0]   #第一列排除后10行
y = data[:-10, 1]   #第二列排除后10行
x_test = data[-10:, 0] #第一列后10行
y_test = data[-10:, 1] #第二列后10行# 构建Sequential网络
model_net = tf.keras.models.Sequential()  # 实例化网络
model_net.add(tf.keras.layers.Dense(1, input_shape=(1, )))  # 添加全连接层
print(model_net.summary())# 构建损失函数
model_net.compile(loss='mse', optimizer=tf.keras.optimizers.SGD(learning_rate=0.5))# 模型训练
model_net.fit(x, y, verbose=1, epochs=20, validation_split=0.2)
pre = model_net.predict(x_test)# 利用均方误差进行模型评价
y_test = pd.DataFrame(y_test)
pre = pd.DataFrame(pre)
mse = (sum(y_test - pre) ** 2) / 10
print('均方误差为:', mse)

总结

model_net.add() :向模型中添加层,第一层需指定 `input_shape`       |
Dense(units=1) :定义全连接层 ,`units` 决定输出维度             |

`input_shape=(1,)` : 指定输入数据的形状 ,仅第一层需要,元组格式            |
model.summary(): 查看模型结构和参数数量           
 

**`units=1`**:输出维度为1(即该层只有1个神经元)。
  - **`input_shape=(1,)`**:指定输入数据的形状为 `(1,)`(即每个样本是1个数值)。

 

### **1. `model_net.compile()`:配置模型训练参数**
- **作用**:定义模型的损失函数、优化器和评估指标。
- **参数解析**:
  - **`loss='mse'`**:使用均方误差(Mean Squared Error)作为损失函数,适用于**回归任务**(如预测房价、温度等连续值)。
  - **`optimizer=tf.keras.optimizers.SGD(learning_rate=0.5)`**:
    - 优化器:随机梯度下降(Stochastic Gradient Descent, SGD)。
    - 学习率:`0.5`(较高的学习率,可能导致训练不稳定,需根据任务调整)。
  - **未显式指定 `metrics`**:如需要监控准确率等指标,可添加 `metrics=['mae']`(平均绝对误差)。

---

### **2. `model_net.fit()`:模型训练**
- **作用**:用训练数据拟合模型,更新权重参数。
- **参数解析**:
  - **`x, y`**:输入数据和标签(假设 `x` 是特征,`y` 是目标值)。
  - **`verbose=1`**:显示训练进度条(`0`=不显示,`1`=显示进度条,`2`=仅显示轮次结果)。
  - **`epochs=20`**:训练20轮(所有数据完整遍历一次为一轮)。
  - **`validation_split=0.2`**:从训练数据中自动划分20%作为验证集(例如,若 `x` 有100个样本,则80个用于训练,20个用于验证)。

 

**`pd.DataFrame()`** 是 Pandas 库中用于创建或转换数据为 **二维表格结构**(DataFrame)的函数。
- 这行代码的目的是将 `y_test`(可能是列表、NumPy 数组或其他格式)转换为 DataFrame,以便后续使用 Pandas 的功能(如数据操作、保存到文件、与其他 DataFrame 合并等)。

 

 

 

 

http://www.xdnf.cn/news/585901.html

相关文章:

  • Linux 线程(中)
  • 基于FPGA控制电容阵列与最小反射算法的差分探头优化设计
  • 使用pm2 部署react+nextjs项目到服务器
  • (Java基础笔记vlog)Java中常见的几种设计模式详解
  • java接口自动化(四) - 企业级代码管理工具Git的应用
  • 理解全景图像拼接
  • 动态网页爬取:Python如何获取JS加载的数据?
  • Jenkins与Maven的集成配置
  • C++中的string(1)简单介绍string中的接口用法以及注意事项
  • Web前端开发 - 制作简单的焦点图效果
  • 单例模式的运用
  • UniApp+Vue3微信小程序二维码生成、转图片、截图保存整页
  • uniapp实现的简约美观的票据、车票、飞机票模板
  • ffmpeg 转换视频格式
  • 【Windows】FFmpeg安装教程
  • 「Python教案」运算符的使用
  • 中国计算机学会——2024年9月等级考试5级——第四题、森森快递(贪心+线段树)
  • JavaScriptAPIs学习day3--事件高级
  • 破局制造业转型: R²AIN SUITE 提效实战教学
  • Unity3D 异步加载材质显示问题排查
  • Python安全密码生成器:告别弱密码的最佳实践
  • TRC20代币创建教程指南
  • 解决 IntelliJ IDEA 配置文件中文被转义问题
  • ClickHouse核心优势分析与场景实战
  • 论文流程图mermaid解决方案
  • uni-app学习笔记八-vue3条件渲染
  • 打卡Day34
  • 绕距#C语言
  • IP大科普:住宅IP、机房IP、原生IP、双ISP
  • Keepalived 与 LVS 集成及多实例配置详解