import torch; torch.manual_seed(0) import torch.nn as nn import torch.nn.functional as F import torch.utils import torch.distributions import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 from src.cocktails.representation_learning.simple_model import SimpleNet device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_activation(activation): if activation == 'tanh': activ = F.tanh elif activation == 'relu': activ = F.relu elif activation == 'mish': activ = F.mish elif activation == 'sigmoid': activ = F.sigmoid elif activation == 'leakyrelu': activ = F.leaky_relu elif activation == 'exp': activ = torch.exp else: raise ValueError return activ class IngredientEncoder(nn.Module): def __init__(self, input_dim, deepset_latent_dim, hidden_dims, activation, dropout): super(IngredientEncoder, self).__init__() self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() dims = [input_dim] + hidden_dims + [deepset_latent_dim] for d_in, d_out in zip(dims[:-1], dims[1:]): self.linears.append(nn.Linear(d_in, d_out)) self.dropouts.append(nn.Dropout(dropout)) self.activation = get_activation(activation) self.n_layers = len(self.linears) self.layer_range = range(self.n_layers) def forward(self, x): for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): x = layer(x) if i_layer != self.n_layers - 1: x = self.activation(dropout(x)) return x # do not use dropout on last layer? class DeepsetCocktailEncoder(nn.Module): def __init__(self, input_dim, deepset_latent_dim, hidden_dims_ing, activation, hidden_dims_cocktail, latent_dim, aggregation, dropout): super(DeepsetCocktailEncoder, self).__init__() self.input_dim = input_dim # dimension of ingredient representation + quantity self.ingredient_encoder = IngredientEncoder(input_dim, deepset_latent_dim, hidden_dims_ing, activation, dropout) # encode each ingredient separately self.deepset_latent_dim = deepset_latent_dim # dimension of the deepset aggregation self.aggregation = aggregation self.latent_dim = latent_dim # post aggregation network self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() dims = [deepset_latent_dim] + hidden_dims_cocktail for d_in, d_out in zip(dims[:-1], dims[1:]): self.linears.append(nn.Linear(d_in, d_out)) self.dropouts.append(nn.Dropout(dropout)) self.FC_mean = nn.Linear(hidden_dims_cocktail[-1], latent_dim) self.FC_logvar = nn.Linear(hidden_dims_cocktail[-1], latent_dim) self.softplus = nn.Softplus() self.activation = get_activation(activation) self.n_layers = len(self.linears) self.layer_range = range(self.n_layers) def forward(self, nb_ingredients, x): # reshape x in (batch size * nb ingredients, dim_ing_rep) batch_size = x.shape[0] all_ingredients = [] for i in range(batch_size): for j in range(nb_ingredients[i]): all_ingredients.append(x[i, self.input_dim * j: self.input_dim * (j + 1)].reshape(1, -1)) x = torch.cat(all_ingredients, dim=0) # encode ingredients in parallel ingredients_encodings = self.ingredient_encoder(x) assert ingredients_encodings.shape == (torch.sum(nb_ingredients), self.deepset_latent_dim) # aggregate x = [] index_first = 0 for i in range(batch_size): index_last = index_first + nb_ingredients[i] # aggregate if self.aggregation == 'sum': x.append(torch.sum(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) elif self.aggregation == 'mean': x.append(torch.mean(ingredients_encodings[index_first:index_last], dim=0).reshape(1, -1)) else: raise ValueError index_first = index_last x = torch.cat(x, dim=0) assert x.shape[0] == batch_size for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): x = self.activation(dropout(layer(x))) mean = self.FC_mean(x) logvar = self.FC_logvar(x) return mean, logvar class MultiHeadModel(nn.Module): def __init__(self, encoder, auxiliaries_dict, activation, hidden_dims_decoder): super(MultiHeadModel, self).__init__() self.encoder = encoder self.latent_dim = self.encoder.output_dim self.auxiliaries_str = [] self.auxiliaries = nn.ModuleList() for aux_str in sorted(auxiliaries_dict.keys()): if aux_str == 'taste_reps': self.taste_reps_decoder = SimpleNet(input_dim=self.latent_dim, hidden_dims=[], output_dim=auxiliaries_dict[aux_str]['dim_output'], activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ']) else: self.auxiliaries_str.append(aux_str) if aux_str == 'ingredients_quantities': hd = hidden_dims_decoder else: hd = [] self.auxiliaries.append(SimpleNet(input_dim=self.latent_dim, hidden_dims=hd, output_dim=auxiliaries_dict[aux_str]['dim_output'], activation=activation, dropout=0.0, final_activ=auxiliaries_dict[aux_str]['final_activ'])) def get_all_auxiliaries(self, x): return [aux(x) for aux in self.auxiliaries] def get_auxiliary(self, z, aux_str): if aux_str == 'taste_reps': return self.taste_reps_decoder(z) else: index = self.auxiliaries_str.index(aux_str) return self.auxiliaries[index](z) def forward(self, x, aux_str=None): z = self.encoder(x) if aux_str is not None: return z, self.get_auxiliary(z, aux_str), [aux_str] else: return z, self.get_all_auxiliaries(z), self.auxiliaries_str def get_multihead_model(input_dim, activation, hidden_dims_cocktail, latent_dim, dropout, auxiliaries_dict, hidden_dims_decoder): encoder = SimpleNet(input_dim, hidden_dims_cocktail, latent_dim, activation, dropout) model = MultiHeadModel(encoder, auxiliaries_dict, activation, hidden_dims_decoder) return model