|
|
|
from torch import nn |
|
|
|
from autoencoder_model.configuration_autoencoder import AutoEncoderConfig |
|
|
|
|
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
|
|
def create_layers(model_section, layer_types, input_dim, latent_dim, num_layers, dropout_rate, compression_rate): |
|
|
|
layers = [] |
|
current_dim = input_dim |
|
|
|
input_diamensions = [] |
|
output_diamensions = [] |
|
|
|
for _ in range(num_layers): |
|
input_diamensions.append(current_dim) |
|
next_dim = max(int(current_dim * compression_rate), latent_dim) |
|
current_dim = next_dim |
|
output_diamensions.append(current_dim) |
|
|
|
output_diamensions[num_layers - 1] = latent_dim |
|
|
|
if model_section == "decoder": |
|
input_diamensions, output_diamensions = output_diamensions, input_diamensions |
|
input_diamensions.reverse() |
|
output_diamensions.reverse() |
|
|
|
for idx, (input_dim, output_dim) in enumerate(zip(input_diamensions, output_diamensions)): |
|
if layer_types == 'linear': |
|
layers.append(nn.Linear(input_dim, output_dim)) |
|
elif layer_types == 'lstm': |
|
|
|
layers.append(nn.LSTM(input_dim, output_dim, batch_first=True)) |
|
elif layer_types == 'rnn': |
|
|
|
layers.append(nn.RNN(input_dim, output_dim, batch_first=True)) |
|
elif layer_types == 'gru': |
|
|
|
layers.append(nn.GRU(input_dim, output_dim, batch_first=True)) |
|
if (idx != num_layers - 1) & (dropout_rate != None): |
|
layers.append(nn.Dropout(dropout_rate)) |
|
return nn.Sequential(*layers) |
|
|
|
class AutoEncoder(PreTrainedModel): |
|
config_class = AutoEncoderConfig |
|
|
|
def __init__(self, config): |
|
super(AutoEncoder, self).__init__(config) |
|
|
|
self.encoder = create_layers("encoder", |
|
config.layer_types, config.input_dim, config.latent_dim, |
|
config.num_layers, config.dropout_rate, config.compression_rate |
|
) |
|
|
|
self.decoder = create_layers("decoder", |
|
config.layer_types, config.input_dim, config.latent_dim, |
|
config.num_layers, config.dropout_rate, config.compression_rate |
|
) |
|
|
|
def forward(self, x): |
|
|
|
if config.layer_types == ['lstm', 'rnn', 'gru']: |
|
x, _ = self.encoder(x) |
|
x, _ = self.decoder(x) |
|
else: |
|
x = self.encoder(x) |
|
x = self.decoder(x) |
|
return x |