import numpy as np
import h5py
import json
from torch.utils.data import Dataset
from torch.utils.data import Subset
from torchvision.transforms import Compose
[docs]class NeuroDataDataset(Dataset):
def __init__(self, tuple_dict, dataset_cache_dir, tier, mode):
"""
Initlization function for NeuroDataDataset
NeuroDataDataset is the dataset class for loading .h5 files of neuron base datasets from the neurodata pipeline.
Currently it only supports static dataset, someone needs to write the logic dynamic later.
Parameters:
self (NeuroDataDataset) : instance of class
tuple_dict (dict): A dictionary containing all the columns of DatasetConfig for this given dataset
dataset_cache_dir (str): Path to where the local cache copy of the dataset is stored.
tier (str): Select the dataset examples base on the tier defined in the dataset class, either train, validation, or test
mode (str): Mode that the dataset should use which changes what it returns. More details in the __getitem__ function
Returns:
None
"""
self.tuple_dict = tuple_dict
self.dataset_cache_dir = dataset_cache_dir
self.tier = tier
self.mode = mode
self.additional_model_params = dict()
self.h5py_file = None # Work around placeholder for h5py and pytorch limitation of num_workers <= 1
# Load the h5py temporarly to get the inital stats, after that close it so pytorch can pickle it correctly
with h5py.File(self.dataset_cache_dir + self.tuple_dict['dataset_file_name'], 'r') as h5py_file:
# Find the indices of the samples that matches the current tier also work around with H5py limitation of not being to support mutiple workers
self.tiers = np.array([string.decode() for string in h5py_file['tiers']])
self.indices = np.where(self.tiers == tier)[0]
self.len = len(self.indices)
[docs] def get_additional_model_params(self):
"""
Return additional_model_params
Parameters:
None
Returns:
dict : additional_model_params to be passed on to the model at creation time
"""
return self.additional_model_params
def __getitem__(self, index):
"""
Standard __getitem__ function required of pytorch dataset class
Because an opened HDF5 file isn’t pickleable and to send Dataset to workers’ processes it needs to be serialised with pickle,
you can’t open the HDF5 file in __init__. Open it in __getitem__ and store as the singleton!.
Do not open it each time as it introduces huge overhead.
https://discuss.pytorch.org/t/dataloader-when-num-worker-0-there-is-bug/25643/16
Parameters:
self (NeuroDataDataset) : instance of class
index (int): index of the example to return from define dataset given the mode
Returns:
tuple : datapoint return varies depending on the dataset mode
"""
if self.h5py_file is None:
self.h5py_file = h5py.File(self.dataset_cache_dir + self.tuple_dict['dataset_file_name'], 'r')
if self.mode == 'decoding':
# Return neuron responses as inputs and stimuls image and targets
return self.h5py_file['responses'][self.indices[index]], self.h5py_file['images'][self.indices[index]]
elif self.mode == 'full-encoding':
# Returns images, behavior, pupil_center and inputs and responses as targets
images = self.h5py_file['images'][self.indices[index]] / 255 # Scale it to 0 to 1
behavior = self.h5py_file['behavior'][self.indices[index]]
pupil_center = self.h5py_file['pupil_center'][self.indices[index]]
responses = self.h5py_file['responses'][self.indices[index]] / self.h5py_file['responses'][self.indices[index]].std()
return [images.astype(np.float32), behavior.astype(np.float32), pupil_center.astype(np.float32)], responses.astype(np.float32)
def __len__(self):
"""
Standard __len__ function required of pytorch dataset class
Return the number of samples avaliable given the dataset configuration
Parameters:
self (NeuroDataDataset) : instance of class
index (int): index of the example to return from define dataset given the mode
Returns:
tuple : datapoint return varies depending on the dataset mode
"""
return self.len