TensorFlow+CNN垃圾分类深度学习全流程实战教程
言简意赅的讲解TensorFlow+卷积神经网络(CNN)解决的痛点
项目概览
垃圾分类是实现可持续发展的重要环节,本教程通过TensorFlow+经典的卷积神经网络(CNN)示例,带你从环境配置到单图推理全流程落地:无需繁琐背景,只讲关键步骤,快速构建高效、可解释的自动化分类系统。如果读文章的同学想一键拥有和我一样的环境的话可以先部署Conda,有疑问的话可以读之前文章👉零基础上手Conda:安装、创建环境、管理依赖的完整指南
- 环境管理:
environment.yml
一键复现 - 数据集准备:下载链接与目录结构
- 数据清洗:自动删除损坏图片
- 数据增强:提升模型鲁棒性
- 模型搭建与训练:CNN 架构详解
- 训练过程可视化:Loss/Accuracy 曲线
- 单图推理:实时分类与可解释分析
- CNN vs. Transformer 对比:架构选型指南
一、环境管理
在项目根目录中创建一个名为 environment.yml
的文件,内容示例如下:
name: tf_gpu
channels:- defaults- conda-forge
dependencies:- _openmp_mutex=4.5- blas=1.0- brotli-python=1.0.9- bzip2=1.0.8- ca-certificates=2025.4.26- contourpy=1.3.1- cudatoolkit=11.2.2- cudnn=8.1.0.77- cycler=0.11.0- expat=2.7.1- fonttools=4.55.3- freetype=2.13.3- glib=2.84.0- glib-tools=2.84.0- gst-plugins-base=1.24.7- gstreamer=1.24.7- icc_rt=2022.1.0- icu=75.1- intel-openmp=2023.2.0- joblib=1.4.2- kiwisolver=1.4.8- krb5=1.21.3- lcms2=2.17- lerc=4.0.0- libblas=3.9.0- libcblas=3.9.0- libclang13=20.1.7- libdeflate=1.24- libffi=3.4.4- libfreetype=2.13.3- libfreetype6=2.13.3- libgcc=15.1.0- libglib=2.84.0- libgomp=15.1.0- libhwloc=2.11.2- libiconv=1.18- libintl=0.22.5- libintl-devel=0.22.5- libjpeg-turbo=3.1.0- liblapack=3.9.0- liblzma=5.8.1- liblzma-devel=5.8.1- libogg=1.3.5- libpng=1.6.47- libsqlite=3.50.1- libtiff=4.7.0- libvorbis=1.3.7- libwebp-base=1.5.0- libwinpthread=12.0.0.r4.gg4f2fc60ca- libxcb=1.17.0- libxml2=2.13.8- libzlib=1.3.1- matplotlib=3.10.0- matplotlib-base=3.10.0- mkl=2023.2.0- mkl-service=2.4.1- openjpeg=2.5.3- openssl=3.5.0- pcre2=10.44- pillow=11.2.1- pip=25.1- ply=3.11- pthread-stubs=0.4- pyparsing=3.2.0- pyqt=5.15.10- pyqt5-sip=12.13.0- python=3.10.16- python-dateutil=2.9.0post0- python_abi=3.10- qt-main=5.15.15- scikit-learn=1.6.1- setuptools=78.1.1- sip=6.7.12- six=1.17.0- sqlite=3.45.3- tbb=2021.13.0- threadpoolctl=3.5.0- tk=8.6.13- tomli=2.0.1- tornado=6.5.1- tzdata=2025b- ucrt=10.0.22621.0- unicodedata2=15.1.0- vc=14.42- vc14_runtime=14.42.34438- vs2015_runtime=14.42.34438- wheel=0.45.1- xorg-libxau=1.0.12- xorg-libxdmcp=1.1.5- xz=5.8.1- xz-tools=5.8.1- zlib=1.3.1- zstd=1.5.7- pip:- absl-py==2.3.0- astunparse==1.6.3- cachetools==5.5.2- certifi==2025.4.26- charset-normalizer==3.4.2- flatbuffers==25.2.10- gast==0.4.0- google-auth==2.40.3- google-auth-oauthlib==0.4.6- google-pasta==0.2.0- grpcio==1.73.0- h5py==3.14.0- idna==3.10- keras==2.10.0- keras-preprocessing==1.1.2- libclang==18.1.1- markdown==3.8- markupsafe==3.0.2- numpy==1.23.5- oauthlib==3.2.2- opt-einsum==3.4.0- packaging==25.0- protobuf==3.19.6- pyasn1==0.6.1- pyasn1-modules==0.4.2- requests==2.32.4- requests-oauthlib==2.0.0- rsa==4.9.1- scipy==1.15.3- tensorboard==2.10.1- tensorboard-data-server==0.6.1- tensorboard-plugin-wit==1.8.1- tensorflow==2.10.0- tensorflow-estimator==2.10.0- tensorflow-io-gcs-filesystem==0.31.0- termcolor==3.1.0- typing-extensions==4.14.0- urllib3==2.4.0- werkzeug==3.1.3- wrapt==1.17.2
prefix: C:\Users\Wenhao\.conda\envs\tf_gpu
-
一键创建环境
conda env create -f environment.yml conda activate garbage_classify
2. 数据集准备
2.1 下载与解压
- 来源:阿里云天池【垃圾分类数据集】
https://tianchi.aliyun.com/dataset/138860
2.2 目录结构
project-root/
├── dataset/
│ ├── Harmful/ # 有害垃圾
│ ├── Kitchen/ # 厨余垃圾
│ ├── Other/ # 其他垃圾
│ └── Recyclable/ # 可回收垃圾
├── clean_data.py
├── train.py
├── visualize.py
├── predict.py
└── environment.yml
3. 数据清洗
3.1 目的
- 自动剔除打不开或截断的图片,避免训练中断。
3.2 实现
# clean_data.py
import os
from PIL import Image, ImageFile# 支持加载截断图
ImageFile.LOAD_TRUNCATED_IMAGES = True
DATA_DIR = "dataset/"
bad_images = []for root, _, files in os.walk(DATA_DIR):for fname in files:path = os.path.join(root, fname)try:with Image.open(path) as img:img.verify()except:bad_images.append(path)if bad_images:print(f"删除 {len(bad_images)} 张损坏图片:")for p in bad_images:os.remove(p)print(" ✔", p)
else:print("✅ 未检测到损坏图片")
python clean_data.py
4. 数据增强
4.1 增强化技巧
- 几何变换:旋转、平移、剪切、缩放
- 颜色变换:亮度、通道抖动
- 翻转与填充:水平翻转 + 边界反射
4.2 代码示例
# train.py 中的数据生成部分
from tensorflow.keras.preprocessing.image import ImageDataGeneratorIMAGE_SIZE = (128, 128)
BATCH_SIZE = 32datagen = ImageDataGenerator(rescale=1./255,validation_split=0.2,rotation_range=20,width_shift_range=0.1,height_shift_range=0.1,shear_range=10,zoom_range=0.2,brightness_range=[0.8,1.2],channel_shift_range=15,horizontal_flip=True,fill_mode='reflect'
)train_gen = datagen.flow_from_directory("dataset/",target_size=IMAGE_SIZE,batch_size=BATCH_SIZE,class_mode='categorical',subset='training'
)
val_gen = datagen.flow_from_directory("dataset/",target_size=IMAGE_SIZE,batch_size=BATCH_SIZE,class_mode='categorical',subset='validation'
)
5. 模型搭建与训练
5.1 模型架构
- 卷积层 + 池化层:提取多层次特征
- 批归一化:稳定加速训练
- 全局平均池化:参数少、防过拟合
- 全连接 + Dropout:分类输出
5.2 训练脚本
# train.py
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (Conv2D, BatchNormalization, MaxPooling2D,GlobalAveragePooling2D, Dense, Dropout
)
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateauprint("✅ GPU:", tf.config.list_physical_devices('GPU'))num_classes = train_gen.num_classes
model = Sequential([Conv2D(32,3,activation='relu',input_shape=(128,128,3)),BatchNormalization(), MaxPooling2D(),Conv2D(64,3,activation='relu'),BatchNormalization(), MaxPooling2D(),Conv2D(128,3,activation='relu'),BatchNormalization(), MaxPooling2D(),GlobalAveragePooling2D(),Dense(128,activation='relu'),Dropout(0.5),Dense(num_classes,activation='softmax'),
])model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),loss='categorical_crossentropy',metrics=['accuracy']
)
model.summary()callbacks = [EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]history = model.fit(train_gen,validation_data=val_gen,epochs=50,callbacks=callbacks
)model.save("custom_garbage_classifier.h5")
print("✅ 模型保存至 custom_garbage_classifier.h5")
6. 训练过程可视化
# visualize.py
import matplotlib.pyplot as plt# Loss 曲线
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title("Loss 曲线")
plt.legend()
plt.show()# Accuracy 曲线
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='val_acc')
plt.title("Accuracy 曲线")
plt.legend()
plt.show()
训练过程完整代码
import os
from PIL import Image, ImageFile# 允许 Pillow 加载被截断的图片
ImageFile.LOAD_TRUNCATED_IMAGES = True# 数据集路径
DATA_DIR = "dataset/"# 第一:自动清理所有损坏或截断的图片
bad_images = []
for root, _, files in os.walk(DATA_DIR):for fname in files:path = os.path.join(root, fname)try:with Image.open(path) as img:img.verify()except Exception:bad_images.append(path)if bad_images:print(f"Found {len(bad_images)} bad images. Removing…")for p in bad_images:os.remove(p)print(" Removed", p)
else:print("No corrupted images found.")# —— 下面是你的训练脚本 —— #import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout# 检查 GPU 是否可用
print("✅ GPU 设备列表:", tf.config.list_physical_devices('GPU'))# 参数
IMAGE_SIZE = (128, 128)
BATCH_SIZE = 32
EPOCHS = 15# 数据增强 + 预处理
datagen = ImageDataGenerator(rescale=1./255,validation_split=0.2,rotation_range=15,width_shift_range=0.1,height_shift_range=0.1,zoom_range=0.1,horizontal_flip=True
)train_gen = datagen.flow_from_directory(DATA_DIR,target_size=IMAGE_SIZE,batch_size=BATCH_SIZE,class_mode='categorical',subset='training'
)val_gen = datagen.flow_from_directory(DATA_DIR,target_size=IMAGE_SIZE,batch_size=BATCH_SIZE,class_mode='categorical',subset='validation'
)# 自定义 CNN 模型结构
model = Sequential([Conv2D(32, (3, 3), activation='relu', input_shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3)),MaxPooling2D(2, 2),Conv2D(64, (3, 3), activation='relu'),MaxPooling2D(2, 2),Conv2D(128, (3, 3), activation='relu'),MaxPooling2D(2, 2),Flatten(),Dense(128, activation='relu'),Dropout(0.5),Dense(train_gen.num_classes, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy']
)# 模型结构
model.summary()# 增加 EarlyStopping,防止过拟合
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)# 模型训练
history = model.fit(train_gen,validation_data=val_gen,epochs=EPOCHS,callbacks=[early_stop]
)# 模型保存
model.save("custom_garbage_classifier.h5")
print("✅ 模型训练完成并保存为 custom_garbage_classifier.h5")
7. 单图推理与可解释 AI
# predict.py
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_arraymodel = load_model("custom_garbage_classifier.h5")
img_path = "evalImageSet/5.jpg"
IMG_SIZE = (128, 128)img = load_img(img_path, target_size=IMG_SIZE)
x = img_to_array(img)/255.0
x = np.expand_dims(x,0)probs = model.predict(x)[0]
class_idx = np.argmax(probs)class_indices = {'Harmful':0,'Kitchen':1,'Other':2,'Recyclable':3}
labels = {v:k for k,v in class_indices.items()}print(f"▶ {img_path} → {labels[class_idx]} ({probs[class_idx]:.1%})")
print("各类别概率:")
for i,p in enumerate(probs):print(f" {labels[i]:<12}: {p:.2%}")plt.imshow(img)
plt.title(f"{labels[class_idx]} ({probs[class_idx]:.1%})")
plt.axis('off')
plt.show()
- 可选:Grad-CAM 可视化关注区域。
推理过程完整代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array# 1. 加载模型
model = load_model("custom_garbage_classifier.h5")# 2. 指定要分析的图片路径
img_path = "evalImageSet/5.jpg" # 改成你自己的图片# 3. 载入并预处理
IMG_SIZE = (128, 128)
img = load_img(img_path, target_size=IMG_SIZE)
x = img_to_array(img) / 255.0 # 归一化到 [0,1]
x = np.expand_dims(x, axis=0) # 变成 (1,128,128,3)# 4. 预测
probs = model.predict(x)[0] # 得到一个长度为类别数的向量
class_idx = np.argmax(probs) # 预测的类别索引# 5. 反查类别名称
# 这里假设你有个 class_indices dict,来自训练时的 generator
# 比如:{'glass': 0, 'paper': 1, 'plastic': 2, ...}
# 请替换成你的实际 mapping
class_indices = {'Harmful': 0, 'Kitchen': 1, 'Other': 2, 'Recyclable': 3}
labels = {v:k for k,v in class_indices.items()}pred_label = labels[class_idx]
pred_prob = probs[class_idx]# 6. 输出结果
print(f"▶ 分析图片:{img_path}")
print(f"预测类别:{pred_label},置信度:{pred_prob:.4%}")
print("\n各类别概率:")
for idx, p in enumerate(probs):print(f" {labels[idx]:<8}: {p:.2%}")# 7. (可选)显示图片
plt.imshow(img)
plt.title(f"Pred: {pred_label} ({pred_prob:.1%})")
plt.axis('off')
plt.show()
8. CNN 与 Transformer 对比
维度 | CNN | Transformer |
---|---|---|
核心模块 | Conv2D + Pooling | Self-Attention + Feed-Forward |
感受野 | 随层级堆叠扩大 | 单层即可实现全局 |
参数共享 | 卷积核在空间/时间上复用 | 注意力权重在所有 token 对上共享 |
位置敏感 | 平移不变;须显式位置编码 | 原生顺序敏感 + 位置编码 |
并行度 | 高(局部并行) | 极高(全局并行) |
计算复杂度 | O(N·K²·C_out) | O(N²·D) |
- 共性:底层张量运算、反向传播、优化器、正则化方法相同。
- 选型:依“局部 vs. 全局依赖”选择;也可混合(ViT、Conformer、CLIP)。
通过上述内容,你就已经基本理解了这个方法,基础用法我也都有展示。如果你能融会贯通,我相信你会很强
Best
Wenhao (楠博万)