当前位置: 首页 > news >正文

【深度学习-pytorch篇】4. 正则化方法(Regularization Techniques)

正则化方法(Regularization Techniques)

1. 目标

  • 理解什么是过拟合及其影响
  • 掌握常见正则化技术:L2 正则化、Dropout、Batch Normalization、Early Stopping
  • 能够使用 PyTorch 编程实现这些正则化方法并进行比较分析

2. 数据构造与任务设定

本实验是一个带噪声的回归任务,目标函数为 y = x + N ( 0 , σ 2 ) y = x + \mathcal{N}(0, \sigma^2) y=x+N(0,σ2)。使用均匀分布采样输入 x ∈ [ − 1 , 1 ] x \in [-1, 1] x[1,1]

import numpy as np
import torch
import torch.utils.data as DataN_SAMPLES = 20
NOISE_RATE = 0.4train_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
train_y = train_x + np.random.normal(0, NOISE_RATE, train_x.shape)validate_x = np.linspace(-1, 1, N_SAMPLES // 2)[:, np.newaxis]
validate_y = validate_x + np.random.normal(0, NOISE_RATE, validate_x.shape)test_x = np.linspace(-1, 1, N_SAMPLES)[:, np.newaxis]
test_y = test_x + np.random.normal(0, NOISE_RATE, test_x.shape)# 转换为 Tensor
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
validate_x = torch.tensor(validate_x, dtype=torch.float32)
validate_y = torch.tensor(validate_y, dtype=torch.float32)
test_x = torch.tensor(test_x, dtype=torch.float32)
test_y = torch.tensor(test_y, dtype=torch.float32)train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=True)

3. 模型定义

3.1 原始 MLP(无正则化)

import torch.nn as nn
import torch.nn.init as initclass FC_Classifier(nn.Module):def __init__(self, input_dim=1, hidden_dim=100, output_dim=1):super().__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.activation(self.fc1(x))return self.fc2(x)

3.2 Dropout MLP

class DropoutMLP(nn.Module):def __init__(self, dropout_rate=0.5):super().__init__()self.fc1 = nn.Linear(1, 100)self.dropout = nn.Dropout(dropout_rate)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()self._init_weights()def _init_weights(self):init.normal_(self.fc1.weight, mean=0.0, std=0.1)init.constant_(self.fc1.bias, 0)init.normal_(self.fc2.weight, mean=0.0, std=0.1)init.constant_(self.fc2.bias, 0)def forward(self, x):x = self.dropout(self.fc1(x))x = self.activation(x)return self.fc2(x)

3.3 Batch Normalization MLP

class BNMLP(nn.Module):def __init__(self):super().__init__()self.bn_input = nn.BatchNorm1d(1)self.fc1 = nn.Linear(1, 100)self.bn_hidden = nn.BatchNorm1d(100)self.fc2 = nn.Linear(100, 1)self.activation = nn.ReLU()def forward(self, x):x = self.bn_input(x)x = self.fc1(x)x = self.bn_hidden(x)x = self.activation(x)return self.fc2(x)

4. Early Stopping 策略

当验证集误差连续若干轮无提升时,提前停止训练,避免过拟合。

max_patience = 5
patience = 0
best_val_loss = float("inf")
is_early_stop = False

5. RMSNorm 实现与讲解

5.1 原理说明

RMSNorm 是一种替代 LayerNorm 的轻量化归一化方法:

  • 不减均值
  • 仅用激活值的均方根进行归一化
  • 不依赖 batch 维度

数学公式:

RMS ( x ) = 1 n ∑ i = 1 n x i 2 \text{RMS}(x) = \sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2} RMS(x)=n1i=1nxi2

RMSNorm ( x ) = x RMS ( x ) + ϵ ⋅ γ \text{RMSNorm}(x) = \frac{x}{\text{RMS}(x) + \epsilon} \cdot \gamma RMSNorm(x)=RMS(x)+ϵxγ

其中 γ \gamma γ 为可学习参数, ϵ \epsilon ϵ 是一个很小的数避免除以 0。

5.2 代码实现

class RMSNorm(nn.Module):def __init__(self, hidden_size, eps=1e-6):super().__init__()self.weight = nn.Parameter(torch.ones(hidden_size))self.eps = epsdef forward(self, x):rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)return self.weight * x / rms

5.3 与其他归一化对比

方法是否减均值是否除方差是否依赖 batch
BatchNorm
LayerNorm
RMSNorm是 (仅 RMS)

6. 实验建议

  • 尝试不同的 Dropout 比例(如 0.1 / 0.3 / 0.5)并观察效果;
  • 对比是否每层都加 BatchNorm 是否更优;
  • 比较 L2 正则项中 weight decay 的不同取值;
  • 使用 RMSNorm 替代 LayerNorm 做对比实验。
http://www.xdnf.cn/news/699733.html

相关文章:

  • 使用u盘重装银河麒麟操作系统
  • 【人工智能】微调革命:释放大模型的无限潜能
  • 【系统架构设计师】2025年上半年真题论文回忆版: 论软件测试方法及应用(包括解题思路和参考素材)
  • 社交类网站设计:经典feed流系统架构详细设计(小红书微博等)
  • window 显示驱动开发-处理 E_INVALIDARG 返回值
  • ArgoDB表类型及常用命令
  • 491. Non-decreasing Subsequences
  • DeepSeek R1 与 V3 的全面对比,两个版本有什么差别?
  • 【Linux】linux上看到的内存和实际内存不一样?
  • Linux云计算训练营笔记day17(Python)
  • Cisco Packer Tracer 组建虚拟局域网(VLAN)
  • 【前端】【Jquery】一篇文章学习Jquery所有知识点
  • keepalived两台设备同时出现VIP问题
  • MySql--explain的用法
  • 【Linux网络篇】:简单的TCP网络程序编写以及相关内容的扩展
  • css样式块重复调用
  • 楼宇自控系统重塑建筑设备管理:告别低效,迈向智能管理时代
  • 华为OD机试真题——书籍叠放(2025A卷:200分)Java/python/JavaScript/C/C++/GO最佳实现
  • Linux系统之cal命令的基本使用
  • 国有企业采购方式及适用情形
  • Java集合进阶
  • C++补充基础小知识:什么是接口类 和 抽象类?为什么要继承?
  • 线程的生命周期?怎么终止线程?线程和线程池有什么区别?如何创建线程池?说一下 ThreadPoolExecutor 的参数含义?
  • yolov12毕设前置知识准备 1
  • Linux基本指令/上
  • Python常用模块实用指南
  • Python人工智能算法学习 禁忌搜索算法求解旅行商问题(TSP)的研究与实现
  • .net Winfrom 如何将窗口设置为MDI容器
  • QGIS新手教程2:线图层与多边形图层基础操作指南(点线互转、中心点提取与WKT导出)
  • Git:现代软件开发的基石——原理、实践与行业智慧·优雅草卓伊凡