Keras/TensorFlow 中 `fit()` 方法参数详细说明
Keras/TensorFlow 中 fit()
方法参数详细说明
Keras/TensorFlow 中的 fit()
方法是训练神经网络的核心API,提供了丰富的参数来控制训练过程。以下是所有参数的详细说明:
一、基础参数
1. x
/y
- 作用:输入数据和目标数据
- 类型:
- NumPy数组
- TensorFlow张量
- 字典(用于具名输入)
- tf.data数据集
- 示例:
model.fit(x=train_images, y=train_labels)
2. batch_size
- 作用:每个梯度更新的样本数
- 类型:整数或
None
- 默认值:32
- 注意:
- 如果使用数据集对象并且指定了
steps_per_epoch
,则不需要设置 - 典型值:16/32/64/128/256
- 如果使用数据集对象并且指定了
3. epochs
- 作用:训练轮次数
- 类型:整数
- 默认值:1
- 示例:
model.fit(..., epochs=50)
4. verbose
- 作用:控制训练过程输出的详细程度
- 类型:整数
- 可选值:
- 0:静默模式
- 1:进度条(默认)
- 2:每个epoch一行输出
二、验证相关参数
5. validation_split
- 作用:从训练数据中分出部分作为验证集的比例
- 类型:0-1之间的浮点数
- 默认值:0.0(不使用)
- 示例:
model.fit(..., validation_split=0.2) # 使用20%数据作为验证集
6. validation_data
- 作用:手动指定验证数据集
- 类型:与
x
/y
相同的格式 - 优先级:高于
validation_split
- 示例:
model.fit(..., validation_data=(val_images, val_labels))
7. validation_freq
- 作用:指定每隔多少epoch进行一次验证
- 类型:整数或列表
- 默认值:1(每个epoch都验证)
- 示例:
model.fit(..., validation_freq=3) # 每3个epoch验证一次
三、数据相关参数
8. shuffle
- 作用:是否在每个epoch前打乱数据
- 类型:布尔值
- 默认值:
True
- 注意:使用
tf.data
数据集时优先使用数据集自身的shuffle操作
9. class_weight
- 作用:为不同类别分配权重(用于不平衡数据集)
- 类型:字典
- 示例:
model.fit(..., class_weight={0: 1., 1: 0.5}) # 类别1的权重是类别0的一半
10. sample_weight
- 作用:为每个样本分配权重
- 类型:NumPy数组
- 示例:
weights = np.array([1.0, 1.5]) # 第二个样本权重更大 model.fit(..., sample_weight=weights)
11. initial_epoch
- 作用:从指定epoch开始训练(用于恢复训练)
- 类型:整数
- 默认值:0
- 示例:
model.fit(..., initial_epoch=10) # 从第10个epoch开始
四、回调与控制参数
12. callbacks
- 作用:训练过程中执行的回调函数列表
- 类型:列表
- 常见回调:
EarlyStopping
- 早停ModelCheckpoint
- 保存模型TensorBoard
- 可视化LearningRateScheduler
- 学习率调整
- 示例:
callbacks = [tf.keras.callbacks.EarlyStopping(patience=3),tf.keras.callbacks.ModelCheckpoint('model.h5') ] model.fit(..., callbacks=callbacks)
五、高级参数
13. steps_per_epoch
- 作用:每个epoch执行的batch步数
- 类型:整数
- 默认值:
None
(自动计算:样本数/batch_size) - 适用场景:
- 使用无限数据集时必需指定
- 部分数据集训练
14. validation_steps
- 作用:验证时使用的batch步数
- 类型:整数
- 适用场景:
- 验证数据为无限数据集时必需指定
15. max_queue_size
- 作用:生成器队列的最大大小
- 类型:整数
- 默认值:10
- 适用场景:使用Python生成器作为输入时
16. workers
- 作用:生成器预处理的最大进程数
- 类型:整数
- 默认值:1
17. use_multiprocessing
- 作用:是否使用多进程处理数据
- 类型:布尔值
- 默认值:
False
- 注意:设置
True
可能导致性能下降
六、实际使用示例
# 完整参数示例
history = model.fit(x=train_images,y=train_labels,batch_size=64,epochs=100,verbose=1,callbacks=[tf.keras.callbacks.EarlyStopping(patience=5),tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, patience=3)],validation_data=(val_images, val_labels),validation_freq=2,shuffle=True,class_weight={0: 1.0, 1: 2.0}, # 假设类别1更重要initial_epoch=0,steps_per_epoch=None,validation_steps=None,max_queue_size=10,workers=4,use_multiprocessing=False
)
七、返回值
fit()
方法返回 History
对象,包含:
history.history
:字典,包含训练过程中的loss和metrics记录history.epoch
:完成的epoch列表history.params
:训练参数history.model
:对应的模型对象
# 使用训练历史
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Val'], loc='upper left')
plt.show()