当前位置: 首页 > news >正文

基于残差神经网络的垃圾分类

本课题旨在利用残差神经网络(ResNet)构建一个高效的图像分类模型,实现对垃圾图像的自动识别与分类。通过引入残差连接,有效缓解深层神经网络在训练过程中出现的梯度消失和退化问题,从而提升模型在复杂垃圾图像数据集上的识别精度与泛化能力。研究过程中将构建包含多类别垃圾图像的数据集,利用数据增强技术提升训练样本多样性,最终在测试集中实现对如“可回收物”“有害垃圾”“湿垃圾”“干垃圾”等类别的准确判别。该方法在智能垃圾投放、资源回收与环境管理等领域具有重要的实际应用价值。

# %%
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input, Add
from sklearn.metrics import classification_report, confusion_matrix# 设置数据路径
data_dir = "database"
img_size = (128, 128)
batch_size = 32
sample_fraction = 0.1  # 仅使用10%数据# 数据预处理与增强
data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)train_generator = data_gen.flow_from_directory(data_dir, target_size=img_size, batch_size=batch_size, subset='training', shuffle=True)val_generator = train_generator# 构建带残差连接的CNN模型
input_layer = Input(shape=(128, 128, 3))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_layer)
conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
residual1 = Add()([conv1, conv2])  # 残差连接
pool1 = MaxPooling2D(2, 2)(residual1)conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv3)
residual2 = Add()([conv3, conv4])  # 残差连接
pool2 = MaxPooling2D(2, 2)(residual2)flatten = Flatten()(pool2)
dense1 = Dense(128, activation='relu')(flatten)
dropout = Dropout(0.5)(dense1)
output_layer = Dense(len(train_generator.class_indices), activation='softmax')(dropout)model = Model(inputs=input_layer, outputs=output_layer)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 训练模型
history = model.fit(train_generator, validation_data=val_generator, epochs=100)# 保存模型
model.save("garbage_classifier.h5")# %%
# 绘制训练过程的准确率和损失变化
plt.figure(figsize=(12, 5))# 训练 & 验证损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Over Epochs')# 训练 & 验证准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Over Epochs')plt.show()# 计算混淆矩阵
val_generator.reset()
Y_pred = model.predict(val_generator)
y_pred = np.argmax(Y_pred, axis=1)
y_true = val_generator.classes# 计算分类报告
class_labels = list(train_generator.class_indices.keys())
print("分类报告:\n", classification_report(y_true, y_pred, target_names=class_labels))# 绘制混淆矩阵热力图
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix Heatmap")
plt.show()# %%
# 定义垃圾分类预测接口
def predict_image(image_path):model = load_model("garbage_classifier.h5")image = load_img(image_path, target_size=img_size)image = img_to_array(image) / 255.0image = np.expand_dims(image, axis=0)prediction = model.predict(image)predicted_class = np.argmax(prediction)class_labels = list(train_generator.class_indices.keys())return class_labels[predicted_class]print(predict_image("C:\MY_PROJECT\PROJECT\project4\database\cardboard\cardboard3.jpg"))

http://www.xdnf.cn/news/1070803.html

相关文章:

  • Maven生命周期与阶段扩展深度解析
  • 嵌入式项目:基于QT与Hi3861的物联网智能大棚集成控制系统
  • jenkins中执行python脚本导入路径错误
  • Chrome浏览器访问https提示“您的连接不是私密连接”问题解决方案
  • 【C++特殊工具与技术】固有的不可移植的特性(3)::extern“C“
  • 力扣第455场周赛
  • MATLAB 4D作图
  • Hyperledger Fabric 入门笔记(二十)Fabric V2.5 测试网络进阶之Tape性能测试
  • OpenCV模版匹配方法的衡量指标比较
  • 修复opensuse 风滚草rabbitmq的Error: :plugins_dir_does_not_exist问题
  • 【STM32】外部中断
  • 【Linux】基础开发工具(2)
  • java枚举enum的使用示例
  • 大厂测开实习和小厂开发实习怎么选
  • Java设计模式->责任链模式的介绍
  • [AI]从0到1通过神经网络训练模型
  • python+requests接口自动化测试
  • 《规则怪谈》合集
  • [特殊字符]️ 用 Python 绘制专业风玫瑰图:从气象数据到可视化的全流程指南
  • vscode ssh远程连接到Linux并实现免密码登录
  • Apipost和Postman对比
  • 缓存与加速技术实践-MongoDB数据库应用
  • 【RESTful接口设计规范全解析】URL路径设计 + 动词名词区分 + 状态码 + 返回值结构 + 最佳实践 + 新手常见误区汇总
  • Python打卡:Day37
  • 算法打卡 day4
  • Spring Boot 项目中同时使用 Swagger 和 Javadoc 的完整指南
  • Selenium+Pytest自动化测试框架实战
  • 快速傅里叶变换(FFT)是什么?
  • uniapp微信小程序:editor组件placeholder字体样式修改
  • GC 学习笔记