当前位置: 首页 > web >正文

打卡day36

一、数据准备与预处理

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler, OneHotEncoder, LabelEncoder
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
from tqdm import tqdm# 设置GPU设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 加载信贷预测数据集
data = pd.read_csv('data.csv')# 丢弃掉Id列
data = data.drop(['Id'], axis=1)# 区分连续特征与离散特征
continuous_features = data.select_dtypes(include=['float64', 'int64']).columns.tolist()
discrete_features = data.select_dtypes(exclude=['float64', 'int64']).columns.tolist()# 离散特征使用众数进行补全
for feature in discrete_features:if data[feature].isnull().sum() > 0:mode_value = data[feature].mode()[0]data[feature].fillna(mode_value, inplace=True)# 连续变量用中位数进行补全
for feature in continuous_features:if data[feature].isnull().sum() > 0:median_value = data[feature].median()data[feature].fillna(median_value, inplace=True)# 有顺序的离散变量进行标签编码
mappings = {"Years in current job": {"10+ years": 10,"2 years": 2,"3 years": 3,"< 1 year": 0,"5 years": 5,"1 year": 1,"4 years": 4,"6 years": 6,"7 years": 7,"8 years": 8,"9 years": 9},"Home Ownership": {"Home Mortgage": 0,"Rent": 1,"Own Home": 2,"Have Mortgage": 3},"Term": {"Short Term": 0,"Long Term": 1}
}# 使用映射字典进行转换
data["Years in current job"] = data["Years in current job"].map(mappings["Years in current job"])
data["Home Ownership"] = data["Home Ownership"].map(mappings["Home Ownership"])
data["Term"] = data["Term"].map(mappings["Term"])# 对没有顺序的离散变量进行独热编码
data = pd.get_dummies(data, columns=['Purpose'])

二、数据集划分与归一化

# 分离特征数据和标签数据
X = data.drop(['Credit Default'], axis=1)  # 特征数据
y = data['Credit Default']  # 标签数据# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 对特征数据进行归一化处理
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)  # 确保训练集和测试集是相同的缩放# 将数据转换为PyTorch张量
X_train = torch.FloatTensor(X_train).to(device)
y_train = torch.LongTensor(y_train.values).to(device)
X_test = torch.FloatTensor(X_test).to(device)
y_test = torch.LongTensor(y_test.values).to(device)

三、构建神经网络模型

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(X_train.shape[1], 64)  # 输入层到第一隐藏层self.relu = nn.ReLU()self.dropout = nn.Dropout(0.3)  # 添加Dropout防止过拟合self.fc2 = nn.Linear(64, 32)  # 第一隐藏层到第二隐藏层self.fc3 = nn.Linear(32, 2)  # 第二隐藏层到输出层def forward(self, x):x = self.fc1(x)x = self.relu(x)x = self.dropout(x)x = self.fc2(x)x = self.relu(x)x = self.dropout(x)x = self.fc3(x)return x# 初始化模型
model = MLP().to(device)

四、定义损失函数和优化器

criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数
optimizer = optim.SGD(model.parameters(), lr=0.01)  # 使用SGD优化器

五、训练模型

num_epochs = 200  # 训练轮数
for epoch in range(num_epochs):model.train()  # 设置为训练模式optimizer.zero_grad()  # 清空梯度outputs = model(X_train)  # 前向传播loss = criterion(outputs, y_train)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

六、评估模型

model.eval()  # 设置为评估模式
with torch.no_grad():correct = 0total = 0outputs = model(X_test)_, predicted = torch.max(outputs.data, 1)total += y_test.size(0)correct += (predicted == y_test).sum().item()accuracy = 100 * correct / total
print(f'Accuracy on test set: {accuracy:.2f}%')
http://www.xdnf.cn/news/8807.html

相关文章:

  • HUAWEI交换机配置镜像口验证(eNSP)
  • --legacy-peer-deps 是什么意思
  • 【不背八股】1.if __name__ == “__main__“ 有什么作用?
  • 【redis】redis和hiredis的基本使用
  • RabbitMQ 可靠性保障:消息确认与持久化机制(一)
  • day01
  • 算法打卡第六天
  • C++23 对部分特性的 constexpr 支持
  • 历年华南理工大学保研上机真题
  • 阿里千问系列:Qwen3技术报告解读(下)
  • 美团2025年校招笔试真题手撕教程(二)
  • 第一章 半导体基础知识
  • 腾讯云国际站可靠性测试
  • 13软件测试用例设计方法-场景法
  • UnLua源码分析(二)IUnLuaInterface
  • 并发编程(6)
  • Lua5.4.2常用API整理记录
  • 基于Python的分布式网络爬虫系统设计与实现
  • DAY33 简单神经网络
  • MongoDB 错误处理与调试完全指南:从入门到精通
  • 字符集和字符编码
  • 使用Arduino UNO复活电脑的风扇
  • CI/CD (持续集成/持续部署) GitHub Actions 自动构建
  • 【Linux】进程问题--僵尸进程
  • Github Actions工作流入门
  • 详解3DGS
  • MySQL---库操作
  • 深入解析MongoDB WiredTiger存储引擎:原理、优势与最佳实践
  • 如何通过API接口实现自动化上货跨平台铺货?商品采集|商品上传实现详细步骤
  • 论文阅读:PURPLE: Making a Large Language Model a Better SQL Writer