import torch
import numpy as np
import pandas as pd
import time
from apex import amp
[docs]class NNTrainer():
"""
NNTrainer is a trainer class provided with synicix_ml_pipeline. It handles training and saving/loading of models with some additional useful features.
One of the most notable features it provide is integrated fp16 or mixed precision training which can drastically improve your network training speed at little to no loss in performance
This is provided via Nvidia's the apex libaray: https://github.com/NVIDIA/apex
The other feature is a built in early stoping algorithm which was implemented base off this paper: https://page.mi.fu-berlin.de/prechelt/Biblio/stop_tricks1997.pdf
"""
def __init__(self,
train_dataloader,
validation_dataloader,
test_dataloader,
device,
model_class,
model_class_params,
optimizer_class,
optimizer_class_params,
criterion_class,
criterion_class_params,
model_save_path,
max_epochs,
training_strip_length,
sucessive_validation_score_decline_tolerance,
fp16=False,
fp16_opt_level=None):
"""
Initialize function for NNTrainer
Parameters:
train_dataloader: Train Pytorch DataLoader
validation_dataloader: Validation Pytorch DataLoader
test_dataloader: Test Pytorch dataloader
device: pytorch device to train the model on
model_class: User defined model_class
model_class_params: Parameters to pass to the model initialization function
optimizer_class: User defined optimizer_class
optimizer_class_params: Parameters to pass to the optimization initialization function
criterion_class: User defined criterion_class
criterion_class_params: Parameters to pass to the optimization initialization function
model_save_path: Path in which to save the models
max_epochs: Maximum number of epochs before training stops
training_strip_length: Training Epoch Window used for calculating the validation coefficient for early stopping
sucessive_validation_score_decline_tolerance: How many successive validation coefficient increase before stopping training
fp16: Whether or not to enable fp16 training
fp16_opt_level: If fp16 is enable, then the opt_level must be one of the following values 'O1', 'O2', 'O3'. More info here: https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
Returns:
None
"""
# Store the dataloaders
self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.test_dataloader = test_dataloader
# Store the fp16 and device
self.fp16 = fp16
self.device = device
# Check if the device is cpu and if so make sure that fp16 is disable otherwise it will blow up
if device == torch.device('cpu'):
if fp16:
raise(Exception('Cannot use fp16 on CPU Training'))
# Confirm that fp16_opt_level is defined if fp_16 is true, else set it to O0
if self.fp16:
if fp16_opt_level not in ['O1', 'O2', 'O3']:
raise Exception('Invalid fp16_opt_level ' + str(fp16_opt_level) + ' was passed')
else:
fp16_opt_level = 'O0'
# Build the model and optimizer
self.model = model_class(**model_class_params).to(device)
self.optimizer = optimizer_class(self.model.parameters(), **optimizer_class_params)
# Pass the model and optimizer into amp for initialization
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=fp16_opt_level)
# Print the model
print(self.model)
# Print out the current device being used
print('Using Device:', device)
# Build the criterion
self.criterion = criterion_class(**criterion_class_params)
# Store variables and initalize certain ones for early stopping algorithm
self.max_epochs = max_epochs
self.training_strip_length = training_strip_length
self.validation_stop_coefficients = []
self.num_early_stopping_violations = 0
self.sucessive_validation_score_decline_tolerance = sucessive_validation_score_decline_tolerance
# Initialize variables for recording stuff
self.training_epoch_loss_history = []
self.regularization_loss_history = []
self.validation_epoch_loss_history = []
self.test_epoch_loss_history = []
self.test_score = 0.0
self.smallest_validation_epoch_loss = None
self.model_class_params_history_dict = dict()
# Store the model_save_path
self.model_save_path = model_save_path
# Initalized self.model_param_history for each set of parameter in the model
for name, param in self.model.named_parameters():
self.model_class_params_history_dict[name + '_mean'] = list()
self.model_class_params_history_dict[name + '_var'] = list()
self._record_model_class_params()
[docs] def train(self):
"""
Function to start the training
Parameters:
None
Returns:
None
"""
# Set the model to train and save the initial model
self.model.train()
self.save_model()
# Begin train loop
for epoch in range(self.max_epochs):
previous_time = time.time()
# Initilized list to store the loss and regularization values for the curernt epoch
current_epoch_loss_history = []
current_epoch_regularization_loss_history = []
# Train loop for 1 epoch
for data in self.train_dataloader:
# Call train step which will return batch_loss and regularization_loss
batch_loss, regularization_loss = self.train_step(data)
# Append to current epoch loss history
current_epoch_loss_history.append(batch_loss)
current_epoch_regularization_loss_history.append(regularization_loss)
# Print out current epoch loss
current_epoch_loss = np.array(current_epoch_loss_history).mean()
current_epoch_regularization_loss = np.array(current_epoch_regularization_loss_history).mean()
print("Epoch ", (epoch + 1), " loss: ", current_epoch_loss, ' Elapse Time: ', time.time() - previous_time)
# Record epoch loss
self.training_epoch_loss_history.append(current_epoch_loss)
self.regularization_loss_history.append(current_epoch_regularization_loss)
# If the current epoch is divisable by the training_strip_length or is the very first epoch then run early stopping check
if (epoch + 1) % self.training_strip_length == 0 or epoch == 0:
# record the model params at the checkpoint
self._record_model_class_params()
# Run the validation dataset and append it it to the validation_epoch_loss_history
self.validation_epoch_loss_history.append(self.validate())
# Run the test datasaet and append it to the test_epoch_loss_history
self.test_epoch_loss_history.append(self.evaluate())
# Print current Validation Score and Test Score at the current epoch
print('Validation Score at Epoch', (epoch + 1), ':', self.validation_epoch_loss_history[-1])
print('Test Score at Epoch', (epoch + 1), ':', self.test_epoch_loss_history[-1])
# Check for nans
if np.isnan(self.validation_epoch_loss_history[-1]).any():
raise Exception('Nan occur during validation, exiting due to unstabilty')
if np.isnan(self.test_epoch_loss_history[-1]):
raise Exception('Nan occur during test set, exiting due to unstabilty')
# Check if we should stop based on the validation dataset results
if self.should_early_stop(epoch):
print('Max number of validation violations reached, stopping training early')
break
# Check for nans
if np.isnan(current_epoch_loss_history).any() or np.isnan(current_epoch_regularization_loss_history).any():
raise Exception('Nan occur during training, exiting due to unstabilty')
# Compute the test score based off the testdataset
self.load_best_performing_model()
self.test_score = self.evaluate()
# Check if test_score is valid
if np.isnan(self.test_score):
raise Exception('Nan occur during test set, exiting due to unstabilty')
# Print test_score
print('Test Set Performance:', self.test_score)
[docs] def validate(self):
"""
Run the validation dataset through the model and return the average loss
Parameters:
None
Returns:
float: model's validation dataset loss
"""
# Set the model to eval mode
self.model.eval()
# Create a list to store the batch_loss for the current epoch
validation_current_epoch_loss = []
for data in self.validation_dataloader:
batch_loss = self.eval_step(data)
# Append to current epoch loss history
validation_current_epoch_loss.append(batch_loss)
self.model.train()
return np.array(validation_current_epoch_loss).mean()
[docs] def evaluate(self, return_outputs_targets_and_loss=False):
"""
Run the test dataset through the model and return the average loss
Parameters:
return_outputs_targets_and_loss (bool): If true then it will return outputs_targets_and_losses as a dict
Returns:
float: Model's test dataset loss if return_outputs_targets_and_loss False
dict(outputs, target, loss): If return_outputs_targets_and_loss True
"""
self.model.eval()
eval_step_returns = []
for data in self.test_dataloader:
# Eval step will return batch loss if return_outputs_targets_and_loss is false otherwise
# it will return a dict(outputs, targets, loss)
eval_step_return = self.eval_step(data, return_outputs_targets_and_loss)
eval_step_returns.append(eval_step_return)
self.model.train()
if return_outputs_targets_and_loss:
return eval_step_returns
else:
test_score = np.array(eval_step_returns).mean()
return test_score
[docs] def train_step(self, data):
"""
Training step function to handle training of one batch
Parameters:
data (tuple): Data point sample obtain from pytorch dataloader
Returns:
pytorch.loss: Model's loss for the given batch
Tensor: Regularlization loss for the given batch
"""
tensors = data[0]
targets = data[-1]
# Send to device
tensors = self._move_tensor_to_device(tensors)
targets = self._move_tensor_to_device(targets)
outputs, regularization_loss = self.model(tensors)
loss = self._compute_loss(outputs, targets)
loss_with_reg = loss + regularization_loss
with amp.scale_loss(loss_with_reg, self.optimizer) as scale_loss:
scale_loss.backward()
self.optimizer.step()
# Zero out gradients
self.optimizer.zero_grad()
return loss.item(), regularization_loss.item()
[docs] def eval_step(self, data, return_outputs_targets_and_loss=False):
"""
Evluate step function to handle training of one batch
Parameters:
data (tuple): Data point sample obtain from pytorch dataloader
return_outputs_targets_and_loss If true then it will return outputs_targets_and_losses as a dict
Returns:
float: Model's test dataset loss if return_outputs_targets_and_loss False
dict(outputs, target, loss): If return_outputs_targets_and_loss True
"""
with torch.no_grad():
tensors = data[0]
targets = data[-1]
# Send to device
tensors = self._move_tensor_to_device(tensors)
targets = self._move_tensor_to_device(targets)
outputs, _ = self.model(tensors)
loss = self._compute_loss(outputs, targets)
if return_outputs_targets_and_loss:
return dict(outputs=outputs, targets=targets, loss=loss)
else:
return loss.item()
[docs] def should_early_stop(self, current_epoch):
"""
Checks wheter to early stop or not based on the the previous validation score and training strip performance
Parameters:
current_epoch (int): Current epoch, mainly use for printing
Returns:
bool: Whether to early stop or not
"""
# If current validation loss is lower than the min then update min
if self.smallest_validation_epoch_loss == None or self.validation_epoch_loss_history[-1] < self.smallest_validation_epoch_loss:
self.smallest_validation_epoch_loss = self.validation_epoch_loss_history[-1]
# Save the best performing model so far
self.save_model()
# Compute Generalization Loss
generalization_loss = 100 * (self.validation_epoch_loss_history[-1] / self.smallest_validation_epoch_loss - 1)
strip_loss_avg_over_min = 1000 * (np.array(self.training_epoch_loss_history[current_epoch - self.training_strip_length : current_epoch]).sum() / (self.training_strip_length * np.array(self.training_epoch_loss_history[current_epoch:current_epoch + self.training_strip_length]).min()) - 1)
self.validation_stop_coefficients.append(generalization_loss / strip_loss_avg_over_min)
print('Epoch', (current_epoch + 1), 'validation_coefficient:', self.validation_stop_coefficients[-1])
if len(self.validation_stop_coefficients) == 1:
return False
elif self.validation_stop_coefficients[-1] > self.validation_stop_coefficients[-2]:
self.num_early_stopping_violations += 1
else:
# Reset counter
self.num_early_stopping_violations = 0
if self.num_early_stopping_violations > self.sucessive_validation_score_decline_tolerance:
return True
else:
return False
def _record_model_class_params(self):
"""
Helper function to record the mean and variance of the current model params
Parameters:
current_epoch (int): Current epoch, mainly use for printing
Returns:
bool: Whether to early stop or not
"""
for name, param in self.model.named_parameters():
self.model_class_params_history_dict[name+ '_mean'].append(param.detach().cpu().float().numpy().mean().tolist())
self.model_class_params_history_dict[name+ '_var'].append(param.detach().cpu().float().numpy().var().tolist())
[docs] def save_model(self):
"""
Save model optmizer, and amp to model_save_path
Parameters:
None
Returns:
None
"""
print('Saving Model and Optimizer To', self.model_save_path)
checkpoint = dict()
checkpoint['model'] = self.model.state_dict()
checkpoint['optimizer'] = self.optimizer.state_dict()
checkpoint['amp'] = amp.state_dict()
try:
torch.save(checkpoint, self.model_save_path)
except FileNotFoundError:
file = open(self.model_save_path)
file.close()
torch.save(checkpoint, self.model_save_path)
[docs] def load_model(self, path):
"""
Load model, optmizer, and amp from path
Parameters:
path (str): path of where to load model file from
Returns:
None
"""
print('Loading Model and Optimizer From', path)
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
def _compute_loss(self, outputs, targets):
"""
Helper function to deal with mutiple outputs and target losses
Parameters:
outputs (torch.Tensor): Model output(s)
targets (torch.Tensor): Target(s) to compute the model output(s) against
Returns:
torch.Loss: Return the total loss
"""
if type(targets) == list:
for output, target in zip(outputs, targets):
loss = torch.zeros(1).to(self.device)
loss += self.criterion(output, target)
else:
loss = self.criterion(outputs, targets)
return loss
def _move_tensor_to_device(self, tensors):
"""
Helper function to deal with moving multiple tensors to the device
Parameters:
tensors (torch.Tensor): Tensor or list of Tensor to move to the device
Returns:
torch.Tensor: Return a copy of the Tensor that has been move to the device
"""
if type(tensors) == list:
tensors = [input.to(self.device) for input in tensors]
else:
tensors = tensors.to(self.device)
return tensors