Spaces:
Sleeping
Sleeping
| from utils.transformer_modules import * | |
| from utils.transformer_modules import _gen_timing_signal, _gen_bias_mask | |
| from utils.hparams import HParams | |
| use_cuda = torch.cuda.is_available() | |
| class self_attention_block(nn.Module): | |
| def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, | |
| bias_mask=None, layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0, attention_map=False): | |
| super(self_attention_block, self).__init__() | |
| self.attention_map = attention_map | |
| self.multi_head_attention = MultiHeadAttention(hidden_size, total_key_depth, total_value_depth,hidden_size, num_heads, bias_mask, attention_dropout, attention_map) | |
| self.positionwise_convolution = PositionwiseFeedForward(hidden_size, filter_size, hidden_size, layer_config='cc', padding='both', dropout=relu_dropout) | |
| self.dropout = nn.Dropout(layer_dropout) | |
| self.layer_norm_mha = LayerNorm(hidden_size) | |
| self.layer_norm_ffn = LayerNorm(hidden_size) | |
| def forward(self, inputs): | |
| x = inputs | |
| # Layer Normalization | |
| x_norm = self.layer_norm_mha(x) | |
| # Multi-head attention | |
| if self.attention_map is True: | |
| y, weights = self.multi_head_attention(x_norm, x_norm, x_norm) | |
| else: | |
| y = self.multi_head_attention(x_norm, x_norm, x_norm) | |
| # Dropout and residual | |
| x = self.dropout(x + y) | |
| # Layer Normalization | |
| x_norm = self.layer_norm_ffn(x) | |
| # Positionwise Feedforward | |
| y = self.positionwise_convolution(x_norm) | |
| # Dropout and residual | |
| y = self.dropout(x + y) | |
| if self.attention_map is True: | |
| return y, weights | |
| return y | |
| class bi_directional_self_attention(nn.Module): | |
| def __init__(self, hidden_size, total_key_depth, total_value_depth, filter_size, num_heads, max_length, | |
| layer_dropout=0.0, attention_dropout=0.0, relu_dropout=0.0): | |
| super(bi_directional_self_attention, self).__init__() | |
| self.weights_list = list() | |
| params = (hidden_size, | |
| total_key_depth or hidden_size, | |
| total_value_depth or hidden_size, | |
| filter_size, | |
| num_heads, | |
| _gen_bias_mask(max_length), | |
| layer_dropout, | |
| attention_dropout, | |
| relu_dropout, | |
| True) | |
| self.attn_block = self_attention_block(*params) | |
| params = (hidden_size, | |
| total_key_depth or hidden_size, | |
| total_value_depth or hidden_size, | |
| filter_size, | |
| num_heads, | |
| torch.transpose(_gen_bias_mask(max_length), dim0=2, dim1=3), | |
| layer_dropout, | |
| attention_dropout, | |
| relu_dropout, | |
| True) | |
| self.backward_attn_block = self_attention_block(*params) | |
| self.linear = nn.Linear(hidden_size*2, hidden_size) | |
| def forward(self, inputs): | |
| x, list = inputs | |
| # Forward Self-attention Block | |
| encoder_outputs, weights = self.attn_block(x) | |
| # Backward Self-attention Block | |
| reverse_outputs, reverse_weights = self.backward_attn_block(x) | |
| # Concatenation and Fully-connected Layer | |
| outputs = torch.cat((encoder_outputs, reverse_outputs), dim=2) | |
| y = self.linear(outputs) | |
| # Attention weights for Visualization | |
| self.weights_list = list | |
| self.weights_list.append(weights) | |
| self.weights_list.append(reverse_weights) | |
| return y, self.weights_list | |
| class bi_directional_self_attention_layers(nn.Module): | |
| def __init__(self, embedding_size, hidden_size, num_layers, num_heads, total_key_depth, total_value_depth, | |
| filter_size, max_length=100, input_dropout=0.0, layer_dropout=0.0, | |
| attention_dropout=0.0, relu_dropout=0.0): | |
| super(bi_directional_self_attention_layers, self).__init__() | |
| self.timing_signal = _gen_timing_signal(max_length, hidden_size) | |
| params = (hidden_size, | |
| total_key_depth or hidden_size, | |
| total_value_depth or hidden_size, | |
| filter_size, | |
| num_heads, | |
| max_length, | |
| layer_dropout, | |
| attention_dropout, | |
| relu_dropout) | |
| self.embedding_proj = nn.Linear(embedding_size, hidden_size, bias=False) | |
| self.self_attn_layers = nn.Sequential(*[bi_directional_self_attention(*params) for l in range(num_layers)]) | |
| self.layer_norm = LayerNorm(hidden_size) | |
| self.input_dropout = nn.Dropout(input_dropout) | |
| def forward(self, inputs): | |
| # Add input dropout | |
| x = self.input_dropout(inputs) | |
| # Project to hidden size | |
| x = self.embedding_proj(x) | |
| # Add timing signal | |
| x += self.timing_signal[:, :inputs.shape[1], :].type_as(inputs.data) | |
| # A Stack of Bi-directional Self-attention Layers | |
| y, weights_list = self.self_attn_layers((x, [])) | |
| # Layer Normalization | |
| y = self.layer_norm(y) | |
| return y, weights_list | |
| class BTC_model(nn.Module): | |
| def __init__(self, config): | |
| super(BTC_model, self).__init__() | |
| self.timestep = config['timestep'] | |
| self.probs_out = config['probs_out'] | |
| params = (config['feature_size'], | |
| config['hidden_size'], | |
| config['num_layers'], | |
| config['num_heads'], | |
| config['total_key_depth'], | |
| config['total_value_depth'], | |
| config['filter_size'], | |
| config['timestep'], | |
| config['input_dropout'], | |
| config['layer_dropout'], | |
| config['attention_dropout'], | |
| config['relu_dropout']) | |
| self.self_attn_layers = bi_directional_self_attention_layers(*params) | |
| self.output_layer = SoftmaxOutputLayer(hidden_size=config['hidden_size'], output_size=config['num_chords'], probs_out=config['probs_out']) | |
| def forward(self, x, labels): | |
| labels = labels.view(-1, self.timestep) | |
| # Output of Bi-directional Self-attention Layers | |
| self_attn_output, weights_list = self.self_attn_layers(x) | |
| # return logit values for CRF | |
| if self.probs_out is True: | |
| logits = self.output_layer(self_attn_output) | |
| return logits | |
| # Output layer and Soft-max | |
| prediction,second = self.output_layer(self_attn_output) | |
| prediction = prediction.view(-1) | |
| second = second.view(-1) | |
| # Loss Calculation | |
| loss = self.output_layer.loss(self_attn_output, labels) | |
| return prediction, loss, weights_list, second | |
| if __name__ == "__main__": | |
| config = HParams.load("run_config.yaml") | |
| device = torch.device("cuda" if use_cuda else "cpu") | |
| batch_size = 2 | |
| timestep = 108 | |
| feature_size = 144 | |
| num_chords = 25 | |
| features = torch.randn(batch_size,timestep,feature_size,requires_grad=True).to(device) | |
| chords = torch.randint(25,(batch_size*timestep,)).to(device) | |
| model = BTC_model(config=config.model).to(device) | |
| prediction, loss, weights_list, second = model(features, chords) | |
| print(prediction.size()) | |
| print(loss) | |