当前位置: 首页 > news >正文

4.权重衰减(weight decay)

4.1 手动实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return (y_hat-y)**2/2
def sgd(params,lr,batch_size):for params in params:params.data-=lr*params.grad/batch_sizeparams.grad.zero_()
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.sum().item()total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=10,0.05,3
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:l=loss(net(X),y)+lambd*l2_penalty(w)l.sum().backward()sgd([w,b],lr,batch_size)if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()

4.2 简单实现权重衰减

import torch
from torch import nn
from torch.utils.data import TensorDataset,DataLoader
import matplotlib.pyplot as plt
def synthetic_data(w,b,num_inputs):X=torch.normal(0,1,size=(num_inputs,w.shape[0]))y=X@w+by+=torch.normal(0,0.1,size=y.shape)return X,y
def load_array(data,batch_size,is_train=True):dataset=TensorDataset(*data)return DataLoader(dataset,batch_size=batch_size,shuffle=is_train)
def init_params(num_inputs):w=torch.normal(0,1,size=(num_inputs,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)return [w,b]
def l2_penalty(w):return 0.5*torch.sum(w.pow(2))
def linear_reg(X,w,b):return torch.matmul(X,w)+b
def mse_loss(y_hat,y):return ((y_hat-y)**2).sum()/2
def evaluate_loss(net, data_iter, loss):total_loss, total_samples = 0.0, 0for X, y in data_iter:l = loss(net(X), y)total_loss += l.item()*y.shape[0]total_samples += y.numel()return total_loss / total_samples
n_train,n_test,num_inputs,batch_size=20,100,200,5
true_w,true_b=torch.ones((num_inputs,1))*0.01,0.05
train_data=synthetic_data(true_w,true_b,n_train)
test_data=synthetic_data(true_w,true_b,n_test)
train_iter=load_array(train_data,batch_size)
test_iter=load_array(test_data,batch_size,is_train=False)
w,b=init_params(num_inputs)
net=lambda X:linear_reg(X,w,b)
loss=mse_loss
num_epochs,lr,lambd=100,0.001,3
optimizer=torch.optim.SGD([w,b],lr=lr,weight_decay=0.001)
#animator=SimpleAnimator()
for epoch in range(num_epochs):for X,y in train_iter:optimizer.zero_grad()l=loss(net(X),y)l.backward()#sgd([w,b],lr,batch_size)optimizer.step() if (epoch+1)%5==0:train_loss=evaluate_loss(net,train_iter,loss)test_loss=evaluate_loss(net,test_iter,loss)#animator.add(epoch+1,train_loss,test_loss)print(f"Epoch {epoch+1}: Train Loss: {train_loss:.4f},test Loss: {test_loss:.4f}")
print('w的L2范数是:', torch.norm(w).item())
plt.show()
http://www.xdnf.cn/news/1087147.html

相关文章:

  • EPLAN 电气制图(四):EPLAN 总电源电路设计知识详解
  • 【openGLES】安卓端EGL的使用
  • vue3 el-table 行数据沾满格 取消自动换行
  • 小米YU7预售现象深度解析:智能电动汽车的下一个范式革命
  • 【Linux】Redis 6.2.6 的二进制部署【适用于多版本】
  • 中州养老Day02
  • Zookeeper是如何解决脑裂问题的?
  • 深入了解linux系统—— System V之消息队列和信号量
  • 从0到1搭建ELK日志收集平台
  • 扣子Coze纯前端部署多Agents
  • 使用python的 FastApi框架开发图书管理系统-前后端分离项目分享
  • 暑假算法日记第四天
  • Django双下划线查询
  • 汽车功能安全系统阶段开发【技术安全方案TSC以及安全分析】5
  • 基于Vue 3的AI前端框架汇总及工具对比表
  • HTTP/3.x协议详解:基于QUIC的下一代Web传输协议
  • react的条件渲染【简约风5min】
  • 图像梯度处理与边缘检测:OpenCV 实战指南
  • AIGC与影视制作:技术革命、产业重构与未来图景
  • 无缝矩阵的音频合成与音频分离功能详解
  • 静态路由实验以及核心原理
  • 音频主动降噪技术
  • 2025年深圳杉川机器人性格测评和Verify测评SHL题库高分攻略
  • Ubuntu22.04中Google浏览器138版本无法使用中文搜狗输入法
  • AI开源伦理临大考,如何判定抄袭
  • nng库使用
  • 数据结构:位图
  • 无缝矩阵支持音频分离带画面分割功能的全面解析
  • 进阶向:Python音频录制与分析系统详解,从原理到实践
  • 代码详细注释:ARM-Linux字符设备驱动开发案例:LCD汉字输出改进建议开发板断电重启还能显示汉字,显示汉字位置自定义