当前位置: 首页 > web >正文

Day40打卡 @浙大疏锦行

知识点回顾:

  1. 彩色和灰度图片测试和训练的规范写法:封装在函数中
  2. 展平操作:除第一个维度batchsize外全部展平
  3. dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout
import torchdef flatten_tensor(x):"""保留batch维度,展平其余所有维度"""return x.view(x.size(0), -1)

class NeuralNetwork(torch.nn.Module):def __init__(self, input_dim, hidden_dim, dropout_prob=0.5):super().__init__()self.layer1 = torch.nn.Linear(input_dim, hidden_dim)self.dropout = torch.nn.Dropout(dropout_prob)self.layer2 = torch.nn.Linear(hidden_dim, 10)def forward(self, x):x = torch.relu(self.layer1(x))x = self.dropout(x)  # 训练时激活,测试时自动关闭return self.layer2(x)

model = NeuralNetwork(input_dim=784, hidden_dim=256)# 训练阶段
model.train()
output_train = model(flattened_grayscale)  # Dropout生效# 测试阶段
model.eval()
with torch.no_grad():output_test = model(flattened_grayscale)  # Dropout自动关闭

def train_epoch(model, dataloader, criterion, optimizer):model.train()for inputs, targets in dataloader:inputs = flatten_tensor(inputs)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()def evaluate(model, dataloader):model.eval()total_correct = 0with torch.no_grad():for inputs, targets in dataloader:inputs = flatten_tensor(inputs)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total_correct += (predicted == targets).sum().item()return total_correct / len(dataloader.dataset)

@浙大疏锦行​​​​​​​

http://www.xdnf.cn/news/9998.html

相关文章:

  • 低功耗架构突破:STM32H750 与 SD NAND (存储芯片)如何延长手环续航至 14 天
  • 使用vscode进行c/c++开发的时候,输出报错乱码、cpp文件本身乱码的问题解决
  • 外包项目交付后还能怎么加固?我用 Ipa Guard 给 iOS IPA 增加了一层保障
  • 数据库暴露--Get型注入攻击
  • C++?多态!!!
  • Git的简单介绍分析及常用使用方法
  • openppp2 -- 1.0.0.25225 优化多线接入运营商路由调配
  • 电路笔记(通信):CAN 仲裁机制(Arbitration Mechanism) 位级监视线与特性先占先得非破坏性仲裁
  • 【机器人】具身导航 VLN 最新论文汇总 | Vision-and-Language Navigation
  • 人工智能100问☞第37问:什么是扩散模型?
  • 【清晰教程】利用Git工具将本地项目push上传至GitHub仓库中
  • 【开源工具】音频格式转换大师:基于PyQt5与FFmpeg的高效格式转换工具开发全解析
  • Go语言使用阿里云模版短信服务
  • 类 Excel 数据填报
  • LVS-NAT 负载均衡群集
  • C++高级编程深度指南:内存管理、安全函数、递归、错误处理、命令行参数解析、可变参数应用与未定义行为规避
  • 历年西安电子科技大学计算机保研上机真题
  • Redisson学习专栏(三):高级特性与实战(Spring/Spring Boot 集成,响应式编程,分布式服务,性能优化)
  • Real SQL Programming
  • 安装一个包 myPhysicsLab
  • Numpy知识点
  • Cesium 8 ,在 Cesium 上实现雷达动画和车辆动画效果,并控制显示和隐藏
  • 提示词优化技巧
  • 【Java】线程池的实现原理是怎样的?CPU密集型任务与IO密集型任务的区别?
  • Java基础面试题--jdk和jre的区别
  • openbmc kvm vnc client connection
  • 四、若依从数据库
  • 【JavaWeb】基本概念、web服务器、Tomcat、HTTP协议
  • 数据结构数组总结
  • 大模型调用数据库表实践:基于自然语言的SQL生成与数据查询系统