当前位置: 首页 > web >正文

机器学习模型训练模块技术文档

一、模块结构概览

import numpy as np
from sklearn.model_selection import cross_validate, learning_curve
from sklearn.pipeline import make_pipeline
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, accuracy_score, recall_score, f1_score
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.utils import shuffle
import os

依赖说明

  • numpy:处理数值计算

  • sklearn:提供机器学习算法和工具

  • matplotlib:可视化学习曲线

  • os:处理文件路径操作

二、核心类定义

2.1 类初始化

class ModelTrainer:def __init__(self):pass

功能:创建模型训练器的基础类,当前无需特殊初始化参数 

 

2.2 主训练方法 train_model

2.2.1 数据准备阶段
def train_model(self, X, y, output_dir="model_plots"):# 创建输出文件夹os.makedirs(output_dir, exist_ok=True)# 数据分割X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2,        # 20%测试集stratify=y,           # 保持类别分布random_state=42       # 可重复性种子)# 数据标准化scaler = StandardScaler()X_train_scaled = scaler.fit_transform(X_train)  # 训练集拟合+转换X_test_scaled = scaler.transform(X_test)        # 测试集仅转换# 合并标准化数据X_scaled = np.concatenate([X_train_scaled, X_test_scaled])y = np.concatenate([y_train, y_test])
 

关键技术点

  • stratify=y 保证分割后的数据保持原始类别分布

  • 标准化处理防止特征尺度差异影响模型性能

  • 合并数据集用于交叉验证

2.2.2 模型配置
models = {"Random Forest": RandomForestClassifier(n_estimators=200,  # 增加树数量提升模型容量max_depth=8,        # 限制深度防止过拟合n_jobs=-1          # 使用全部CPU核心),"Linear SVM": SVC(kernel='rbf',       # 选择径向基函数核C=0.5,             # 正则化强度参数gamma='auto',      # 自动计算gamma参数probability=True   # 启用概率估计),"KNN": KNeighborsClassifier(n_neighbors=3,     # 使用3近邻n_jobs=-1          # 并行计算)
}scoring = {'accuracy': make_scorer(accuracy_score),'recall': make_scorer(recall_score, average='macro'),  # 多分类宏平均'f1': make_scorer(f1_score, average='macro')
}
 

参数调优说明

  • 随机森林:通过限制max_depth平衡偏差-方差

  • SVM:调整C值控制正则化强度

  • KNN:小邻域数适合高维度数据

2.2.3 交叉验证流程
best_score = -1
best_model_name = ""
best_model = Nonefor name, model in models.items():# 交叉验证cv_results = cross_validate(model, X_scaled, y, cv=3,              # 3折交叉验证scoring=scoring    # 使用自定义指标)# 指标计算acc = np.mean(cv_results['test_accuracy'])rec = np.mean(cv_results['test_recall'])f1 = np.mean(cv_results['test_f1'])# 模型比较if f1 > best_score:best_score = f1best_model_name = namebest_model = model# 生成学习曲线self.plot_learning_curve(model, X_scaled, y, name, output_dir)

评估策略

  • 使用3折交叉验证降低数据划分敏感性

  • 以F1宏平均作为模型选择标准

  • 同步输出各模型指标的标准差

2.3 学习曲线绘制 plot_learning_curve

2.3.1 数据计算

def plot_learning_curve(self, model, X, y, model_name, output_dir):train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=3,               # 3折交叉验证scoring='accuracy', # 使用准确率指标n_jobs=-1          # 并行计算)# 统计量计算train_mean = np.mean(train_scores, axis=1)train_std = np.std(train_scores, axis=1)test_mean = np.mean(test_scores, axis=1)test_std = np.std(test_scores, axis=1)
2.3.2 可视化实现
    plt.figure(figsize=(8, 6))plt.fill_between(train_sizes,train_mean - train_std,train_mean + train_std,alpha=0.1, color="r")plt.plot(train_sizes, train_mean, 'o-', color="r", label="Training score")# 测试集曲线同理...plt.title(f"Learning Curve ({model_name})")plt.xlabel("Training Examples")plt.ylabel("Accuracy Score")plt.legend(loc="best")# 保存图像output_path = os.path.join(output_dir, f"{model_name}_learning_curve.png")plt.savefig(output_path)plt.close()
 

可视化分析

  • 阴影区域表示±1标准差范围

  • 训练曲线(红色)与验证曲线(绿色)对比

  • 图像尺寸设为8x6英寸保证可读性

三、使用流程示例

# 示例数据
X, y = load_your_data()  # 需自定义数据加载方法# 初始化训练器
trainer = ModelTrainer()# 执行训练
best_model = trainer.train_model(X, y,output_dir="my_models"  # 指定输出目录
)# 使用最佳模型预测
predictions = best_model.predict(new_data)

四、输出文件结构


model_plots/
├── Random Forest_learning_curve.png
├── Linear SVM_learning_curve.png
└── KNN_learning_curve.png

图像展示模型的学习过程,帮助诊断欠/过拟合问题

 

http://www.xdnf.cn/news/4024.html

相关文章:

  • AVHRR中国积雪物候数据集(1980-2020年)
  • yolo 用roboflow标注的数据集本地训练 kaggle训练 comet使用 训练笔记5
  • FISCO BCOS【初体验笔记】
  • Python 闭包:函数式编程中的魔法变量容器
  • ciscn_2019_c_1
  • 普联的AC100+AP+易展路由组网方案的一些问题
  • docker介绍以及安装
  • sherpa-ncnn:Linux_x86交叉编译Linux_arm32上的sherpa-ncnn -- 语音转文本大模型
  • 蓝桥杯单片机备战笔记
  • 【中间件】brpc_基础_TimerThread
  • 五一假期作业
  • springboot单体项目的执行流程
  • LFU算法解析
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】4.5 清洗流程自动化(存储过程/定时任务)
  • 【中间件】brpc_基础_单例
  • FreeRTOS学习系列·二值信号量
  • Linux查询日志常用命令
  • 解锁现代健康密码:科学养生新主张
  • 基于PLC的换热器温度控制系统设计
  • 状态模式(State Pattern)
  • 电子商务商家后台运营专员模板
  • C++ 中二级指针的正确释放方法
  • 【KWDB 创作者计划】_Ubuntu 22.04系统KWDB数据库安装部署使用教程
  • Qt中的UIC
  • Amazon Bedrock Converse API:开启对话式AI新体验
  • Qt开发:容器组控件的介绍和使用
  • 20、数据可视化:魔镜报表——React 19 图表集成
  • 408考研逐题详解:2009年第8题
  • Java后端程序员学习前端之CSS
  • Python matplotlib 成功使用SimHei 中文字体