Source code for synicix_ml_pipeline.datajoint_tables.DatasetConfig

import datajoint as dj
import os
import shutil
import json
import filecmp

from torch.utils.data import Subset, DataLoader
from torchvision.transforms import Compose

from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable

[docs]@schema class DatasetConfig(dj.Manual, BaseTable): """ A dj.Manual table class that handle the storage dataset configs with details on what dataset class and params to load that dataset and dataloader with. Initializing of this class will require the corresponding dataset_dir and dataset_cache_dir to be defined and passed into the __init__() Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook """ definition = """ dataset_config_id : int unsigned --- dataset_file_name : varchar(256) dataset_type : varchar(256) dataset_class_module_name : varchar(256) dataset_class_name : varchar(256) dataset_class_params : longblob train_sampler_module_name : varchar(256) train_sampler_class_name : varchar(256) train_sampler_class_params : longblob validation_sampler_module_name : varchar(256) validation_sampler_class_name : varchar(256) validation_sampler_class_params : longblob test_sampler_module_name : longblob test_sampler_class_name : varchar(256) test_sampler_class_params : longblob input_shape : longblob output_shape : longblob additional_model_params : longblob dataset_config_blobs_md5_hash : char(128) """ def __init__(self, dataset_dir=None, dataset_cache_dir=None): """ Initilize function Parameters: dataset_dir (str): directory of where the dataset file are located dataset_cache_dir (str): directory of where to cache the dataset files Returns: None """ super().__init__() self.dataset_dir = dataset_dir self.dataset_cache_dir = dataset_cache_dir
[docs] def insert_tuples(self, tuple_dicts): """ Function to compute the hash, build the dataloader, get the input_shape, output_shape, and additional_model_params, and insert into DatasetConfig DJ table Parameters: tuple_dicts: A list of dict containing the attribute to be inserted in to DatasetConfig Returns: None """ # Iterate though each tuple_dict in tuple_dicts and get Get input_shape, output_shape, and additional_model_params and compute md5_hash and added it to the tuple_dict for tuple_dict in tuple_dicts: input_shape, output_shape, additional_model_params = self._get_dataset_additional_info(tuple_dict) tuple_dict['input_shape'] = input_shape tuple_dict['output_shape'] = output_shape tuple_dict['additional_model_params'] = additional_model_params super(DatasetConfig, self).insert_tuples(tuple_dicts)
[docs] def get_dataloaders(self, key, batch_size, num_workers=0): """ Method to build the dataloaders base off of the primary key and return it. Parameters: key (dict): DatasetConfig datajoint table priamry key. batch_size (int): Batchsize of the train, validation, test dataloaders classes. num_workers (int): Number of pytorch dataloader workers to use for the dataloaders. Returns: pytorch.dataloader, pytorch.dataloader, pytorch.dataloader: Train, validation, and test dataloaders """ # Fetch the details of the dataset configuration tuple_dict = (self & key).fetch1() self._check_if_dataset_file_is_cache(tuple_dict) # Get the dataset class and params via helper function dataset_class, dataset_class_params = self._get_dataset_class_and_params(tuple_dict) # Construct the dataset class with the correct parameters for train, validation, and test train_dataset = dataset_class(tuple_dict=tuple_dict,dataset_cache_dir=self.dataset_cache_dir,tier='train',**dataset_class_params) validation_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='validation', **dataset_class_params) test_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='test', **dataset_class_params) # Get the smapler class and params train_sampler_class, train_sampler_params, validation_sampler_class, validation_sampler_params, test_sampler_class, test_sampler_params = self._get_samplers_class_and_params(tuple_dict) # Creating the dataloader for train, validation, and test train_dataloader = self._get_dataloader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, sampler_class=train_sampler_class, sampler_params=train_sampler_params) validation_dataloader = self._get_dataloader(dataset=validation_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, sampler_class=validation_sampler_class, sampler_params=validation_sampler_params) test_dataloader = self._get_dataloader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, sampler_class=test_sampler_class, sampler_params=test_sampler_params) return train_dataloader, validation_dataloader, test_dataloader
def _get_dataset_additional_info(self, tuple_dict): """ Private helper function to get the input and output shape that the dataset expects from the model, plus any additional model params Parameters: key: DatasetConfig DJ primary key Returns: tuple or list of tuple, tuple or list of tuple, dict: input_shape output_shape, additional_model_params """ # Check if dataset in located in the cache directory self._check_if_dataset_file_is_cache(tuple_dict) # Get the dataset and dataset_class_params to create the validation_dataset dataset_class, dataset_class_params = self._get_dataset_class_and_params(tuple_dict) validation_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='validation', **dataset_class_params) # Get any additional_model_params additional_model_params = validation_dataset.get_additional_model_params() # Get the first datapoint in the validation dataset to get input_shape and output_shape validation_data_point = validation_dataset[0] return self._get_data_point_shape(validation_data_point[0]), self._get_data_point_shape(validation_data_point[1]), additional_model_params def _get_data_point_shape(self, data_point): """ Private helper function to get the shape of the datapoint. This is mainly to deal with the cause of multiple inputs or outputs Parameters: data_point (np.array or list of np.array) Returns: tuple or list of tuple: return the shape or shapes of the datapoint """ # Check if the datapoint is a list or not if type(data_point) == list: # If it is then loop through each sub_datapoint to get the shape and return it in a list return tuple(data.shape for data in data_point) else: # If not then just return the shape return data_point.shape def _check_if_dataset_file_is_cache(self, tuple_dict): """ Private helper function to check if the dataset file being reference has been cache to the dataset_cache folder Parameters: tuple_dict (dict): Dictionalry containing the dataset_file_name to be cached Returns: None """ # Check if the dataset_cache_dir exists, if not then create it if not os.path.exists(self.dataset_cache_dir): os.mkdir(self.dataset_cache_dir) print('dataset_cache_dir created at: ' + self.dataset_cache_dir) # Check if the file has been cache in temp_dataset_cache if not os.path.exists(self.dataset_cache_dir + tuple_dict['dataset_file_name']): # File was not found in the cache directory, thus we need to copy it print(tuple_dict['dataset_file_name'] + ' was not found in dataset_cache thus needs to be copy.') shutil.copyfile(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name']) print(tuple_dict['dataset_file_name'] + ' copied sucessfully') elif not filecmp.cmp(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name']): # If the file was found, check if the contents has changed or not print('Dataset file content has hanged, re-caching dataset file.') shutil.copyfile(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name']) else: # If the content didn't changed and the file is cached print('Dataset file ' + str(tuple_dict['dataset_file_name']) + ' already cached') @staticmethod def _get_dataloader(dataset, batch_size, shuffle, drop_last, num_workers, sampler_class, sampler_params): """ Private helper function to create and return a single dataloader with respect to the parameteres that was passed in. If sampler_class is define, then the dataloader will use that sampler class Parameters: dataset (torch.dataset): Standard pytorch dataset class. batch_size (int): batch_size to be use for the dataloader shuffle (bool): To enable shuffle or not, however if sampler is not none, then it default to false drop_last (bool): Rather to drop the last non full batch or not num_workers (int): Number of workers to use to load the dataset sampler_class (torch.Sampler): Sampler class that defines the sampling strategy sampler_params (dict): Sampler class parameters to be pass to the class Returns: pytorch.DataLoader: Returns dataloader that was initalized with the given parameters """ if sampler_class: return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, sampler=sampler_class(dataset, **sampler_params)) else: return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers) @classmethod def _get_dataset_class_and_params(cls, key): """ Helper function to get dataset class and its params Args: key (dict): primary key of tuple Returns: dataset_class (<user-defined>), dict: Returns corresponding dataset_class along with the dataset_class_params """ # Check key contains dataset_config_md5_hash, if not assume it is a tuple_dict to handle the case of inital insert if 'dataset_config_md5_hash' in key.keys(): dataset_class_module_name, dataset_class_name, dataset_class_params = (cls & dict(dataset_config_md5_hash=key['dataset_config_md5_hash'])).fetch1('dataset_class_module_name', 'dataset_class_name', 'dataset_class_params') else: dataset_class_module_name = key['dataset_class_module_name'] dataset_class_name = key['dataset_class_name'] dataset_class_params = key['dataset_class_params'] return super().import_class_from_module(dataset_class_module_name, dataset_class_name), dataset_class_params @classmethod def _get_samplers_class_and_params(cls, tuple_dict): """ Helper function to get the corresponding sampler class and params for train, validation, and test Args: key (dict): primary key of tuple Returns: dataset_class (<user-defined>), dict: Returns corresponding dataset_class along with the dataset_class_params """ # Import the sampler class and params and return them train_sampler_class = super().import_class_from_module(tuple_dict['train_sampler_module_name'], tuple_dict['train_sampler_class_name']) train_sampler_class_params = tuple_dict['train_sampler_class_params'] validation_sampler_class = super().import_class_from_module(tuple_dict['validation_sampler_module_name'], tuple_dict['validation_sampler_class_name']) validation_sampler_class_params = tuple_dict['validation_sampler_class_params'] test_sampler_class = super().import_class_from_module(tuple_dict['test_sampler_module_name'], tuple_dict['test_sampler_class_name']) test_sampler_class_params = tuple_dict['test_sampler_class_params'] return train_sampler_class, train_sampler_class_params, validation_sampler_class, validation_sampler_class_params, test_sampler_class, test_sampler_class_params