File size: 2,190 Bytes
801501a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class LSTM(nn.Module):
    def __init__(self, text_dim, embedding_dim, vocab_size, padding_idx=0):
        super().__init__()
        self.padding_idx = padding_idx
        self.word_embedding = nn.Embedding(
            vocab_size, embedding_dim, padding_idx=padding_idx
        )
        self.rnn = nn.LSTM(embedding_dim, text_dim, batch_first=True)
        self.w_attn = nn.Parameter(torch.Tensor(1, text_dim))
        nn.init.xavier_uniform_(self.w_attn)

    def forward(self, padded_tokens, dropout=0.5):
        w_emb = self.word_embedding(padded_tokens)
        w_emb = F.dropout(w_emb, dropout, self.training)
        len_seq = (padded_tokens != self.padding_idx).sum(dim=1).cpu()
        x_packed = pack_padded_sequence(
            w_emb, len_seq, enforce_sorted=False, batch_first=True
            )
        B = padded_tokens.shape[0]
        rnn_out, _ = self.rnn(x_packed)
        rnn_out, dummy = pad_packed_sequence(rnn_out, batch_first=True)
        h = rnn_out[torch.arange(B), len_seq - 1]
        final_feat, attn = self.word_attention(rnn_out, h, len_seq)
        return final_feat, attn

    def word_attention(self, R, h, len_seq):
        """
        Input:
            R: hidden states of the entire words
            h: the final hidden state after processing the entire words
            len_seq: the length of the sequence
        Output:
            final_feat: the final feature after the bilinear attention
            attn: word attention weights
        """
        B, N, D = R.shape
        device = R.device
        len_seq = len_seq.to(device)

        W_attn = (self.w_attn * torch.eye(D).to(device))[None].repeat(B, 1, 1)
        score = torch.bmm(torch.bmm(R, W_attn), h.unsqueeze(-1))

        mask = torch.arange(N).reshape(1, N, 1).repeat(B, 1, 1).to(device)
        mask = mask < len_seq.reshape(B, 1, 1)

        score = score.masked_fill(mask == 0, -1e9)
        attn = F.softmax(score, 1)
        final_feat = torch.bmm(R.transpose(1, 2), attn).squeeze(-1)

        return final_feat, attn.squeeze(-1)