import datajoint as dj
import os
import shutil
import json
import filecmp
from torch.utils.data import Subset, DataLoader
from torchvision.transforms import Compose
from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable
[docs]@schema
class DatasetConfig(dj.Manual, BaseTable):
"""
A dj.Manual table class that handle the storage dataset configs with details on what dataset class and params to load that dataset and dataloader with.
Initializing of this class will require the corresponding dataset_dir and dataset_cache_dir to be defined and passed into the __init__()
Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook
"""
definition = """
dataset_config_id : int unsigned
---
dataset_file_name : varchar(256)
dataset_type : varchar(256)
dataset_class_module_name : varchar(256)
dataset_class_name : varchar(256)
dataset_class_params : longblob
train_sampler_module_name : varchar(256)
train_sampler_class_name : varchar(256)
train_sampler_class_params : longblob
validation_sampler_module_name : varchar(256)
validation_sampler_class_name : varchar(256)
validation_sampler_class_params : longblob
test_sampler_module_name : longblob
test_sampler_class_name : varchar(256)
test_sampler_class_params : longblob
input_shape : longblob
output_shape : longblob
additional_model_params : longblob
dataset_config_blobs_md5_hash : char(128)
"""
def __init__(self, dataset_dir=None, dataset_cache_dir=None):
"""
Initilize function
Parameters:
dataset_dir (str): directory of where the dataset file are located
dataset_cache_dir (str): directory of where to cache the dataset files
Returns:
None
"""
super().__init__()
self.dataset_dir = dataset_dir
self.dataset_cache_dir = dataset_cache_dir
[docs] def insert_tuples(self, tuple_dicts):
"""
Function to compute the hash, build the dataloader, get the input_shape, output_shape, and additional_model_params, and insert into DatasetConfig DJ table
Parameters:
tuple_dicts: A list of dict containing the attribute to be inserted in to DatasetConfig
Returns:
None
"""
# Iterate though each tuple_dict in tuple_dicts and get Get input_shape, output_shape, and additional_model_params and compute md5_hash and added it to the tuple_dict
for tuple_dict in tuple_dicts:
input_shape, output_shape, additional_model_params = self._get_dataset_additional_info(tuple_dict)
tuple_dict['input_shape'] = input_shape
tuple_dict['output_shape'] = output_shape
tuple_dict['additional_model_params'] = additional_model_params
super(DatasetConfig, self).insert_tuples(tuple_dicts)
[docs] def get_dataloaders(self, key, batch_size, num_workers=0):
"""
Method to build the dataloaders base off of the primary key and return it.
Parameters:
key (dict): DatasetConfig datajoint table priamry key.
batch_size (int): Batchsize of the train, validation, test dataloaders classes.
num_workers (int): Number of pytorch dataloader workers to use for the dataloaders.
Returns:
pytorch.dataloader, pytorch.dataloader, pytorch.dataloader: Train, validation, and test dataloaders
"""
# Fetch the details of the dataset configuration
tuple_dict = (self & key).fetch1()
self._check_if_dataset_file_is_cache(tuple_dict)
# Get the dataset class and params via helper function
dataset_class, dataset_class_params = self._get_dataset_class_and_params(tuple_dict)
# Construct the dataset class with the correct parameters for train, validation, and test
train_dataset = dataset_class(tuple_dict=tuple_dict,dataset_cache_dir=self.dataset_cache_dir,tier='train',**dataset_class_params)
validation_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='validation', **dataset_class_params)
test_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='test', **dataset_class_params)
# Get the smapler class and params
train_sampler_class, train_sampler_params, validation_sampler_class, validation_sampler_params, test_sampler_class, test_sampler_params = self._get_samplers_class_and_params(tuple_dict)
# Creating the dataloader for train, validation, and test
train_dataloader = self._get_dataloader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers, sampler_class=train_sampler_class, sampler_params=train_sampler_params)
validation_dataloader = self._get_dataloader(dataset=validation_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, sampler_class=validation_sampler_class, sampler_params=validation_sampler_params)
test_dataloader = self._get_dataloader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, sampler_class=test_sampler_class, sampler_params=test_sampler_params)
return train_dataloader, validation_dataloader, test_dataloader
def _get_dataset_additional_info(self, tuple_dict):
"""
Private helper function to get the input and output shape that the dataset expects from the model, plus any additional model params
Parameters:
key: DatasetConfig DJ primary key
Returns:
tuple or list of tuple, tuple or list of tuple, dict: input_shape output_shape, additional_model_params
"""
# Check if dataset in located in the cache directory
self._check_if_dataset_file_is_cache(tuple_dict)
# Get the dataset and dataset_class_params to create the validation_dataset
dataset_class, dataset_class_params = self._get_dataset_class_and_params(tuple_dict)
validation_dataset = dataset_class(tuple_dict=tuple_dict, dataset_cache_dir=self.dataset_cache_dir, tier='validation', **dataset_class_params)
# Get any additional_model_params
additional_model_params = validation_dataset.get_additional_model_params()
# Get the first datapoint in the validation dataset to get input_shape and output_shape
validation_data_point = validation_dataset[0]
return self._get_data_point_shape(validation_data_point[0]), self._get_data_point_shape(validation_data_point[1]), additional_model_params
def _get_data_point_shape(self, data_point):
"""
Private helper function to get the shape of the datapoint. This is mainly to deal with the cause of multiple inputs or outputs
Parameters:
data_point (np.array or list of np.array)
Returns:
tuple or list of tuple: return the shape or shapes of the datapoint
"""
# Check if the datapoint is a list or not
if type(data_point) == list:
# If it is then loop through each sub_datapoint to get the shape and return it in a list
return tuple(data.shape for data in data_point)
else:
# If not then just return the shape
return data_point.shape
def _check_if_dataset_file_is_cache(self, tuple_dict):
"""
Private helper function to check if the dataset file being reference has been cache to the dataset_cache folder
Parameters:
tuple_dict (dict): Dictionalry containing the dataset_file_name to be cached
Returns:
None
"""
# Check if the dataset_cache_dir exists, if not then create it
if not os.path.exists(self.dataset_cache_dir):
os.mkdir(self.dataset_cache_dir)
print('dataset_cache_dir created at: ' + self.dataset_cache_dir)
# Check if the file has been cache in temp_dataset_cache
if not os.path.exists(self.dataset_cache_dir + tuple_dict['dataset_file_name']):
# File was not found in the cache directory, thus we need to copy it
print(tuple_dict['dataset_file_name'] + ' was not found in dataset_cache thus needs to be copy.')
shutil.copyfile(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name'])
print(tuple_dict['dataset_file_name'] + ' copied sucessfully')
elif not filecmp.cmp(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name']):
# If the file was found, check if the contents has changed or not
print('Dataset file content has hanged, re-caching dataset file.')
shutil.copyfile(self.dataset_dir + tuple_dict['dataset_file_name'], self.dataset_cache_dir + tuple_dict['dataset_file_name'])
else:
# If the content didn't changed and the file is cached
print('Dataset file ' + str(tuple_dict['dataset_file_name']) + ' already cached')
@staticmethod
def _get_dataloader(dataset, batch_size, shuffle, drop_last, num_workers, sampler_class, sampler_params):
"""
Private helper function to create and return a single dataloader with respect to the parameteres that was passed in.
If sampler_class is define, then the dataloader will use that sampler class
Parameters:
dataset (torch.dataset): Standard pytorch dataset class.
batch_size (int): batch_size to be use for the dataloader
shuffle (bool): To enable shuffle or not, however if sampler is not none, then it default to false
drop_last (bool): Rather to drop the last non full batch or not
num_workers (int): Number of workers to use to load the dataset
sampler_class (torch.Sampler): Sampler class that defines the sampling strategy
sampler_params (dict): Sampler class parameters to be pass to the class
Returns:
pytorch.DataLoader: Returns dataloader that was initalized with the given parameters
"""
if sampler_class:
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, sampler=sampler_class(dataset, **sampler_params))
else:
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
@classmethod
def _get_dataset_class_and_params(cls, key):
"""
Helper function to get dataset class and its params
Args:
key (dict): primary key of tuple
Returns:
dataset_class (<user-defined>), dict: Returns corresponding dataset_class along with the dataset_class_params
"""
# Check key contains dataset_config_md5_hash, if not assume it is a tuple_dict to handle the case of inital insert
if 'dataset_config_md5_hash' in key.keys():
dataset_class_module_name, dataset_class_name, dataset_class_params = (cls & dict(dataset_config_md5_hash=key['dataset_config_md5_hash'])).fetch1('dataset_class_module_name', 'dataset_class_name', 'dataset_class_params')
else:
dataset_class_module_name = key['dataset_class_module_name']
dataset_class_name = key['dataset_class_name']
dataset_class_params = key['dataset_class_params']
return super().import_class_from_module(dataset_class_module_name, dataset_class_name), dataset_class_params
@classmethod
def _get_samplers_class_and_params(cls, tuple_dict):
"""
Helper function to get the corresponding sampler class and params for train, validation, and test
Args:
key (dict): primary key of tuple
Returns:
dataset_class (<user-defined>), dict: Returns corresponding dataset_class along with the dataset_class_params
"""
# Import the sampler class and params and return them
train_sampler_class = super().import_class_from_module(tuple_dict['train_sampler_module_name'], tuple_dict['train_sampler_class_name'])
train_sampler_class_params = tuple_dict['train_sampler_class_params']
validation_sampler_class = super().import_class_from_module(tuple_dict['validation_sampler_module_name'], tuple_dict['validation_sampler_class_name'])
validation_sampler_class_params = tuple_dict['validation_sampler_class_params']
test_sampler_class = super().import_class_from_module(tuple_dict['test_sampler_module_name'], tuple_dict['test_sampler_class_name'])
test_sampler_class_params = tuple_dict['test_sampler_class_params']
return train_sampler_class, train_sampler_class_params, validation_sampler_class, validation_sampler_class_params, test_sampler_class, test_sampler_class_params