当前位置: 首页 > backend >正文

Python打卡训练营day31-文件拆分

知识点回顾
  1. 规范的文件命名
  2. 规范的文件夹管理
  3. 机器学习项目的拆分
  4. 编码格式和类型注解

作业:尝试针对之前的心脏病项目ipynb,将他按照今天的示例项目整理成规范的形式,思考下哪些部分可以未来复用。

在有多级目录时,相对导入仅在同一包内有效,尤其在下级文件导入上级文件夹中的文件

# src/config.pyCONFIG = {"data_path": PROJECT_ROOT / "data/raw/heart.csv","test_size": 0.2,"random_state": 42,"models": {"random_forest": {"n_estimators": 100,"max_depth": 5},"xgboost": {"learning_rate": 0.1,"max_depth": 3,"n_estimators": 200}}
}

# src/data/loader.py
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split
from src.config import CONFIGdef load_data() -> tuple:"""加载并拆分数据集"""df = pd.read_csv(CONFIG["data_path"])# 假设最后一列是目标变量X = df.iloc[:, :-1]y = df.iloc[:, -1]return train_test_split(X, y,test_size=CONFIG["test_size"],random_state=CONFIG["random_state"])

# src/models/base_model.py
from abc import ABC, abstractmethod
import pandas as pdclass BaseModel(ABC):"""所有模型的统一接口"""@abstractmethoddef train(self, X_train: pd.DataFrame, y_train: pd.Series):pass@abstractmethoddef predict(self, X_test: pd.DataFrame) -> pd.Series:pass@abstractmethoddef save(self, path: str):pass

# src/models/random_forest.py
from sklearn.ensemble import RandomForestClassifier
from .base_model import BaseModel
from src.config import CONFIGclass RandomForestModel(BaseModel):def __init__(self):self.model = RandomForestClassifier(n_estimators=CONFIG["models"]["random_forest"]["n_estimators"],max_depth=CONFIG["models"]["random_forest"]["max_depth"],random_state=CONFIG["random_state"])def train(self, X_train, y_train):self.model.fit(X_train, y_train)def predict(self, X_test):return self.model.predict(X_test)def save(self, path):joblib.dump(self.model, path)

# src/models/train.py
from .random_forest import RandomForestModel
from .xgboost_model import XGBoostModel
from src.data import loader
from src.evaluation import metrics
from src.utils import save_resultsdef train_all_models():X_train, X_test, y_train, y_test = loader.load_data()models = {"RandomForest": RandomForestModel(),"XGBoost": XGBoostModel()}results = {}for name, model in models.items():model.train(X_train, y_train)preds = model.predict(X_test)results[name] = metrics.calculate_all_metrics(y_test, preds)model.save(f"models/{name}_model.pkl")save_results(results)

# src/evaluation/metrics.py
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_scoredef calculate_all_metrics(y_true, y_pred) -> dict:return {"accuracy": accuracy_score(y_true, y_pred),"precision": precision_score(y_true, y_pred),"recall": recall_score(y_true, y_pred),"f1": f1_score(y_true, y_pred)}

# scripts/train_model.py
from src.models import trainif __name__ == "__main__":train.train_all_models()

http://www.xdnf.cn/news/8925.html

相关文章:

  • 【深度学习-Day 17】神经网络的心脏:反向传播算法全解析
  • 【工具变量】上市公司企业未来主业业绩数据集(2000-2023年)
  • 内存管理(第五、六章)
  • RV1126的RGA模块讲解
  • 7.Java String类深度解析:从不可变魔法到性能优化实战
  • 【电机控制】基于STM32F103C8T6的四轮智能车设计——直流有刷有感PID控制(硬件篇)
  • Java基础知识回顾
  • CLion-2025 嵌入式开发调试环境详细搭建
  • Mysql 中的锁
  • 2025京麒CTF挑战赛 计算器 WriteUP
  • 2024 CKA模拟系统制作 | Step-By-Step | 5、题目搭建-查看Pod CPU资源使用量
  • 滑动窗口算法:高效处理数组与字符串子序列问题的利器
  • (九)PMSM驱动控制学习---无感控制之高阶滑膜观测器
  • 61580 RT控制
  • SCI与EI期刊分区及影响因子汇总
  • 超越UniAD!百度哈工大X-Driver:基于视觉语言模型的可解释自动驾驶
  • 多线程的基础知识以及应用
  • 校园二手交易系统
  • AI预测3D新模型百十个定位预测+胆码预测+去和尾2025年5月25日第88弹
  • 法律大模型之阿里云通义法睿
  • DataX的json配置文件,{},[]讲解
  • 华硕FL8000U加装16G+32G=48G内存条
  • 英语六级-阅读篇
  • 分布式缓存:BASE理论实践指南
  • YOLOv1到YOLOv12各版本发展2025.5.25
  • Jetpack Compose 导航 (Navigation)
  • mysql 导入导出数据
  • Cache写策略
  • 【深度学习】1. 感知器,MLP, 梯度下降,激活函数,反向传播,链式法则
  • Unity3D 彩色打印