import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from synicix_ml_pipeline.datajoint_tables.BaseTable import BaseTable
[docs]class SimpleMLP(nn.Module):
"""
A very simple example of a multilayer perceptron network illustrating the typical model design for synicix_ml_pipeline
"""
def __init__(self, input_shape, output_shape, num_hidden_layers=1, hidden_size=1000, nonlinear_function_module_name='torch.nn', nonlinear_function_class_name='ELU', nonlinear_function_class_params=dict(), final_relu=False, l1_loss_lamda=0.0, l2_loss_lamda=0.0):
"""
Initialize function for SimpleMLP
Parameters:
input_shape (tuple): Shape of the input tensor or list of tensor
output_shape (tuple): Shape of the output tensor or list of tensor
num_hidden_layers (int): Number of fully connected hidden layers
hidden_size (int): Hidden layer sizes
l1_loss_lamda: Multiplier for L1 Loss
l2_loss_lamda: Multiplier for L2 Loss
Returns:
None
"""
super().__init__()
self.input_shape = input_shape
self.output_shape = output_shape
self.nonlinear_function = BaseTable.import_class_from_module(nonlinear_function_module_name, nonlinear_function_class_name)(**nonlinear_function_class_params)
# Focus only on first input and output
if type(input_shape[0]) == tuple:
self.input_shape_prod = np.prod(input_shape[0])
else:
self.input_shape_prod = np.prod(input_shape)
if type(output_shape[0]) == tuple:
self.output_shape_prod = np.prod(output_shape[0])
else:
self.output_shape_prod = np.prod(output_shape)
# Fully Connected Layers
layers = []
layers.append(nn.Linear(self.input_shape_prod, hidden_size))
layers.append(self.nonlinear_function)
layers.append(nn.BatchNorm1d(hidden_size))
for i in range(num_hidden_layers):
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(self.nonlinear_function)
layers.append(nn.BatchNorm1d(hidden_size))
layers.append(nn.Linear(hidden_size, self.output_shape_prod))
if final_relu:
layers.append(nn.ReLU())
self.layers = nn.Sequential(*layers)
# l1_loss_lamda
self.l1_loss_lamda = l1_loss_lamda
self.l2_loss_lamda = l2_loss_lamda
[docs] def forward(self, x):
"""
Forward function for SimpleMLP
Parameters:
x (Tensor or list of Tensor): input batch, by default it only uses the first input tensor
Returns:
Tensor: Output of the model given the input
Tensor: Regularlization loss to be added during the loss backpropergation
"""
if type(x) == list:
x = x[0]
x = x.reshape(x.shape[0], -1)
x = self.layers(x)
x = x.reshape(-1, *self.output_shape)
return x, self.compute_regularizer()
[docs] def compute_regularizer(self):
"""
Regularlization Computation Function
Parameters:
None
Returns:
Tensor: Regularlization loss to be added during the loss backpropergation
"""
regularizer_loss = 0
for params in self.parameters():
regularizer_loss += self.l1_loss_lamda * params.float().abs().mean() + self.l2_loss_lamda * params.float().pow(2).mean()
return regularizer_loss