Source code for synicix_ml_pipeline.datajoint_tables.ModelConfig

import datajoint as dj
import os
import torch

from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable

[docs]@schema class ModelConfig(dj.Manual, BaseTable): """ A dj.Manual table class that handle the storage of pytorch models definition along with some helper function to help load the models Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook """ definition = """ model_config_id : int unsigned # MD5 Hash of network_class_name + network_module_code --- model_class_module_name : varchar(256) model_class_name : varchar(256) # Class name of the network model_class_params : longblob model_config_blobs_md5_hash : char(128) """
[docs] @classmethod def get_model_class_and_params(cls, key): """ Function to get the model and model_class params given a key Parameters: key (dict): A dictionary to restrict ModelConfig by down to one tuple Returns: <user_defined_model_class>, <user_defined_model_class_params>: Returns the model_class and model_class_params based on the key """ # Import TrainingTask, this is to work around circular imports from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask # Get require infomation from that table to import the model_class and model_class_params model_class_module_name, model_class_name, model_class_params = (cls & (TrainingTask & key)).fetch1('model_class_module_name', 'model_class_name', 'model_class_params') return super().import_class_from_module(model_class_module_name, model_class_name), model_class_params