当前位置: 首页 > ops >正文

WIFI信号状态信息 CSI 深度学习篇之CNN(Python)

本博客是一篇非新手导向的CNN处理CSI图像帧的教程,基于tensorflow框架构建CNN模型进行训练,训练对象依然是前述博客中所提到的CSI图像帧(500 x 90 x 1)。代码里用到了深度可分离卷积,这种结构在减少计算量和参数数量方面比较有优势的。在多次试验后,发现就我的数据集而言,这个模型和普通的卷积结构相比,对处理CSI图像帧这种复杂数据更有帮助,性能会略好一些。不过模型本身的结构不重要,重要的是先把代码跑起来训练起来,上述模型也只是在我的数据集上表现良好,仅供参考。

如今已经是2025年,关于代码怎么运行等等,一切问题均可以问AI,故不讲废话直接上代码:

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, DepthwiseConv2D, Add, GlobalAveragePooling2D, Dense, MaxPooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import LearningRateSchedulertrain_dir = '训练集地址'
val_dir = '验证集地址'
test_dir = '测试集地址'# 初始化数据生成器
train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)# 创建训练集、验证集和测试集
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(500, 90),batch_size=64,color_mode='grayscale',class_mode='categorical')val_generator = val_datagen.flow_from_directory(val_dir,target_size=(500, 90),batch_size=64,color_mode='grayscale',class_mode='categorical')test_generator = test_datagen.flow_from_directory(test_dir,target_size=(500, 90),batch_size=64,color_mode='grayscale',class_mode='categorical')# 定义深度可分离卷积块
def depthwise_separable_conv_block(input_tensor, channels, downsample=False):stride = (2, 2) if downsample else (1, 1)x = DepthwiseConv2D((3, 3), strides=stride, padding="same")(input_tensor)x = BatchNormalization()(x)x = ReLU()(x)x = Conv2D(channels, (1, 1), padding="same")(x)x = BatchNormalization()(x)x = ReLU()(x)if downsample:input_tensor = Conv2D(channels, (1, 1), strides=(2, 2), padding="same")(input_tensor)x = Add()([x, input_tensor])x = ReLU()(x)return x# 构建DSConvNet
def build_DSConvNet(input_shape, num_classes):inputs = Input(shape=input_shape)# 初始卷积块x = Conv2D(32, (3, 3), padding="same", strides=(2, 2))(inputs)x = BatchNormalization()(x)x = ReLU()(x)x = MaxPooling2D((3, 3), padding="same", strides=(2, 2))(x)# 第一组卷积块x = depthwise_separable_conv_block(x, 32)# 第二组卷积块x = depthwise_separable_conv_block(x, 64, downsample=True)# 第三组卷积块x = depthwise_separable_conv_block(x, 128, downsample=True)# 第四组卷积块x = depthwise_separable_conv_block(x, 256, downsample=True)x = GlobalAveragePooling2D()(x)outputs = Dense(num_classes, activation='softmax')(x)# 创建模型model = Model(inputs, outputs)return model# 创建模型实例
model = build_DSConvNet(input_shape=(500, 90, 1), num_classes=10)# 定义学习率调整函数并创建实例
def lr_schedule(epoch):if epoch < 10:return 0.01  # 前10轮学习率为0.01elif epoch < 20:return 0.001  # 第11到20轮为0.001else:return 0.0001  # 第21到30轮为0.0001
lr_scheduler = LearningRateScheduler(lr_schedule, verbose=1)# 编译模型
model.compile(optimizer=Adam(learning_rate=0.01),loss='categorical_crossentropy',metrics=['accuracy'])# 训练模型
history = model.fit(train_generator,steps_per_epoch=train_generator.samples // 64,epochs=30,validation_data=val_generator,validation_steps=val_generator.samples // 64,callbacks=[lr_scheduler])# 评估模型
test_loss, test_acc = model.evaluate(test_generator, steps=test_generator.samples // 64)
print(f'Test accuracy: {test_acc}, Test loss: {test_loss}')

http://www.xdnf.cn/news/7512.html

相关文章:

  • Typescript学习教程,从入门到精通,TypeScript 继承语法知识点及案例代码(8)
  • Kotlin 协程 (三)
  • vivado fpga程序固化
  • 学习黑客数据小包的TLS冒险之旅
  • Java 07异常
  • 将 Workbook 输出流直接上传到云盘
  • Apollo10.0学习——planning模块(8)之Frame类
  • 使用VGG-16模型来对海贼王中的角色进行图像分类分类
  • python打卡day31
  • SQLynx 团队协作实践:提升数据库开发效率的解决方案​
  • 4-5月份,思科,华为,微软,个别考试战报分享
  • Axure中使用动态面板实现图标拖动交换位置
  • C++23 新增扁平化关联容器详解
  • 微小店推客系统开发:构建全民营销矩阵,解锁流量增长密码
  • Java EE进阶1:导读
  • Spring Cloud Gateway深度解析:原理、架构与生产实践
  • 根据当前日期计算并选取上一个月和上一个季度的日期范围,用于日期控件的快捷选取功能
  • MySQL 8.0 OCP 英文题库解析(七)
  • 在 Git 中添加子模块(submodule)的详细步骤
  • kotlin 将一个list按条件分为两个list(partition )
  • 漏洞检测与渗透检验在功能及范围上究竟有何显著差异?
  • iOS Runtime与RunLoop的对比和使用
  • 基于flask+vue的电影可视化与智能推荐系统
  • PostgreSQL架构
  • HTML应用指南:利用POST请求获取全国申通快递服务网点位置信息
  • 华为云鲲鹏型kC2云服务器——鲲鹏920芯片性能测评
  • 【EI会议火热征稿中】第二届云计算与大数据国际学术会议(ICCBD 2025)
  • 程序运行报错分析文档
  • 使用 adb 命令截取 Android 设备的屏幕截图
  • CentOS 7连接公司网络配置指南