Source code for synicix_ml_pipeline.trainers.NNTrainer

import torch
import numpy as np
import pandas as pd
import time

from apex import amp

[docs]class NNTrainer(): """ NNTrainer is a trainer class provided with synicix_ml_pipeline. It handles training and saving/loading of models with some additional useful features. One of the most notable features it provide is integrated fp16 or mixed precision training which can drastically improve your network training speed at little to no loss in performance This is provided via Nvidia's the apex libaray: https://github.com/NVIDIA/apex The other feature is a built in early stoping algorithm which was implemented base off this paper: https://page.mi.fu-berlin.de/prechelt/Biblio/stop_tricks1997.pdf """ def __init__(self, train_dataloader, validation_dataloader, test_dataloader, device, model_class, model_class_params, optimizer_class, optimizer_class_params, criterion_class, criterion_class_params, model_save_path, max_epochs, training_strip_length, sucessive_validation_score_decline_tolerance, fp16=False, fp16_opt_level=None): """ Initialize function for NNTrainer Parameters: train_dataloader: Train Pytorch DataLoader validation_dataloader: Validation Pytorch DataLoader test_dataloader: Test Pytorch dataloader device: pytorch device to train the model on model_class: User defined model_class model_class_params: Parameters to pass to the model initialization function optimizer_class: User defined optimizer_class optimizer_class_params: Parameters to pass to the optimization initialization function criterion_class: User defined criterion_class criterion_class_params: Parameters to pass to the optimization initialization function model_save_path: Path in which to save the models max_epochs: Maximum number of epochs before training stops training_strip_length: Training Epoch Window used for calculating the validation coefficient for early stopping sucessive_validation_score_decline_tolerance: How many successive validation coefficient increase before stopping training fp16: Whether or not to enable fp16 training fp16_opt_level: If fp16 is enable, then the opt_level must be one of the following values 'O1', 'O2', 'O3'. More info here: https://nvidia.github.io/apex/amp.html#opt-levels-and-properties Returns: None """ # Store the dataloaders self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.test_dataloader = test_dataloader # Store the fp16 and device self.fp16 = fp16 self.device = device # Check if the device is cpu and if so make sure that fp16 is disable otherwise it will blow up if device == torch.device('cpu'): if fp16: raise(Exception('Cannot use fp16 on CPU Training')) # Confirm that fp16_opt_level is defined if fp_16 is true, else set it to O0 if self.fp16: if fp16_opt_level not in ['O1', 'O2', 'O3']: raise Exception('Invalid fp16_opt_level ' + str(fp16_opt_level) + ' was passed') else: fp16_opt_level = 'O0' # Build the model and optimizer self.model = model_class(**model_class_params).to(device) self.optimizer = optimizer_class(self.model.parameters(), **optimizer_class_params) # Pass the model and optimizer into amp for initialization self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level=fp16_opt_level) # Print the model print(self.model) # Print out the current device being used print('Using Device:', device) # Build the criterion self.criterion = criterion_class(**criterion_class_params) # Store variables and initalize certain ones for early stopping algorithm self.max_epochs = max_epochs self.training_strip_length = training_strip_length self.validation_stop_coefficients = [] self.num_early_stopping_violations = 0 self.sucessive_validation_score_decline_tolerance = sucessive_validation_score_decline_tolerance # Initialize variables for recording stuff self.training_epoch_loss_history = [] self.regularization_loss_history = [] self.validation_epoch_loss_history = [] self.test_epoch_loss_history = [] self.test_score = 0.0 self.smallest_validation_epoch_loss = None self.model_class_params_history_dict = dict() # Store the model_save_path self.model_save_path = model_save_path # Initalized self.model_param_history for each set of parameter in the model for name, param in self.model.named_parameters(): self.model_class_params_history_dict[name + '_mean'] = list() self.model_class_params_history_dict[name + '_var'] = list() self._record_model_class_params()
[docs] def train(self): """ Function to start the training Parameters: None Returns: None """ # Set the model to train and save the initial model self.model.train() self.save_model() # Begin train loop for epoch in range(self.max_epochs): previous_time = time.time() # Initilized list to store the loss and regularization values for the curernt epoch current_epoch_loss_history = [] current_epoch_regularization_loss_history = [] # Train loop for 1 epoch for data in self.train_dataloader: # Call train step which will return batch_loss and regularization_loss batch_loss, regularization_loss = self.train_step(data) # Append to current epoch loss history current_epoch_loss_history.append(batch_loss) current_epoch_regularization_loss_history.append(regularization_loss) # Print out current epoch loss current_epoch_loss = np.array(current_epoch_loss_history).mean() current_epoch_regularization_loss = np.array(current_epoch_regularization_loss_history).mean() print("Epoch ", (epoch + 1), " loss: ", current_epoch_loss, ' Elapse Time: ', time.time() - previous_time) # Record epoch loss self.training_epoch_loss_history.append(current_epoch_loss) self.regularization_loss_history.append(current_epoch_regularization_loss) # If the current epoch is divisable by the training_strip_length or is the very first epoch then run early stopping check if (epoch + 1) % self.training_strip_length == 0 or epoch == 0: # record the model params at the checkpoint self._record_model_class_params() # Run the validation dataset and append it it to the validation_epoch_loss_history self.validation_epoch_loss_history.append(self.validate()) # Run the test datasaet and append it to the test_epoch_loss_history self.test_epoch_loss_history.append(self.evaluate()) # Print current Validation Score and Test Score at the current epoch print('Validation Score at Epoch', (epoch + 1), ':', self.validation_epoch_loss_history[-1]) print('Test Score at Epoch', (epoch + 1), ':', self.test_epoch_loss_history[-1]) # Check for nans if np.isnan(self.validation_epoch_loss_history[-1]).any(): raise Exception('Nan occur during validation, exiting due to unstabilty') if np.isnan(self.test_epoch_loss_history[-1]): raise Exception('Nan occur during test set, exiting due to unstabilty') # Check if we should stop based on the validation dataset results if self.should_early_stop(epoch): print('Max number of validation violations reached, stopping training early') break # Check for nans if np.isnan(current_epoch_loss_history).any() or np.isnan(current_epoch_regularization_loss_history).any(): raise Exception('Nan occur during training, exiting due to unstabilty') # Compute the test score based off the testdataset self.load_best_performing_model() self.test_score = self.evaluate() # Check if test_score is valid if np.isnan(self.test_score): raise Exception('Nan occur during test set, exiting due to unstabilty') # Print test_score print('Test Set Performance:', self.test_score)
[docs] def validate(self): """ Run the validation dataset through the model and return the average loss Parameters: None Returns: float: model's validation dataset loss """ # Set the model to eval mode self.model.eval() # Create a list to store the batch_loss for the current epoch validation_current_epoch_loss = [] for data in self.validation_dataloader: batch_loss = self.eval_step(data) # Append to current epoch loss history validation_current_epoch_loss.append(batch_loss) self.model.train() return np.array(validation_current_epoch_loss).mean()
[docs] def evaluate(self, return_outputs_targets_and_loss=False): """ Run the test dataset through the model and return the average loss Parameters: return_outputs_targets_and_loss (bool): If true then it will return outputs_targets_and_losses as a dict Returns: float: Model's test dataset loss if return_outputs_targets_and_loss False dict(outputs, target, loss): If return_outputs_targets_and_loss True """ self.model.eval() eval_step_returns = [] for data in self.test_dataloader: # Eval step will return batch loss if return_outputs_targets_and_loss is false otherwise # it will return a dict(outputs, targets, loss) eval_step_return = self.eval_step(data, return_outputs_targets_and_loss) eval_step_returns.append(eval_step_return) self.model.train() if return_outputs_targets_and_loss: return eval_step_returns else: test_score = np.array(eval_step_returns).mean() return test_score
[docs] def train_step(self, data): """ Training step function to handle training of one batch Parameters: data (tuple): Data point sample obtain from pytorch dataloader Returns: pytorch.loss: Model's loss for the given batch Tensor: Regularlization loss for the given batch """ tensors = data[0] targets = data[-1] # Send to device tensors = self._move_tensor_to_device(tensors) targets = self._move_tensor_to_device(targets) outputs, regularization_loss = self.model(tensors) loss = self._compute_loss(outputs, targets) loss_with_reg = loss + regularization_loss with amp.scale_loss(loss_with_reg, self.optimizer) as scale_loss: scale_loss.backward() self.optimizer.step() # Zero out gradients self.optimizer.zero_grad() return loss.item(), regularization_loss.item()
[docs] def eval_step(self, data, return_outputs_targets_and_loss=False): """ Evluate step function to handle training of one batch Parameters: data (tuple): Data point sample obtain from pytorch dataloader return_outputs_targets_and_loss If true then it will return outputs_targets_and_losses as a dict Returns: float: Model's test dataset loss if return_outputs_targets_and_loss False dict(outputs, target, loss): If return_outputs_targets_and_loss True """ with torch.no_grad(): tensors = data[0] targets = data[-1] # Send to device tensors = self._move_tensor_to_device(tensors) targets = self._move_tensor_to_device(targets) outputs, _ = self.model(tensors) loss = self._compute_loss(outputs, targets) if return_outputs_targets_and_loss: return dict(outputs=outputs, targets=targets, loss=loss) else: return loss.item()
[docs] def should_early_stop(self, current_epoch): """ Checks wheter to early stop or not based on the the previous validation score and training strip performance Parameters: current_epoch (int): Current epoch, mainly use for printing Returns: bool: Whether to early stop or not """ # If current validation loss is lower than the min then update min if self.smallest_validation_epoch_loss == None or self.validation_epoch_loss_history[-1] < self.smallest_validation_epoch_loss: self.smallest_validation_epoch_loss = self.validation_epoch_loss_history[-1] # Save the best performing model so far self.save_model() # Compute Generalization Loss generalization_loss = 100 * (self.validation_epoch_loss_history[-1] / self.smallest_validation_epoch_loss - 1) strip_loss_avg_over_min = 1000 * (np.array(self.training_epoch_loss_history[current_epoch - self.training_strip_length : current_epoch]).sum() / (self.training_strip_length * np.array(self.training_epoch_loss_history[current_epoch:current_epoch + self.training_strip_length]).min()) - 1) self.validation_stop_coefficients.append(generalization_loss / strip_loss_avg_over_min) print('Epoch', (current_epoch + 1), 'validation_coefficient:', self.validation_stop_coefficients[-1]) if len(self.validation_stop_coefficients) == 1: return False elif self.validation_stop_coefficients[-1] > self.validation_stop_coefficients[-2]: self.num_early_stopping_violations += 1 else: # Reset counter self.num_early_stopping_violations = 0 if self.num_early_stopping_violations > self.sucessive_validation_score_decline_tolerance: return True else: return False
def _record_model_class_params(self): """ Helper function to record the mean and variance of the current model params Parameters: current_epoch (int): Current epoch, mainly use for printing Returns: bool: Whether to early stop or not """ for name, param in self.model.named_parameters(): self.model_class_params_history_dict[name+ '_mean'].append(param.detach().cpu().float().numpy().mean().tolist()) self.model_class_params_history_dict[name+ '_var'].append(param.detach().cpu().float().numpy().var().tolist())
[docs] def save_model(self): """ Save model optmizer, and amp to model_save_path Parameters: None Returns: None """ print('Saving Model and Optimizer To', self.model_save_path) checkpoint = dict() checkpoint['model'] = self.model.state_dict() checkpoint['optimizer'] = self.optimizer.state_dict() checkpoint['amp'] = amp.state_dict() try: torch.save(checkpoint, self.model_save_path) except FileNotFoundError: file = open(self.model_save_path) file.close() torch.save(checkpoint, self.model_save_path)
[docs] def load_model(self, path): """ Load model, optmizer, and amp from path Parameters: path (str): path of where to load model file from Returns: None """ print('Loading Model and Optimizer From', path) checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp'])
[docs] def load_best_performing_model(self): """ Load the best performing model which is the previous save file of the lowest validation score form model_save_path Parameters: Returns: None """ print('Loading Model and Optimizer From', self.model_save_path) checkpoint = torch.load(self.model_save_path, map_location=self.device) self.model.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) amp.load_state_dict(checkpoint['amp'])
def _compute_loss(self, outputs, targets): """ Helper function to deal with mutiple outputs and target losses Parameters: outputs (torch.Tensor): Model output(s) targets (torch.Tensor): Target(s) to compute the model output(s) against Returns: torch.Loss: Return the total loss """ if type(targets) == list: for output, target in zip(outputs, targets): loss = torch.zeros(1).to(self.device) loss += self.criterion(output, target) else: loss = self.criterion(outputs, targets) return loss def _move_tensor_to_device(self, tensors): """ Helper function to deal with moving multiple tensors to the device Parameters: tensors (torch.Tensor): Tensor or list of Tensor to move to the device Returns: torch.Tensor: Return a copy of the Tensor that has been move to the device """ if type(tensors) == list: tensors = [input.to(self.device) for input in tensors] else: tensors = tensors.to(self.device) return tensors