python第31天打卡
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers, optimizers, utils, datasets# 数据加载和预处理函数
def load_and_preprocess_data():(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()# 重塑并归一化图像数据x_train = x_train.reshape(-1, 28, 28, 1).astype("float32") / 255.0x_test = x_test.reshape(-1, 28, 28, 1).astype("float32") / 255.0# 转换标签为one-hot编码y_train = utils.to_categorical(y_train, 10)y_test = utils.to_categorical(y_test, 10)return (x_train, y_train), (x_test, y_test)# 模型定义
def create_simple_cnn():return keras.Sequential([layers.Conv2D(16, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(10, activation='softmax')])def create_complex_cnn():return keras.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(10, activation='softmax')])# 训练和评估函数
def train_and_evaluate(model, optimizer, x_train, y_train, x_test, y_test):model.compile(optimizer=optimizer,loss='categorical_crossentropy',metrics=['accuracy'])history = model.fit(x_train, y_train,epochs=5,batch_size=64,validation_data=(x_test, y_test))return history.history# 主程序
if __name__ == "__main__":# 加载数据(x_train, y_train), (x_test, y_test) = load_and_preprocess_data()# 模型和优化器配置model_configs = [('Simple CNN', create_simple_cnn),('Complex CNN', create_complex_cnn)]optimizers_config = {'SGD': optimizers.SGD(learning_rate=0.01),'Adam': optimizers.Adam(learning_rate=0.001)}# 训练和评估所有组合results = {}for model_name, model_fn in model_configs:for opt_name, optimizer in optimizers_config.items():print(f"\n{'='*50}")print(f"Training {model_name} with {opt_name} optimizer:")model = model_fn()history = train_and_evaluate(model, optimizer,x_train, y_train,x_test, y_test)# 记录结果results[f"{model_name}_{opt_name}"] = historyprint(f"\nTraining results for {model_name}/{opt_name}:")print(f"Final Training Accuracy: {history['accuracy'][-1]:.4f}")print(f"Final Validation Accuracy: {history['val_accuracy'][-1]:.4f}")print(f"Final Training Loss: {history['loss'][-1]:.4f}")print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")
@浙大疏锦行