当前位置: 首页 > java >正文

【动手学深度学习】4.10 实战Kaggle比赛:预测房价


目录

    • 4.10 实战Kaggle比赛:预测房价
      • 1)数据预处理
      • 2)模型定义与训练
      • 3)模型评估与预测
      • 4)模型训练与预测提交
      • 5)示例超参数(可调)


4.10 实战Kaggle比赛:预测房价

数据来源:Kaggle房价预测比赛

.

1)数据预处理

读取数据

import pandas as pdtrain_data = pd.read_csv('../data/kaggle_house_pred_train.csv')
test_data = pd.read_csv('../data/kaggle_house_pred_test.csv')all_features = pd.concat((train_data.iloc[:, 1:-1], test_data.iloc[:, 1:]))

处理数值和类别特征

# 数值特征标准化
numeric_feats = all_features.dtypes[all_features.dtypes != 'object'].index
all_features[numeric_feats] = all_features[numeric_feats].apply(lambda x: (x - x.mean()) / x.std()
)
all_features[numeric_feats] = all_features[numeric_feats].fillna(0)# 类别特征独热编码
all_features = pd.get_dummies(all_features, dummy_na=True)

转换为张量

import torchn_train = train_data.shape[0]
X = torch.tensor(all_features[:n_train].values, dtype=torch.float32)
X_test = torch.tensor(all_features[n_train:].values, dtype=torch.float32)
y = torch.tensor(train_data.SalePrice.values, dtype=torch.float32).reshape(-1, 1)

.

2)模型定义与训练

模型定义

from torch import nndef get_net():net = nn.Sequential(nn.Linear(X.shape[1], 1))return net

损失函数:对数均方根误差(log RMSE)

import torch.nn.functional as Fdef log_rmse(net, features, labels):clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(F.mse_loss(torch.log(clipped_preds), torch.log(labels)))return rmse.item()

训练函数

def train(net, train_features, train_labels, test_features, test_labels,num_epochs, learning_rate, weight_decay, batch_size):train_ls, test_ls = [], []dataset = torch.utils.data.TensorDataset(train_features, train_labels)train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True)optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)for epoch in range(num_epochs):for X_batch, y_batch in train_iter:optimizer.zero_grad()loss = F.mse_loss(net(X_batch), y_batch)loss.backward()optimizer.step()train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls, test_ls

.

3)模型评估与预测

K折交叉验证

def get_k_fold_data(k, i, X, y):fold_size = X.shape[0] // kX_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat((X_train, X_part), 0)y_train = torch.cat((y_train, y_part), 0)return X_train, y_train, X_valid, y_validdef k_fold(k, X_train, y_train, num_epochs, lr, weight_decay, batch_size):train_l_sum, valid_l_sum = 0, 0for i in range(k):data = get_k_fold_data(k, i, X_train, y_train)net = get_net()train_ls, valid_ls = train(net, *data, num_epochs, lr, weight_decay, batch_size)train_l_sum += train_ls[-1]valid_l_sum += valid_ls[-1]print(f'Fold {i+1}, Train log rmse: {train_ls[-1]:.4f}, Valid log rmse: {valid_ls[-1]:.4f}')return train_l_sum / k, valid_l_sum / k

.

4)模型训练与预测提交

使用全部数据训练并预测

def train_and_pred(train_features, test_features, train_labels, test_data,num_epochs, lr, weight_decay, batch_size):net = get_net()train(net, train_features, train_labels, None, None, num_epochs, lr, weight_decay, batch_size)preds = net(test_features).detach().numpy()submission = pd.DataFrame({"Id": test_data.Id,"SalePrice": preds.flatten()})submission.to_csv('submission.csv', index=False)

.

5)示例超参数(可调)

k, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, X, y, num_epochs, lr, weight_decay, batch_size)
print(f'{k}-fold validation: Avg train log rmse: {train_l:.4f}, Avg valid log rmse: {valid_l:.4f}')
# 最终提交
train_and_pred(X, X_test, y, test_data, num_epochs, lr, weight_decay, batch_size)

.

总结:

  • 核心流程:数据预处理 → 建模 → K折验证 → 全数据训练 → 生成提交文件。

  • 模型简单但有效:线性回归 + 标准化 + One-Hot。

  • log_rmse 是比赛评分标准的重要转化。

.


声明:资源可能存在第三方来源,若有侵权请联系删除!

http://www.xdnf.cn/news/14960.html

相关文章:

  • S7-1500——(一)从入门到精通1、基于TIA 博途解析PLC程序结构(一)
  • 【04】MFC入门到精通——MFC 自己手动新添加对话框模板 并 创建对话框类
  • 从零开始学前端html篇2
  • React 编译器与性能优化:告别手动 Memoization
  • 网关助力航天喷涂:Devicenet与Modbus TCP的“跨界对话“
  • windows指定某node及npm版本下载
  • Linux入门篇学习——Linux 编写第一个自己的命令
  • 【TCP/IP】3. IP 地址
  • 250709-通过命令行上传模型文件到ModelsScope
  • yolo8实现目标检测
  • Mysql: Bin log原理以及三种格式
  • 权限分级看板管理:实时数据驱动决策的关键安全基石
  • python 在运行时没有加载修改后的版本
  • NLP:初识RNN模型(概念、分类、作用)
  • 从救火到赋能:运维的职责演进与云原生时代的未来图景
  • day10-Redis面试篇
  • SAP采购管理系统替代选谁?8Manage SRM全面优势测评与深度对比
  • Rust与人工智能(AI)技术
  • ✍️ Python 批量设置 Word 文档多级字体样式(标题/正文/名称/小节)
  • 【LeetCode 热题 100】136. 只出现一次的数字——异或
  • Pycharm 报错 Environment location directory is not empty 如何解决
  • Android ttyS2无法打开该如何配置 + ttyS0和ttyS1可以
  • 第1章 Excel界面环境与基础操作指南
  • springBoot使用XWPFDocument 和 LoopRowTableRenderPolicy 两种方式填充数据到word模版中
  • IT系统安全刚需:绝缘故障定位系统
  • 掌握PDF转CAD技巧,提升工程设计效率
  • iframe 的同源限制与反爬机制的冲突
  • [C语言初阶]操作符
  • HTML + CSS + JavaScript
  • uniapp+vue3+ts项目:实现小程序文件下载、预览、进度监听(含项目、案例、插件)