How will I stop Keras Training when the accuracy already reached 1.0? I tried monitoring loss value, but I haven't tried stopping the training when the accuracy is already 1.
I tried the code below with no luck:
stopping_criterions =[
EarlyStopping(monitor='loss', min_delta=0, patience = 1000),
EarlyStopping(monitor='acc', base_line=1.0, patience =0)
]
model.summary()
model.compile(Adam(), loss='binary_crossentropy', metrics=['accuracy'])
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[stopping_criterions], shuffle = True, verbose=2)
UPDATE:
The training immediately stops at first epoch, even if the accuracy is still not 1.0.
Please help.
Update: tested in keras 2.4.3 (Dec.2020)
I don't know why EarlyStopping does not work in this case. Instead, I defined a custom callback that stops training when acc (or val_acc) reaches a specified baseline:
from keras.callbacks import Callback
class TerminateOnBaseline(Callback):
"""Callback that terminates training when either acc or val_acc reaches a specified baseline
"""
def __init__(self, monitor='accuracy', baseline=0.9):
super(TerminateOnBaseline, self).__init__()
self.monitor = monitor
self.baseline = baseline
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
acc = logs.get(self.monitor)
if acc is not None:
if acc >= self.baseline:
print('Epoch %d: Reached baseline, terminating training' % (epoch))
self.model.stop_training = True
You can use it like this:
callbacks = [TerminateOnBaseline(monitor='accuracy', baseline=0.8)]
callbacks = [TerminateOnBaseline(monitor='val_accuracy', baseline=0.95)]
Note: This solution does not work.
If you want to stop training when the training (or validation) accuracy exactly reaches 100%, then use EarlyStopping callback and set the baseline argument to 1.0 and patience to zero:
EarlyStopping(monitor='acc', baseline=1.0, patience=0) # use 'val_acc' instead to monitor validation accuarcy
Using an EarlyStopping with baseline callback doesn't do the trick here as far as I know.
'Baseline' is the min value of the monitored variable(here accuracy) that you should get to continue the training. Here the Baseline is 1.0, at the end of first epoch the Baseline is less than 'accuracy'(obviously you can't expect 'accuracy' of 1.0 at the first epoch itself) and since the patience is set to zero, the training stops at the first epoch itself since the Baseline is greater than accuracy.
Using a custom callback does the work here.
class MyThresholdCallback(tf.keras.callbacks.Callback):
def __init__(self, threshold):
super(MyThresholdCallback, self).__init__()
self.threshold = threshold
def on_epoch_end(self, epoch, logs=None):
accuracy = logs["acc"]
if accuracy >= self.threshold:
self.model.stop_training = True
And calling the callback in the model.fit
callback=MyThresholdCallback(threshold=1.0)
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[callback], shuffle = True, verbose=2)
The name baseline is misleading. Although not easy to interpret from the source code below, baseline should be understood as:
While the monitored value is worse1 than the baseline, keep training for max patience epochs longer. If it's better, elevate the baseline and repeat.
1 i.e. lower for accuracies, higher for losses.
The relevant (trimmed) source code of EarlyStopping:
self.best = baseline # in initialization
...
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if self.monitor_op(current - self.min_delta, self.best): # read as `current > self.best` (for accuracy)
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.model.stop_training = True
Then your example
EarlyStopping(monitor='acc', base_line=1.0, patience=0) means: while the monitored value is worse than 1.0 (which it always is), keep training for 0 epochs longer (i.e. terminate immediately).
If you want these semantics:
While the monitored value is worse than the baseline, keep training. If it is better, keep training until no progress is made for patience consecutive epochs, and also retain all features of EarlyStopping, may I suggest this:
class MyEarlyStopping(EarlyStopping):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.baseline_attained = False
def on_epoch_end(self, epoch, logs=None):
if not self.baseline_attained:
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current, self.baseline):
if self.verbose > 0:
print('Baseline attained.')
self.baseline_attained = True
else:
return
super(MyEarlyStopping, self).on_epoch_end(epoch, logs)
Related
I am training a single-layer neural network using PyTorch and saving the model after the validation loss decreases. Once the network has finished training, I load the saved model and pass my test set features through that (rather than the model from the last epoch) to see how well it does. However, more often that not, the validation loss will stop decreasing after about 150 epochs, and I'm worried that the network is overfitting the data. Would it be better for me to load the saved model during training if the validation loss has not decreased for some number of iterations (say, after 5 epochs), and then train on that saved model instead?
Also, are there any recommendations for how to avoid a situation where the validation loss stops decreasing? I've had some models where the validation loss continues to decrease even after 500 epochs and others where it stops decreasing after 100. Here is my code so far:
class NeuralNetwork(nn.Module):
def __init__(self, input_dim, output_dim, nodes):
super(NeuralNetwork, self).__init__()
self.linear1 = nn.Linear(input_dim, nodes)
self.tanh = nn.Tanh()
self.linear2 = nn.Linear(nodes, output_dim)
def forward(self, x):
output = self.linear1(x)
output = self.tanh(output)
output = self.linear2(output)
return output
epochs = 500 # (start small for now)
learning_rate = 0.01
w_decay = 0.1
momentum = 0.9
input_dim = 4
output_dim = 1
nodes = 8
model = NeuralNetwork(input_dim, output_dim, nodes)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=w_decay)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5)
losses = []
val_losses = []
min_validation_loss = np.inf
means = [] # we want to store the mean and standard deviation for the test set later
stdevs = []
torch.save({
'epoch': 0,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'training_loss': 0.0,
'validation_loss': 0.0,
'means': [],
'stdevs': [],
}, new_model_path)
new_model_saved = True
for epoch in range(epochs):
curr_loss = 0.0
validation_loss = 0.0
if new_model_saved:
checkpoint = torch.load(new_model_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
means = checkpoint['means']
stdevs = checkpoint['stdevs']
new_model_saved = False
model.train()
for i, batch in enumerate(train_dataloader):
x, y = batch
x, new_mean, new_std = normalize_data(x, means, stdevs)
means = new_mean
stdevs = new_std
optimizer.zero_grad()
predicted_outputs = model(x)
loss = criterion(torch.squeeze(predicted_outputs), y)
loss.backward()
optimizer.step()
curr_loss += loss.item()
model.eval()
for x_val, y_val in val_dataloader:
x_val, val_means, val_std = normalize_data(x_val, means, stdevs)
predicted_y = model(x_val)
loss = criterion(torch.squeeze(predicted_y), y_val)
validation_loss += loss.item()
curr_lr = optimizer.param_groups[0]['lr']
if epoch % 10 == 0:
print(f'Epoch {epoch} \t\t Training Loss: {curr_loss/len(train_dataloader)} \t\t Validation Loss: {validation_loss/len(val_dataloader)} \t\t Learning rate: {curr_lr}')
if min_validation_loss > validation_loss:
print(f' For epoch {epoch}, validation loss decreased ({min_validation_loss:.6f}--->{validation_loss:.6f}) \t learning rate: {curr_lr} \t saving the model')
min_validation_loss = validation_loss
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'training_loss': curr_loss/len(train_dataloader),
'validation_loss': validation_loss/len(val_dataloader),
'means': means,
'stdevs': stdevs
}, new_model_path)
new_model_saved = True
losses.append(curr_loss/len(train_dataloader))
val_losses.append(validation_loss/len(val_dataloader))
scheduler.step(curr_loss/len(train_dataloader))
The phenomenon of the validation loss increases whereas the training loss decreases is called overfitting. Overfitting is a problem when training a model and should be avoided. please read more on this topic here. Overfitting may occur after any number of epochs and id dependent on a lot of variables(learning rate, database side, database diversity and more). as a rule of thumb, test your model at the "pivot point", i.e. exactly where the validation loss begins to increase (and the training continues to decrease). This means that my recommendation is to save the model after each iteration where the validation loss decreases. If it keeps increasing after any X number of epochs, it probably means that you reach a "deep" minimum for the loss and it will not be beneficial to keep training (again, this has some exceptions but for this level of discussion it is enough)
I encourage you to read and learn more about this subject, It is very interesting and has significant implications.
I working my way through using Keras Modeling and think I have now sussed how to use the callback feature to trap the best fit and prevent overfitting; all seems good. Whilst I can understand the verbose parameter will display the information I require it makes the output messy and I prefer to set this to zero. I would though still like to somehow capture the "epoch" count that gave the best result to incorporate into my own display; is there some way I can get at this? Thanks
model.compile(optimizer='adam', loss='mse' )]
cbfile = 'best_model.h5'
calls = [
EarlyStopping(monitor='val_loss', mode='auto', verbose=0, patience=10),\
ModelCheckpoint(cbfile, monitor = 'val_loss', mode = 'auto',\
save_best_only = True ) ]
history = model.fit(Xvect, Yvect, epochs=mcycl, batch_size=32,\
validation_split=dsplit, verbose=0, callbacks = calls )
saved = load_model('best_model.h5')
score = saved.evaluate(Xvect, Yvect, verbose=0)
print('"Overall loss for best fit":',np.round(score,4))
How about writing your own custom EarlyStopping callback? The Tensorflow docs provide a very good example of how you could get started:
import numpy as np
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
Note the self.stopped_epoch variable in the example. This way you have full control of what you display and how your early stopping logic works. Furthermore, using the logs dictionary, you can access your current loss and accuracy for epoch x. On the other hand, if you just want to use a simple print statement after training your model, you could just get the last epoch of your callback and print it:
model.compile(optimizer='adam', loss='mse' )]
cbfile = 'best_model.h5'
early_stopping = EarlyStopping(monitor='val_loss', mode='auto', verbose=0, patience=10)
calls = [early_stopping,
ModelCheckpoint(cbfile, monitor = 'val_loss', mode = 'auto',\
save_best_only = True ) ]
history = model.fit(Xvect, Yvect, epochs=mcycl, batch_size=32,\
validation_split=dsplit, verbose=0, callbacks = calls )
saved = load_model('best_model.h5')
score = saved.evaluate(Xvect, Yvect, verbose=0)
print('"Overall loss for best fit":',np.round(score,4))
print("Epoch %05d: early stopping" % (early_stopping.stopped_epoch + 1))
Usually, we can define a callback for a model to stop the epoch if the accuracy reaches a certain level.
I am working on the adjustment of parameters. The val_acc is highly unstable as shown in the picture.
def LSTM_model(X_train, y_train, X_test, y_test, num_classes, batch_size=68, units=128, learning_rate=0.005, epochs=20,
dropout=0.2, recurrent_dropout=0.2):
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs={}):
if (logs.get('acc') > 0.90):
print("\nReached 90% accuracy so cancelling training!")
self.model.stop_training = True
callbacks = myCallback()
As the graphs show that the val_acc(orange) is fluctuating within a range and not really going up anymore.
Is there a way to automatically stop the training once the general trend of the val_acc stops increasing?
You can achieve this with a callback like this
class terminate_on_plateau(keras.callbacks.Callback):
def __init__(self):
self.patience = 10
self.val_loss = deque([],self.patience)
self.std_threshold = 1e-2
def on_epoch_end(self,epoch,logs=None):
val_loss,val_mae = model.evaluate(x_val,y_val)
self.val_loss.append(val_loss)
if len(self.val_loss) >= self.patience:
std = np.std(self.val_loss)
if std < self.std_threshold:
print('\n\n EarlyStopping on std invoked! \n\n')
# clear the deque
self.val_loss = deque([],self.patience)
model.stop_training = True
As you can see, in terminate_on_plateau, val_loss of epochs are stored in a deque of max length self.patience. Once the length of the deque reaches self.patience, standard deviation of the val_loss will be calculated for every new epoch, and the training process will be terminated (the deque of val_loss will also be cleared), if the calculated std is smaller than a threshold.
Below is a simple script that shows you how to use this
from collections import deque
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input,Dense
x = np.linspace(0,10,1000)
np.random.shuffle(x)
y = np.sin(x) + x
x_train,x_val,y_train,y_val = train_test_split(x,y,test_size=0.3)
input_x = Input(shape=(1,))
y = Dense(10,activation='relu')(input_x)
y = Dense(10,activation='relu')(y)
y = Dense(1,activation='relu')(y)
model = Model(inputs=input_x,outputs=y)
adamopt = tf.keras.optimizers.Adam(lr=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-8)
class terminate_on_plateau(keras.callbacks.Callback):
def __init__(self):
self.patience = 10
self.val_loss = deque([],self.patience)
self.std_threshold = 1e-2
def on_epoch_end(self,epoch,logs=None):
val_loss,val_mae = model.evaluate(x_val,y_val)
self.val_loss.append(val_loss)
if len(self.val_loss) >= self.patience:
std = np.std(self.val_loss)
if std < self.std_threshold:
print('\n\n EarlyStopping on std invoked! \n\n')
# clear the deque
self.val_loss = deque([],self.patience)
model.stop_training = True
model.compile(loss='mse',optimizer=adamopt,metrics=['mae'])
history = model.fit(x_train,y_train,
batch_size=8,
epochs=100,
validation_data=(x_val, y_val),
verbose=1,
callbacks=[terminate_on_plateau()])
The code below is for a custom callback that will stop training when the quantity being monitored fails to improve after patience number of epochs. Set the parameter acc_or_loss to 'loss' in order to monitor validation loss. Set it to 'acc' to monitor validation accuracy. I recommend NOT to monitor validation accuracy as it can swing wildly particularly in the early epochs. I put in print statements so you can see what is going on during training. You can of course remove them later. If you are monitoring validation loss the call back halts training if for a patience number of epochs the validation loss has exceeded the lowest loss found in the previous epochs. If you are monitoring validation accuracy the callback halts training if for a patience number of epochs the validation accuracy has stayed below the highest validation accuracy recorded in the previous epochs
class halt(keras.callbacks.Callback):
def __init__(self, patience, acc_or_loss):
self.acc_or_loss=acc_or_loss
super(halt, self).__init__()
self.patience=patience # specifies how many epochs without improvement before learning rate is adjusted
self.lowest_loss=np.inf
self.highest_acc=0
self.count=0
print ('initializing values ', 'count= ', self.count, ' lowest_loss= ', self.lowest_loss, 'highest acc= ', self.highest_acc)
def on_epoch_end(self, epoch, logs=None):
v_loss=logs.get('val_loss') # get the validation loss for this epoch
v_acc=logs.get('val_accuracy')
if self.acc_or_loss=='loss':
print (' for epoch ', epoch +1, ' v_loss= ', v_loss, ' lowest_loss= ', self.lowest_loss, 'count= ', self.count)
if v_loss< self.lowest_loss:
self.lowest_loss=v_loss
self.count=0
else:
self.count=self.count +1
if self.count>=self.patience:
print('There have been ', self.patience, ' epochs with no reduction of validation loss below the lowest loss')
print ('Terminating training')
self.model.stop_training = True
else:
print (' for epoch ', epoch +1, ' v_acc= ', v_acc, ' highest accuracy= ', self.highest_acc, 'count= ', self.count)
if v_acc>self.highest_acc:
self.count=0
self.highest_acc=v_acc
else:
self.count=self.count +1
if self.count>=self.patience:
print('There have been ', self.patience, ' epochs with noincrease in validation accuracy')
print ('Terminating training')
self.model.stop_training = True
patience= 2 # specify the patience value
acc_or_loss='loss' # specify to monitor validation loss or validation accuracy
callbacks=[halt(patience=patience, acc_or_loss=acc_or_loss)]
# in model.fit include callbacks=callbacks
Or you can just use Keras API in tensorflow : tf.keras.callbacks.EarlyStopping
Given your initial question I'm not sure why you would need custom callbacks
Here is an example of application:
history = model.fit([trainX,trainX,trainX],
np.array(trainLabels),
validation_data = ([testX, testX, testX], np.array(testLabels)),
epochs=EPOCH,
batch_size=BATCH_SIZE,
steps_per_epoch = None,
callbacks=[tf.keras.callbacks.EarlyStopping(
monitor="val_acc",
patience=5,
mode="min",
restore_best_weights = True)])
Some of the above answers are a little complex, you can use the below code.
opt = tf.optimizers.Adadelta(learning_rate=0.01)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=["accuracy"])
es = EarlyStopping(monitor='val_accuracy', mode='max', patience=20)
# will stop if validation accuracy is not improving till 20 epoches, you can give any number in patience.
ms = ModelCheckpoint('save_model.h5', monitor='val_accuracy', mode='max', save_best_only=True)
training_history = model.fit(x=X_train, y=y_train, validation_split=0.1, batch_size=5, epochs=1000, verbose=1,
callbacks = [es, ms])
I just copied this code from my project, which is not for LSTM, you can adjust this code according to your problem/task.
I am implementing a decaying learning rate based on accuracy from the previous epoch.
Capturing Metrics:
class CustomMetrics(tf.keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.metrics={'loss': [],'accuracy': [],'val_loss': [],'val_accuracy': []}
self.lr=[]
def on_epoch_end(self, epoch, logs={}):
print(f"\nEPOCH {epoch} Callng from METRICS CLASS")
self.metrics['loss'].append(logs.get('loss'))
self.metrics['accuracy'].append(logs.get('accuracy'))
self.metrics['val_loss'].append(logs.get('val_loss'))
self.metrics['val_accuracy'].append(logs.get('val_accuracy'))
Custom Learning Decay:
from tensorflow.keras.callbacks import LearningRateScheduler
def changeLearningRate(epoch):
initial_learningrate=0.1
#print(f"EPOCH {epoch}, Calling from ChangeLearningRate:")
lr = 0.0
if epoch != 0:
if custom_metrics_dict.metrics['accuracy'][epoch] < custom_metrics_dict.metrics['accuracy'][epoch-1]:
print(f"Accuracy # epoch {epoch} is less than acuracy at epoch {epoch-1}")
print("[INFO] Decreasing Learning Rate.....")
lr = initial_learningrate*(0.1)
print(f"LR Changed to {lr}")
return lr
Model Preparation:
input_layer = Input(shape=(2))
layer1 = Dense(32,activation='tanh',kernel_initializer=tf.random_uniform_initializer(0,1,seed=30))(input_layer)
output = Dense(2,activation='softmax',kernel_initializer=tf.random_uniform_initializer(0,1,seed=30))(layer1)
model = Model(inputs=input_layer,outputs=output)
custom_metrics_dict=CustomMetrics()
lrschedule = LearningRateScheduler(changeLearningRate, verbose=1)
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1,momentum=0.9)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train,Y_train,epochs=4, validation_data=(X_test,Y_test), batch_size=16 ,callbacks=[custom_metrics_dict,lrschedule])
It's erroring out with index out of range error. From what I noticed, per epoch, LRScheduler code is being called more than once. I am unable to figure a way to make appropriate function calls. What can I try next?
The signature of the scheduler function is def scheduler(epoch, lr): which means you should take the lr from that parameter.
You shouldn't write the initial_learningrate = 0.1, if you do that your lr will not decay, you will always return the same when the accuracy decrease.
For the out of range exception you check than epoch is not 0, which means than at for epoch = 1 you're checking custom_metrics_dict.metrics['accuracy'][epoch] and custom_metrics_dict.metrics['accuracy'][epoch-1], but you stored only one accuracy value, epoch 0 has no accuracy value so this array custom_metrics_dict.metrics['accuracy'] has only one value in it
I've run your code correctly with this function
from tensorflow.keras.callbacks import LearningRateScheduler
def changeLearningRate(epoch, lr):
print(f"EPOCH {epoch}, Calling from ChangeLearningRate: {custom_metrics_dict.metrics['accuracy']}")
if epoch > 1:
if custom_metrics_dict.metrics['accuracy'][epoch - 1] > custom_metrics_dict.metrics['accuracy'][epoch-2]:
print(f"Accuracy # epoch {epoch} is less than acuracy at epoch {epoch-1}")
print("[INFO] Decreasing Learning Rate.....")
lr = lr*(0.1)
print(f"LR Changed to {lr}")
return lr
I designed a network for a text classification problem. To do this, I'm using huggingface transformet's BERT model with a linear layer above that for fine-tuning. My problem is that the loss on the training set is decreasing which is fine, but when it comes to do the evaluation after each epoch on the development set, the loss is increasing with epochs. I'm posting my code to investigate if there's something wrong with it.
for epoch in range(1, args.epochs + 1):
total_train_loss = 0
trainer.set_train()
for step, batch in enumerate(train_dataloader):
loss = trainer.step(batch)
total_train_loss += loss
avg_train_loss = total_train_loss / len(train_dataloader)
logger.info(('Training loss for epoch %d/%d: %4.2f') % (epoch, args.epochs, avg_train_loss))
print("\n-------------------------------")
logger.info('Start validation ...')
trainer.set_eval()
y_hat = list()
y = list()
total_dev_loss = 0
for step, batch_val in enumerate(dev_dataloader):
true_labels_ids, predicted_labels_ids, loss = trainer.validate(batch_val)
total_dev_loss += loss
y.extend(true_labels_ids)
y_hat.extend(predicted_labels_ids)
avg_dev_loss = total_dev_loss / len(dev_dataloader)
print(("\n-Total dev loss: %4.2f on epoch %d/%d\n") % (avg_dev_loss, epoch, args.epochs))
print("Training terminated!")
Following is the trainer file, which I use for doing a forward pass on a given batch and then backpropagate accordingly.
class Trainer(object):
def __init__(self, args, model, device, data_points, is_test=False, train_stats=None):
self.args = args
self.model = model
self.device = device
self.loss = nn.CrossEntropyLoss(reduction='none')
if is_test:
# Should load the model from checkpoint
self.model.eval()
self.model.load_state_dict(torch.load(args.saved_model))
logger.info('Loaded saved model from %s' % args.saved_model)
else:
self.model.train()
self.optim = AdamW(model.parameters(), lr=2e-5, eps=1e-8)
total_steps = data_points * self.args.epochs
self.scheduler = get_linear_schedule_with_warmup(self.optim, num_warmup_steps=0,
num_training_steps=total_steps)
def step(self, batch):
batch = tuple(t.to(self.device) for t in batch)
batch_input_ids, batch_input_masks, batch_labels = batch
self.model.zero_grad()
outputs = self.model(batch_input_ids,
attention_mask=batch_input_masks,
labels=batch_labels)
loss = self.loss(outputs, batch_labels)
loss = loss.sum()
(loss / loss.numel()).backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optim.step()
self.scheduler.step()
return loss
def validate(self, batch):
batch = tuple(t.to(self.device) for t in batch)
batch_input_ids, batch_input_masks, batch_labels = batch
with torch.no_grad():
model_output = self.model(batch_input_ids,
attention_mask=batch_input_masks,
labels=batch_labels)
predicted_label_ids = self._predict(model_output)
label_ids = batch_labels.to('cpu').numpy()
loss = self.loss(model_output, batch_labels)
loss = loss.sum()
return label_ids, predicted_label_ids, loss
def _predict(self, logits):
return np.argmax(logits.to('cpu').numpy(), axis=1)
Finally, the following is my model (i.e., Classifier) class:
import torch.nn as nn
from transformers import BertModel
class Classifier(nn.Module):
def __init__(self, args, is_eval=False):
super(Classifier, self).__init__()
self.bert_model = BertModel.from_pretrained(
args.init_checkpoint,
output_attentions=False,
output_hidden_states=True,
)
self.is_eval_mode = is_eval
self.linear = nn.Linear(768, 2) # binary classification
def switch_state(self):
self.is_eval_mode = not self.is_eval_mode
def forward(self, input_ids, attention_mask=None, labels=None):
bert_outputs = self.bert_model(input_ids,
token_type_ids=None,
attention_mask=attention_mask)
# Should give the logits to the the linear layer
model_output = self.linear(bert_outputs[1])
return model_output
For visualization the loss throughout the epochs:
When I've used Bert for text classification my model has generally behaved as you tell. In part this is expected because pre-trained models tend to require few epochs to fine-tune, actually if you check Bert's paper the number of epochs recommended for fine-tuning is between 2 and 4.
On the other hand, I've usually found the optimum at just 1 or 2 epochs, which coincides with your case also. My guess is: there is a trade-off when fine-tuning pre-trained models between fitting to your downstream task and forgetting the weights learned at pre-training. Depending on the data you have, the equilibrium point may happen sooner or later and overfitting starts after that. But this paragraph is speculation based on my experience.
When validation loss increases it means your model is overfitting