当前位置: 首页 > backend >正文

打卡day52

简单cnn 借助调参指南进一步提高精度

基础CNN模型代码

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical# 加载数据
(train_images, train_labels), (test_images, test_labels) = cifar10.load_data()# 数据预处理
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)# 基础CNN模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])history = model.fit(train_images, train_labels, epochs=10, batch_size=64,validation_data=(test_images, test_labels))

改进方法

增加模型复杂度

model = models.Sequential([layers.Conv2D(64, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),layers.BatchNormalization(),layers.Conv2D(64, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Conv2D(128, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.Conv2D(128, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Conv2D(256, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.Conv2D(256, (3, 3), activation='relu', padding='same'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.25),layers.Flatten(),layers.Dense(512, activation='relu'),layers.BatchNormalization(),layers.Dropout(0.5),layers.Dense(10, activation='softmax')
])

优化器调参

from tensorflow.keras.optimizers import Adamoptimizer = Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07)
model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])

数据增强

from tensorflow.keras.preprocessing.image import ImageDataGeneratordatagen = ImageDataGenerator(rotation_range=15,width_shift_range=0.1,height_shift_range=0.1,horizontal_flip=True,zoom_range=0.1
)
datagen.fit(train_images)history = model.fit(datagen.flow(train_images, train_labels, batch_size=64),epochs=50,validation_data=(test_images, test_labels))

早停和模型检查点

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpointcallbacks = [EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
]history = model.fit(..., callbacks=callbacks, epochs=100)
http://www.xdnf.cn/news/13769.html

相关文章:

  • UE5制作与云渲染配置不足?3090/4090显卡云端解放创作力
  • 基于sample_aiisp例子,创建3路编码流,记录
  • 奥威BI:用AI重新定义数据分析,中小企业数字化转型的智能引擎
  • 力扣HOT100之技巧:31. 下一个排列
  • CMS软件以及常见分类
  • excel中自定义公式
  • 基于 Nginx 服务器的泛域名 SSL 证书申请与部署
  • 腾讯云:6月30日起,自动禁用,及时排查
  • keil5怎么关闭工程
  • JavaScript中的迭代器模式:优雅遍历数据的“设计之道”
  • React---Hooks深入
  • vue3 全局过滤器
  • 【Docker 04】image - 镜像
  • 《一本书看透A股》速读笔记
  • Python----神经网络发(神经网络发展历程)
  • 水库大坝安全监测之渗流监测
  • 理解LLM所谓的“推理”能力
  • Vim 命令大全:从入门到精通
  • Flutter 小技巧之:实现 iOS 26 的 “液态玻璃”
  • Spring Cloud Gateway 动态路由实现方案
  • Android NTP自动同步时间机制
  • 记录一个大模型逐层微调计算损失输出少了一个维度的小bug
  • Three.js搭建小米SU7三维汽车实战(4)场景搭建
  • 【时时三省】(C语言基础)将外部变量的作用域扩展到其他文件
  • 计算复变积分 $w = \int_0^1 (1 + it)^2 \, dt$
  • 【清晰教程】可视化数据集标注工具Labelimg零基础安装
  • openstack实例创建过程分析
  • 深度掌控,智启未来 —— 基于 STM32F103RBT6 的控制板
  • 离线部署openstack 2024.1 cinder
  • pangolin