import datajoint as dj
import os
import torch
from synicix_ml_pipeline.datajoint_tables.BaseTable import schema
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable
[docs]@schema
class ModelConfig(dj.Manual, BaseTable):
"""
A dj.Manual table class that handle the storage of pytorch models definition along with some helper function to help load the models
Typical usage of this class is done by using the method insert_tuples. An example of this can be found in the Pipeline Configuration Jupyter Notebook
"""
definition = """
model_config_id : int unsigned # MD5 Hash of network_class_name + network_module_code
---
model_class_module_name : varchar(256)
model_class_name : varchar(256) # Class name of the network
model_class_params : longblob
model_config_blobs_md5_hash : char(128)
"""
[docs] @classmethod
def get_model_class_and_params(cls, key):
"""
Function to get the model and model_class params given a key
Parameters:
key (dict): A dictionary to restrict ModelConfig by down to one tuple
Returns:
<user_defined_model_class>, <user_defined_model_class_params>: Returns the model_class and model_class_params based on the key
"""
# Import TrainingTask, this is to work around circular imports
from synicix_ml_pipeline.datajoint_tables.TrainingTask import TrainingTask
# Get require infomation from that table to import the model_class and model_class_params
model_class_module_name, model_class_name, model_class_params = (cls & (TrainingTask & key)).fetch1('model_class_module_name', 'model_class_name', 'model_class_params')
return super().import_class_from_module(model_class_module_name, model_class_name), model_class_params