Source code for synicix_ml_pipeline.datajoint_tables.TrainingConfig

import datajoint as dj

from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable

[docs]@schema class TrainingConfig(dj.Manual, BaseTable): """ A dj.Manual table class that handle the storage of training configuration such the trainer class to use and various other parameters realted to that. Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook """ definition = """ training_config_id : int unsigned --- trainer_class_module_name : varchar(256) trainer_class_name : varchar(256) trainer_class_params : longblob batch_size : smallint unsigned epoch_limit : int unsigned optimizer_class_module_name : varchar(256) optimizer_class_name : varchar(256) optimizer_class_params : longblob criterion_class_module_name : varchar(256) criterion_class_name : varchar(256) criterion_class_params : longblob training_config_blobs_md5_hash : char(128) """
[docs] @classmethod def get_optimizer_class_and_params(cls, key): """ Function to load the optimizer_class and optimizer_class_params from the DB base on the key given Parameters: key (dict): key to restrict TrainingConfig by into one tuple Returns: <user_defined_optimizer_class>, <user_defined_optimizer_class_params>: Returns the optimzer_class and optimizer_class_params for the given key """ # Import the TrainingTask for restriction, this is to get around circular imports from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask # Get require infomation from that table to import the optimizer_class and optimizer_class_params optimizer_class_module_name, optimizer_class_name, optimizer_class_params = ((TrainingTask & key) * cls).fetch1('optimizer_class_module_name', 'optimizer_class_name', 'optimizer_class_params') return super().import_class_from_module(optimizer_class_module_name, optimizer_class_name), optimizer_class_params
[docs] @classmethod def get_criterion_class_and_params(cls, key): """ Function to load the criterion_class and criterion_class_params from the DB base on the key given Parameters: key (dict): key to restrict TrainingConfig by into one tuple Returns: <user_defined_criterion_class>, <user_defined_criterion_class_params>: Returns the criterion_class and criterion_class_params for the given key """ # Import the TrainingTask for restriction, this is to get around circular imports from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask # Get require infomation from that table to import the criterion_class and criterion_class_params criterion_class_module_name, criterion_class_name, criterion_class_params = ((TrainingTask & key) * cls).fetch1('criterion_class_module_name', 'criterion_class_name', 'criterion_class_params') return super().import_class_from_module(criterion_class_module_name, criterion_class_name), criterion_class_params