early stopping in PyTorch
Question:
I tried to implement an early stopping function to avoid my neural network model overfit. I’m pretty sure that the logic is fine, but for some reason, it doesn’t work.
I want that when the validation loss is greater than the training loss over some epochs, the early stopping function returns True. But it returns False all the time, even though the validation loss becomes a lot greater than the training loss. Could you see where is the problem, please?
early stopping function
def early_stopping(train_loss, validation_loss, min_delta, tolerance):
counter = 0
if (validation_loss - train_loss) > min_delta:
counter +=1
if counter >= tolerance:
return True
calling the function during the training
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
print("We are at epoch:", i)
break
EDIT:
The train and validation loss:
EDIT2:
def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
preds = []
train_loss = []
validation_loss = []
min_delta = 5
for e in range(epochs):
print(f"Epoch {e+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", e)
break
return train_loss, validation_loss
Answers:
The problem with your implementation is that whenever you call early_stopping()
the counter is re-initialized with 0
.
Here is working solution using an oo-oriented approch with __call__()
and __init__()
instead:
class EarlyStopping:
def __init__(self, tolerance=5, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
Call it like that:
early_stopping = EarlyStopping(tolerance=5, min_delta=10)
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", i)
break
Example:
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
train_loss = [
642.14990234,
601.29278564,
561.98400879,
530.01501465,
497.1098938,
466.92709351,
438.2364502,
413.76028442,
391.5090332,
370.79074097,
]
validate_loss = [
509.13619995,
497.3125,
506.17315674,
497.68960571,
505.69918823,
459.78610229,
480.25592041,
418.08630371,
446.42675781,
372.09902954,
]
for i in range(len(train_loss)):
early_stopping(train_loss[i], validate_loss[i])
print(f"loss: {train_loss[i]} : {validate_loss[i]}")
if early_stopping.early_stop:
print("We are at epoch:", i)
break
Output:
loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
Although @KarelZe’s response solves your problem sufficiently and elegantly, I want to provide an alternative early stopping criterion that is arguably better.
Your early stopping criterion is based on how much (and for how long) the validation loss diverges from the training loss. This will break when the validation loss is indeed decreasing but is generally not close enough to the training loss. The goal of training a model is to encourage the reduction of validation loss and not the reduction in the gap between training loss and validation loss.
Hence, I would argue that a better early stopping criterion would be watch for the trend in validation loss alone, i.e., if the training is not resulting in lowering of the validation loss then terminate it. Here’s an example implementation:
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def early_stop(self, validation_loss):
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > (self.min_validation_loss + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
Here’s how you’d use it:
early_stopper = EarlyStopper(patience=3, min_delta=10)
for epoch in np.arange(n_epochs):
train_loss = train_one_epoch(model, train_loader)
validation_loss = validate_one_epoch(model, validation_loader)
if early_stopper.early_stop(validation_loss):
break
I tried to implement an early stopping function to avoid my neural network model overfit. I’m pretty sure that the logic is fine, but for some reason, it doesn’t work.
I want that when the validation loss is greater than the training loss over some epochs, the early stopping function returns True. But it returns False all the time, even though the validation loss becomes a lot greater than the training loss. Could you see where is the problem, please?
early stopping function
def early_stopping(train_loss, validation_loss, min_delta, tolerance):
counter = 0
if (validation_loss - train_loss) > min_delta:
counter +=1
if counter >= tolerance:
return True
calling the function during the training
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
print("We are at epoch:", i)
break
EDIT:
The train and validation loss:
EDIT2:
def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
preds = []
train_loss = []
validation_loss = []
min_delta = 5
for e in range(epochs):
print(f"Epoch {e+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", e)
break
return train_loss, validation_loss
The problem with your implementation is that whenever you call early_stopping()
the counter is re-initialized with 0
.
Here is working solution using an oo-oriented approch with __call__()
and __init__()
instead:
class EarlyStopping:
def __init__(self, tolerance=5, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
Call it like that:
early_stopping = EarlyStopping(tolerance=5, min_delta=10)
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", i)
break
Example:
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
train_loss = [
642.14990234,
601.29278564,
561.98400879,
530.01501465,
497.1098938,
466.92709351,
438.2364502,
413.76028442,
391.5090332,
370.79074097,
]
validate_loss = [
509.13619995,
497.3125,
506.17315674,
497.68960571,
505.69918823,
459.78610229,
480.25592041,
418.08630371,
446.42675781,
372.09902954,
]
for i in range(len(train_loss)):
early_stopping(train_loss[i], validate_loss[i])
print(f"loss: {train_loss[i]} : {validate_loss[i]}")
if early_stopping.early_stop:
print("We are at epoch:", i)
break
Output:
loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
Although @KarelZe’s response solves your problem sufficiently and elegantly, I want to provide an alternative early stopping criterion that is arguably better.
Your early stopping criterion is based on how much (and for how long) the validation loss diverges from the training loss. This will break when the validation loss is indeed decreasing but is generally not close enough to the training loss. The goal of training a model is to encourage the reduction of validation loss and not the reduction in the gap between training loss and validation loss.
Hence, I would argue that a better early stopping criterion would be watch for the trend in validation loss alone, i.e., if the training is not resulting in lowering of the validation loss then terminate it. Here’s an example implementation:
class EarlyStopper:
def __init__(self, patience=1, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
def early_stop(self, validation_loss):
if validation_loss < self.min_validation_loss:
self.min_validation_loss = validation_loss
self.counter = 0
elif validation_loss > (self.min_validation_loss + self.min_delta):
self.counter += 1
if self.counter >= self.patience:
return True
return False
Here’s how you’d use it:
early_stopper = EarlyStopper(patience=3, min_delta=10)
for epoch in np.arange(n_epochs):
train_loss = train_one_epoch(model, train_loader)
validation_loss = validate_one_epoch(model, validation_loader)
if early_stopper.early_stop(validation_loss):
break