import datajoint as dj
from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable
[docs]@schema
class TrainingConfig(dj.Manual, BaseTable):
"""
A dj.Manual table class that handle the storage of training configuration such the trainer class to use and various other parameters realted to that.
Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook
"""
definition = """
training_config_id : int unsigned
---
trainer_class_module_name : varchar(256)
trainer_class_name : varchar(256)
trainer_class_params : longblob
batch_size : smallint unsigned
epoch_limit : int unsigned
optimizer_class_module_name : varchar(256)
optimizer_class_name : varchar(256)
optimizer_class_params : longblob
criterion_class_module_name : varchar(256)
criterion_class_name : varchar(256)
criterion_class_params : longblob
training_config_blobs_md5_hash : char(128)
"""
[docs] @classmethod
def get_optimizer_class_and_params(cls, key):
"""
Function to load the optimizer_class and optimizer_class_params from the DB base on the key given
Parameters:
key (dict): key to restrict TrainingConfig by into one tuple
Returns:
<user_defined_optimizer_class>, <user_defined_optimizer_class_params>: Returns the optimzer_class and optimizer_class_params for the given key
"""
# Import the TrainingTask for restriction, this is to get around circular imports
from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask
# Get require infomation from that table to import the optimizer_class and optimizer_class_params
optimizer_class_module_name, optimizer_class_name, optimizer_class_params = ((TrainingTask & key) * cls).fetch1('optimizer_class_module_name', 'optimizer_class_name', 'optimizer_class_params')
return super().import_class_from_module(optimizer_class_module_name, optimizer_class_name), optimizer_class_params
[docs] @classmethod
def get_criterion_class_and_params(cls, key):
"""
Function to load the criterion_class and criterion_class_params from the DB base on the key given
Parameters:
key (dict): key to restrict TrainingConfig by into one tuple
Returns:
<user_defined_criterion_class>, <user_defined_criterion_class_params>: Returns the criterion_class and criterion_class_params for the given key
"""
# Import the TrainingTask for restriction, this is to get around circular imports
from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask
# Get require infomation from that table to import the criterion_class and criterion_class_params
criterion_class_module_name, criterion_class_name, criterion_class_params = ((TrainingTask & key) * cls).fetch1('criterion_class_module_name', 'criterion_class_name', 'criterion_class_params')
return super().import_class_from_module(criterion_class_module_name, criterion_class_name), criterion_class_params