File size: 993 Bytes
4943752
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d65ddc0
 
 
4943752
 
d65ddc0
4943752
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy
import torch
import torch.nn as nn


class LCF_Pooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states, lcf_vec):
        device = hidden_states.device
        lcf_vec = lcf_vec.detach().cpu().numpy()

        pooled_output = numpy.zeros(
            (hidden_states.shape[0], hidden_states.shape[2]), dtype=numpy.float32
        )
        hidden_states = hidden_states.detach().cpu().numpy()
        for i, vec in enumerate(lcf_vec):
            lcf_ids = [j for j in range(len(vec)) if sum(vec[j] - 1.0) == 0]
            pooled_output[i] = hidden_states[i][lcf_ids[len(lcf_ids) // 2]]

        pooled_output = torch.Tensor(pooled_output).to(device)
        pooled_output = self.dense(pooled_output)
        pooled_output = self.activation(pooled_output)
        return pooled_output