import pandas as pd
import numpy as np
from torch.utils.data import Dataset
[docs]class StockDataset(Dataset):
def __init__(self, tuple_dict, dataset_cache_dir, tier, train_percentage, validation_percentage,
inputs_columns=['Price', 'Open', 'High', 'Low', 'Vol.'],
target_columns=['Change %'],
input_time_window_len=30,
target_time_window_len=1):
self.tuple_dict = tuple_dict
self.dataset_cache_dir = dataset_cache_dir
self.tier = tier
self.train_percentage = train_percentage
self.validation_percentage = validation_percentage
self.input_time_window_len = input_time_window_len
self.target_time_window_len = target_time_window_len
self.dataset = pd.read_csv(self.dataset_cache_dir + tuple_dict['dataset_file_name'])
self.input_columns = inputs_columns
self.target_columns = target_columns
self.dataset_stats = dict()
self.additional_model_params = dict()
# Remove commas from all columns
for column in ['Price', 'Open', 'High', 'Low', 'Vol.']:
self.dataset[column] = self.dataset[column].str.replace(',', '')
if column == 'Vol.':
self.dataset[column] = self.dataset[column].apply(self.value_to_float)
else:
self.dataset[column] = self.dataset[column].str.replace(',', '').astype(float)
# Convert Percentage Value to decimal
self.dataset['Change %'] = self.dataset['Change %'].str.replace('%', '').astype(float) / 100
# Compute the difference
dataset = dict()
for column_name in np.unique(self.input_columns + self.target_columns):
if not column_name == 'Change %':
dataset[column_name] = np.diff(self.dataset[column_name])
else:
dataset[column_name] = self.dataset[column_name][:-1]
self.dataset = pd.DataFrame().from_dict(dataset)
# Compute starting index base on tier
if tier == 'train':
self.starting_index_offset = 0
self.len = int(self.train_percentage * (len(self.dataset) - self.input_time_window_len - self.target_time_window_len + 1))
elif tier == 'validation':
self.starting_index_offset = int(self.train_percentage * (len(self.dataset) - self.input_time_window_len - self.target_time_window_len + 1))
self.len = int(self.validation_percentage * (len(self.dataset) - self.input_time_window_len - self.target_time_window_len + 1))
elif tier == 'test':
self.starting_index_offset = int((self.train_percentage + self.validation_percentage) * (len(self.dataset) - self.input_time_window_len - self.target_time_window_len + 1))
self.len = int((1 - (self.train_percentage + self.validation_percentage)) * (len(self.dataset) - self.input_time_window_len - self.target_time_window_len + 1))
[docs] def get_additional_model_params(self):
"""
Return additional_model_params
Parameters:
None
Returns:
dict : additional_model_params to be passed on to the model at creation time
"""
return self.additional_model_params
[docs] def value_to_float(self, x):
if 'K' in x:
if len(x) > 1:
return float(x.replace('K', '')) * 1000
return 1000.0
elif 'M' in x:
if len(x) > 1:
return float(x.replace('M', '')) * 1000000
return 1000000.0
elif 'B' in x:
return float(x.replace('B', '')) * 1000000000
else:
return 0.0
def __getitem__(self, index):
inputs = np.empty(shape=(len(self.input_columns), self.input_time_window_len), dtype=float)
targets = np.empty(shape=(len(self.target_columns), self.target_time_window_len), dtype=float)
offset_index = index + self.starting_index_offset
inputs_stopping_index = offset_index + self.input_time_window_len
for i, column_name in enumerate(self.input_columns):
inputs[i] = self.dataset[column_name][offset_index : inputs_stopping_index]
inputs_min = inputs[i].min()
inputs_max = inputs[i].max()
if inputs_min == inputs_max:
inputs[:] = 0.0
else:
inputs[i] = ((inputs[i] - inputs_min) / (inputs_max - inputs_min)) - 0.5
for i, column_name in enumerate(self.target_columns):
targets[i] = self.dataset[column_name][inputs_stopping_index : inputs_stopping_index + self.target_time_window_len]
targets_min = targets[i].min()
targets_max = targets[i].max()
'''
if targets_min == targets_max:
targets[:] = 0.0
else:
targets[i] = ((targets[i] - targets_min) / (targets_max - targets_min)) - 0.5
'''
return inputs.astype(np.float32), targets.astype(np.float32)
def __len__(self):
return self.len