Source code for synicix_ml_pipeline.models.SimpleMLP

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable

[docs]class SimpleMLP(nn.Module): """ A very simple example of a multilayer perceptron network illustrating the typical model design for synicix_ml_pipeline """ def __init__(self, input_shape, output_shape, num_hidden_layers=1, hidden_size=1000, nonlinear_function_module_name='torch.nn', nonlinear_function_class_name='ELU', nonlinear_function_class_params=dict(), final_relu=False, l1_loss_lamda=0.0, l2_loss_lamda=0.0): """ Initialize function for SimpleMLP Parameters: input_shape (tuple): Shape of the input tensor or list of tensor output_shape (tuple): Shape of the output tensor or list of tensor num_hidden_layers (int): Number of fully connected hidden layers hidden_size (int): Hidden layer sizes l1_loss_lamda: Multiplier for L1 Loss l2_loss_lamda: Multiplier for L2 Loss Returns: None """ super().__init__() self.input_shape = input_shape self.output_shape = output_shape self.nonlinear_function = BaseTable.import_class_from_module(nonlinear_function_module_name, nonlinear_function_class_name)(**nonlinear_function_class_params) # Focus only on first input and output if type(input_shape[0]) == tuple: self.input_shape_prod = np.prod(input_shape[0]) else: self.input_shape_prod = np.prod(input_shape) if type(output_shape[0]) == tuple: self.output_shape_prod = np.prod(output_shape[0]) else: self.output_shape_prod = np.prod(output_shape) # Fully Connected Layers layers = [] layers.append(nn.Linear(self.input_shape_prod, hidden_size)) layers.append(self.nonlinear_function) layers.append(nn.BatchNorm1d(hidden_size)) for i in range(num_hidden_layers): layers.append(nn.Linear(hidden_size, hidden_size)) layers.append(self.nonlinear_function) layers.append(nn.BatchNorm1d(hidden_size)) layers.append(nn.Linear(hidden_size, self.output_shape_prod)) if final_relu: layers.append(nn.ReLU()) self.layers = nn.Sequential(*layers) # l1_loss_lamda self.l1_loss_lamda = l1_loss_lamda self.l2_loss_lamda = l2_loss_lamda
[docs] def forward(self, x): """ Forward function for SimpleMLP Parameters: x (Tensor or list of Tensor): input batch, by default it only uses the first input tensor Returns: Tensor: Output of the model given the input Tensor: Regularlization loss to be added during the loss backpropergation """ if type(x) == list: x = x[0] x = x.reshape(x.shape[0], -1) x = self.layers(x) x = x.reshape(-1, *self.output_shape) return x, self.compute_regularizer()
[docs] def compute_regularizer(self): """ Regularlization Computation Function Parameters: None Returns: Tensor: Regularlization loss to be added during the loss backpropergation """ regularizer_loss = 0 for params in self.parameters(): regularizer_loss += self.l1_loss_lamda * params.float().abs().mean() + self.l2_loss_lamda * params.float().pow(2).mean() return regularizer_loss