基于残差神经网络的垃圾分类
本课题旨在利用残差神经网络(ResNet)构建一个高效的图像分类模型,实现对垃圾图像的自动识别与分类。通过引入残差连接,有效缓解深层神经网络在训练过程中出现的梯度消失和退化问题,从而提升模型在复杂垃圾图像数据集上的识别精度与泛化能力。研究过程中将构建包含多类别垃圾图像的数据集,利用数据增强技术提升训练样本多样性,最终在测试集中实现对如“可回收物”“有害垃圾”“湿垃圾”“干垃圾”等类别的准确判别。该方法在智能垃圾投放、资源回收与环境管理等领域具有重要的实际应用价值。
# %%
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input, Add
from sklearn.metrics import classification_report, confusion_matrix# 设置数据路径
data_dir = "database"
img_size = (128, 128)
batch_size = 32
sample_fraction = 0.1 # 仅使用10%数据# 数据预处理与增强
data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)train_generator = data_gen.flow_from_directory(data_dir, target_size=img_size, batch_size=batch_size, subset='training', shuffle=True)val_generator = train_generator# 构建带残差连接的CNN模型
input_layer = Input(shape=(128, 128, 3))
conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')(input_layer)
conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
residual1 = Add()([conv1, conv2]) # 残差连接
pool1 = MaxPooling2D(2, 2)(residual1)conv3 = Conv2D(64, (3, 3), activation='relu', padding='same')(pool1)
conv4 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv3)
residual2 = Add()([conv3, conv4]) # 残差连接
pool2 = MaxPooling2D(2, 2)(residual2)flatten = Flatten()(pool2)
dense1 = Dense(128, activation='relu')(flatten)
dropout = Dropout(0.5)(dense1)
output_layer = Dense(len(train_generator.class_indices), activation='softmax')(dropout)model = Model(inputs=input_layer, outputs=output_layer)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# 训练模型
history = model.fit(train_generator, validation_data=val_generator, epochs=100)# 保存模型
model.save("garbage_classifier.h5")# %%
# 绘制训练过程的准确率和损失变化
plt.figure(figsize=(12, 5))# 训练 & 验证损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Over Epochs')# 训练 & 验证准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Over Epochs')plt.show()# 计算混淆矩阵
val_generator.reset()
Y_pred = model.predict(val_generator)
y_pred = np.argmax(Y_pred, axis=1)
y_true = val_generator.classes# 计算分类报告
class_labels = list(train_generator.class_indices.keys())
print("分类报告:\n", classification_report(y_true, y_pred, target_names=class_labels))# 绘制混淆矩阵热力图
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix Heatmap")
plt.show()# %%
# 定义垃圾分类预测接口
def predict_image(image_path):model = load_model("garbage_classifier.h5")image = load_img(image_path, target_size=img_size)image = img_to_array(image) / 255.0image = np.expand_dims(image, axis=0)prediction = model.predict(image)predicted_class = np.argmax(prediction)class_labels = list(train_generator.class_indices.keys())return class_labels[predicted_class]print(predict_image("C:\MY_PROJECT\PROJECT\project4\database\cardboard\cardboard3.jpg"))