当前位置: 首页 > web >正文

python打卡day37@浙大疏锦行

知识点回顾:

  1. 过拟合的判断:测试集和训练集同步打印指标
  2. 模型的保存和加载
    1. 仅保存权重
    2. 保存权重和模型
    3. 保存全部信息checkpoint,还包含训练状态
  3. 早停策略

作业:对信贷数据集训练后保存权重,加载权重后继续训练50轮,并采取早停策略

一、首先实现模型训练和保存权重

import torch
from model import CreditRiskModel  # 假设这是您的模型类# ... 数据加载代码 ...def train_model():model = CreditRiskModel(input_size=30)  # 根据实际特征数调整criterion = nn.BCELoss()optimizer = torch.optim.Adam(model.parameters())# 训练循环for epoch in range(100):# ... 训练代码 ...# 保存权重if epoch % 10 == 0:torch.save(model.state_dict(), f'weights/epoch_{epoch}.pth')# 最终保存torch.save(model.state_dict(), 'weights/final_weights.pth')

二、加载权重并继续训练50轮

from train import train_model  # 导入之前的训练函数
from model import CreditRiskModeldef load_and_resume():# 初始化模型model = CreditRiskModel(input_size=30)# 加载保存的权重checkpoint = torch.load('weights/final_weights.pth')model.load_state_dict(checkpoint)# 继续训练50轮optimizer = torch.optim.Adam(model.parameters())for epoch in range(50):# ... 训练代码 ...

三、实现早停策略

class EarlyStopping:def __init__(self, patience=5, delta=0):self.patience = patienceself.delta = deltaself.counter = 0self.best_score = Noneself.early_stop = Falsedef __call__(self, val_loss):score = -val_lossif self.best_score is None:self.best_score = scoreelif score < self.best_score + self.delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_score = scoreself.counter = 0

四、整合到训练代码中

from utils.early_stopping import EarlyStoppingdef train_model():# ... 之前的初始化代码 ...early_stopping = EarlyStopping(patience=5)for epoch in range(100):# ... 训练代码 ...val_loss = validate(model)  # 假设有验证函数# 早停检查early_stopping(val_loss)if early_stopping.early_stop:print(f"Early stopping triggered at epoch {epoch}")break# 保存最佳模型if early_stopping.counter == 0:torch.save(model.state_dict(), 'weights/best_weights.pth')

①权重保存方式:

# 仅保存权重
torch.save(model.state_dict(), 'model_weights.pth')# 保存整个模型
torch.save(model, 'full_model.pth')# 保存checkpoint(包含优化器状态等)
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,
}, 'checkpoint.ckpt')

②加载方式对应:

# 加载权重
model.load_state_dict(torch.load('model_weights.pth'))# 加载整个模型
model = torch.load('full_model.pth')# 加载checkpoint
checkpoint = torch.load('checkpoint.ckpt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

http://www.xdnf.cn/news/9061.html

相关文章:

  • tc3975开发板上有ft2232这块的电路,我想知道这个开发板有哪些升级方式,重点关注是怎样通过ft2232实现的烧录升级的
  • 单片机上按键功能通常都是用什么方法写?
  • 《DeepSeek行业应用全景指南(视频微课版)》:从入门到精通的AI落地实践手册
  • 2025年文件加密软件——数据保险箱,为您的文件上锁
  • DIY 自己的 MCP 服务-核心概念、基本协议、一个例子(Python)
  • 在 Windows 系统下使用 Qt 配置 OpenCV 和 MySql
  • 游戏引擎学习第310天:利用网格划分完成排序加速优化
  • 小土堆pytorch--优化器
  • Spring AI系列之Spring AI 集成 ChromaDB 向量数据库
  • 【C++进阶篇】初识哈希
  • FFmpeg 4.3 H265 二十二.4,使用计算机摄像头,通过VCL软件, 模拟 监控摄像头 的 RTSP 流
  • @MySQL升级8.0.42(Ubuntu 22.04)-SOP
  • Flink核心概念小结
  • Spring AI 系列之一个很棒的 Spring AI 功能——Advisors
  • WeakAuras Lua Script [ICC BOSS 11 - Sindragosa]
  • 博图软件块的概述-块的结构详解
  • VR 展厅开启一场穿越时空的邂逅​
  • Java常用API
  • React从基础入门到高级实战:React 核心技术 - React 状态管理:Context 与 Redux
  • uniapp-商城-71-shop(4-商品列表,详情页中添加商品到购物车的处理)
  • 机器人工具中心点标定
  • 【Linux】网络--传输层--TCP协议基础
  • 深入浅出对抗学习:概念、攻击、防御与代码实践
  • Ansible常用模块
  • c++算法题
  • 【QT】对话框dialog类封装
  • Unity UGUI 中 InputField 组件处理拖拽超出文本框边界时自动滚动内容的核心协程
  • java虚拟机2
  • 高速通信时代的信号编码利器-PAM4技术解析
  • HTML 文件路径完全指南:相对路径、绝对路径解析与引用技巧