提前停止训练(Early Stopping)
提前停止(Early Stopping) 是一种在训练机器学习模子(尤其是深度学习模子)时常用的正则化技能,用于防止过拟合并提升模子的泛化能力。它通过监控验证集的性能,在性能不再提高或开始降落时终止训练,从而选择性能最佳的模子。
工作原理
提前停止的基本头脑是:
- 在每个训练轮次(epoch)后,评估模子在验证集上的性能(通常使用损失函数值或评价指标,如正确率)。
- 如果验证集性能在多个轮次内未改善,则停止训练并恢复到性能最佳的模子状态。
实现步骤
- 分割数据集: 将训练数据分为训练集和验证集,训练集用于优化模子参数,验证集用于监控模子的泛化性能。
- 设定监控指标: 选择一个监控指标(如验证损失、验证正确率等),作为权衡模子性能的尺度。
- 设定耐心值(Patience): 耐心值是指答应验证集性能在指定轮次内未改善的次数。如果凌驾耐心值还未见性能提升,则停止训练。
- 生存最佳模子: 在训练过程中,记录验证集性能最优的模子状态,停止训练后使用该状态作为最终模子。
优点
- 防止过拟合:通过终止训练,制止模子过分拟合训练数据。
- 提高泛化能力:选择验证集上性能最优的模子,提升模子对未见数据的体现。
- 节流训练时间:减少不须要的迭代,节省计算资源。
- 动态调解:适应数据集的差别复杂度,不需要预设固定的训练轮次。
缺点
- 需要验证集:需要分出一部分数据作为验证集,大概导致训练数据减少。
- 过早停止的风险:模子大概在某些训练阶段出现短暂波动,提前停止大概会错过更好的优化结果。
- 适合深度学习模子:对于小规模模子或简单问题,提前停止的效果大概不明显。
实现方式
1. 使用 TensorFlow/Keras 实现
- import numpy as np
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Dense
- from tensorflow.keras.callbacks import EarlyStopping
- # 示例数据集
- X_train = np.random.rand(1000, 20) # 1000个样本,每个样本20个特征
- y_train = np.random.randint(2, size=(1000, 1)) # 1000个样本的二分类标签
- X_val = np.random.rand(200, 20) # 200个样本,每个样本20个特征
- y_val = np.random.randint(2, size=(200, 1)) # 200个样本的二分类标签
- model = Sequential([
- Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
- Dense(1, activation='sigmoid')
- ])
- # 定义 EarlyStopping 回调,监控验证集损失,如果连续5个epoch没有改善则停止训练,并恢复最佳权重
- early_stopping = EarlyStopping(
- monitor='val_loss', # 监控的指标
- patience=5, # 在验证集性能不提升的轮数后停止
- restore_best_weights=True # 恢复验证集性能最优的模型
- )
- # 编译模型
- model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
- # 训练模型,使用早停机制
- model.fit(
- X_train, y_train,
- validation_data=(X_val, y_val),
- epochs=100,
- callbacks=[early_stopping]
- )
复制代码 运行结果
- Epoch 1/100
- 32/32 [==============================] - 1s 8ms/step - loss: 0.2522 - accuracy: 0.5210 - val_loss: 0.2504 - val_accuracy: 0.5350
- Epoch 2/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2491 - accuracy: 0.5320 - val_loss: 0.2502 - val_accuracy: 0.5300
- Epoch 3/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2484 - accuracy: 0.5320 - val_loss: 0.2507 - val_accuracy: 0.4950
- Epoch 4/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2468 - accuracy: 0.5260 - val_loss: 0.2521 - val_accuracy: 0.4950
- Epoch 5/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2456 - accuracy: 0.5560 - val_loss: 0.2524 - val_accuracy: 0.5150
- Epoch 6/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2452 - accuracy: 0.5450 - val_loss: 0.2540 - val_accuracy: 0.5000
- Epoch 7/100
- 32/32 [==============================] - 0s 2ms/step - loss: 0.2457 - accuracy: 0.5500 - val_loss: 0.2529 - val_accuracy: 0.4750
复制代码
2. 使用 PyTorch 实现
- import torch
- import torch.nn as nn
- class EarlyStopping:
- def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
- self.patience = patience
- self.delta = delta
- self.best_loss = None
- self.counter = 0
- self.early_stop = False
- self.path = path
- def __call__(self, val_loss, model):
- if self.best_loss is None or val_loss < self.best_loss - self.delta:
- self.best_loss = val_loss
- self.counter = 0
- torch.save(model.state_dict(), self.path)
- else:
- self.counter += 1
- if self.counter >= self.patience:
- self.early_stop = True
- # 定义示例数据集
- X_train = torch.randn(1000, 20) # 1000个样本,每个样本20个特征
- y_train = torch.randint(0, 2, (1000, 1)) # 1000个样本的二分类标签
- X_val = torch.randn(200, 20) # 200个样本,每个样本20个特征
- y_val = torch.randint(0, 2, (200, 1)) # 200个样本的二分类标签
- # 定义模型
- model = nn.Sequential(
- nn.Linear(20, 64),
- nn.ReLU(),
- nn.Linear(64, 1),
- nn.Sigmoid()
- )
- # 定义训练和验证函数
- def train():
- pass
- def validate():
- return torch.tensor(0.5) # 示例验证损失
- # 使用示例
- early_stopping = EarlyStopping(patience=5)
- for epoch in range(100):
- train() # 训练过程
- val_loss = validate() # 验证损失
- early_stopping(val_loss, model)
- if early_stopping.early_stop:
- model.load_state_dict(torch.load('checkpoint.pt'))
- break
复制代码
总结
提前停止训练是机器学习和深度学习中的一种简单高效的正则化方法,可以或许显著提升模子的泛化能力,同时减少训练时间。团结耐心值(patience)、监控指标以及最佳模子生存机制,可以灵活地应用到各种场景中。
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |