当前位置: 首页 > ai >正文

Python训练营---Day31

知识点回顾
  1. 规范的文件命名
  2. 规范的文件夹管理
  3. 机器学习项目的拆分
  4. 编码格式和类型注解

作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。

拆分后的目录结构如下:

1、src/data/load_data.py

# -*- coding: utf-8 -*-
import sys
import io
# 设置标准输出为 UTF-8 编码
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
import pandas as pddef load_data(file_path: str) -> pd.DataFrame:"""加载数据文件Args:file_path: 数据文件路径Returns:加载的数据框"""return pd.read_csv(file_path)if __name__ == "__main__":# 测试代码data = load_data("testDay31/data/raw/heart.csv")print("数据读取完成!") 

2、src/data/preprocessing.py

# -*- coding: utf-8 -*-
import sys
import io
import os
# 设置标准输出为 UTF-8 编码
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')import pandas as pd
import numpy as np
from typing import Tuple, Dictdef load_data(file_path: str) -> pd.DataFrame:"""加载数据文件Args:file_path: 数据文件路径Returns:加载的数据框"""return pd.read_csv(file_path)# 仅以处理缺失值为例
def handle_missing_values(data: pd.DataFrame) -> pd.DataFrame:"""处理缺失值Args:data: 包含缺失值的数据框Returns:处理后的数据框"""data_clean = data.copy()continuous_features = data.select_dtypes(include=['int64', 'float64']).columns.tolist()for feature in continuous_features:mode_value = data[feature].mode()[0]data_clean[feature].fillna(mode_value, inplace=True)return data_cleanif __name__ == "__main__":# 测试代码data = load_data("testDay31/data/raw/heart.csv")# data_encoded, mappings = encode_categorical_features(data)data_clean = handle_missing_values(data)print("数据预处理完成!") 

3、models/train.py

# -*- coding: utf-8 -*-
import sys
import os
import io
# 设置标准输出为 UTF-8 编码
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import time
import joblib # 用于保存模型
from typing import Tuple # 用于类型注解from data.preprocessing import  load_data,handle_missing_values
# from data.load_data import load_datadef prepare_data() -> Tuple:"""准备训练数据Returns:训练集和测试集的特征和标签"""# 加载和预处理数据data = load_data("testDay31/data/raw/heart.csv")data_clean = handle_missing_values(data)# 分离特征和标签X = data_clean.drop(['target'], axis=1)y = data_clean['target']# 划分训练集和测试集X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)return X_train, X_test, y_train, y_testdef train_model(X_train, y_train, model_params=None) -> RandomForestClassifier:"""训练随机森林模型Args:X_train: 训练特征y_train: 训练标签model_params: 模型参数字典Returns:训练好的模型"""if model_params is None:model_params = {'random_state': 42}model = RandomForestClassifier(**model_params)model.fit(X_train, y_train)return modeldef evaluate_model(model, X_test, y_test) -> None:"""评估模型性能Args:model: 训练好的模型X_test: 测试特征y_test: 测试标签"""y_pred = model.predict(X_test)print("\n分类报告:")print(classification_report(y_test, y_pred))print("\n混淆矩阵:")print(confusion_matrix(y_test, y_pred))def save_model(model, model_path: str) -> None:"""保存模型Args:model: 训练好的模型model_path: 模型保存路径"""os.makedirs(os.path.dirname(model_path), exist_ok=True)joblib.dump(model, model_path)print(f"\n模型已保存至: {model_path}")if __name__ == "__main__":# 准备数据X_train, X_test, y_train, y_test = prepare_data()# 记录开始时间start_time = time.time()# 训练模型model = train_model(X_train, y_train)# 记录结束时间end_time = time.time()print(f"\n训练耗时: {end_time - start_time:.4f} 秒")# 评估模型evaluate_model(model, X_test, y_test)# 保存模型save_model(model, "testDay31/models/random_forest_model.joblib") 

http://www.xdnf.cn/news/7650.html

相关文章:

  • 大模型幻觉
  • CAN总线
  • mbed驱动st7789屏幕-硬件选择及连接(1)
  • TDengine 更多安全策略
  • (二十四)Java网络编程全面解析:从基础到实践
  • 基于python的花卉识别系统
  • Playwright+Next.js:实例演示服务器端 API 模拟新方法
  • 从私有化到容器云:iVX 研发基座的高校智慧校园部署运维全解析
  • 多头注意力机制和单注意力头多输出的区别
  • 大型商业综合体AI智能保洁管理系统:开启智能保洁新时代
  • 麒麟系统编译osg —— 扩展篇
  • 02 if...else,switch,do..while,continue,break
  • DevExpressWinForms-XtraMessageBox-定制和汉化
  • 【python进阶知识】Day 31 文件的规范拆分和写法
  • vLLM框架高效原因分析
  • IntentUri页面跳转
  • 常见的 API 及相关知识总结
  • 如何查看Python内置函数列表
  • 面试之MySQL慢查询优化干货分享
  • AT2659S低噪声放大器芯片:1.4-3.6V宽电压供电,集成50Ω匹配
  • springboot+vue实现服装商城系统(带用户协同过滤个性化推荐算法)
  • 使用SFunction获取属性名,减少嵌入硬编码
  • 初识Linux 进程:进程创建、终止与进程地址空间
  • js绑定事件
  • RabbitMQ ⑤-顺序性保障 || 消息积压 || 幂等性
  • 在CuPy中使用多节点多GPU环境
  • C#基础:yield return关键字的特点
  • 2025ICPC武汉邀请赛-F
  • 游戏启动DLL文件缺失怎么解决 解决dll问题的方法
  • Vue学习路线