File size: 4,871 Bytes
5e9bd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
from .decode_strategy import DecodeStrategy


def sample_with_temperature(logits, sampling_temp, keep_topk):
    """Select next tokens randomly from the top k possible next tokens.

    Samples from a categorical distribution over the ``keep_topk`` words using
    the category probabilities ``logits / sampling_temp``.
    """

    if sampling_temp == 0.0 or keep_topk == 1:
        # argmax
        topk_scores, topk_ids = logits.topk(1, dim=-1)
        if sampling_temp > 0:
            topk_scores /= sampling_temp
    else:
        logits = torch.div(logits, sampling_temp)
        if keep_topk > 0:
            top_values, top_indices = torch.topk(logits, keep_topk, dim=1)
            kth_best = top_values[:, -1].view([-1, 1])
            kth_best = kth_best.repeat([1, logits.shape[1]]).float()
            ignore = torch.lt(logits, kth_best)
            logits = logits.masked_fill(ignore, -10000)

        dist = torch.distributions.Multinomial(logits=logits, total_count=1)
        topk_ids = torch.argmax(dist.sample(), dim=1, keepdim=True)
        topk_scores = logits.gather(dim=1, index=topk_ids)

    return topk_ids, topk_scores


class GreedySearch(DecodeStrategy):
    """Select next tokens randomly from the top k possible next tokens.
    """

    def __init__(self, pad, bos, eos, batch_size, min_length, max_length,
                 return_attention=False, return_hidden=False, sampling_temp=1, keep_topk=1):
        super().__init__(
            pad, bos, eos, batch_size, 1, min_length, max_length, return_attention, return_hidden)
        self.sampling_temp = sampling_temp
        self.keep_topk = keep_topk
        self.topk_scores = None

    def initialize(self, memory_bank, device=None):
        fn_map_state = None

        if device is None:
            device = memory_bank.device

        self.memory_length = memory_bank.size(1)
        super().initialize(memory_bank, device)

        self.select_indices = torch.arange(
            self.batch_size, dtype=torch.long, device=device)
        self.original_batch_idx = torch.arange(
            self.batch_size, dtype=torch.long, device=device)

        return fn_map_state, memory_bank

    @property
    def current_predictions(self):
        return self.alive_seq[:, -1]

    @property
    def batch_offset(self):
        return self.select_indices

    def _pick(self, log_probs):
        """Function used to pick next tokens.
        """
        topk_ids, topk_scores = sample_with_temperature(
            log_probs, self.sampling_temp, self.keep_topk)
        return topk_ids, topk_scores

    def advance(self, log_probs, attn=None, hidden=None, label=None):
        """Select next tokens randomly from the top k possible next tokens.
        """
        self.ensure_min_length(log_probs)
        topk_ids, self.topk_scores = self._pick(log_probs)
        self.is_finished = topk_ids.eq(self.eos)
        if label is not None:
            label = label.view_as(self.is_finished)
            self.is_finished = label.eq(self.eos)
        self.alive_seq = torch.cat([self.alive_seq, topk_ids], -1)

        if self.return_attention:
            if self.alive_attn is None:
                self.alive_attn = attn
            else:
                self.alive_attn = torch.cat([self.alive_attn, attn], 1)
        if self.return_hidden:
            if self.alive_hidden is None:
                self.alive_hidden = hidden
            else:
                self.alive_hidden = torch.cat([self.alive_hidden, hidden], 1)
        self.ensure_max_length()

    def update_finished(self):
        """Finalize scores and predictions."""
        finished_batches = self.is_finished.view(-1).nonzero()
        for b in finished_batches.view(-1):
            b_orig = self.original_batch_idx[b]
            # scores/predictions/attention are lists,
            # (to be compatible with beam-search)
            self.scores[b_orig].append(self.topk_scores[b, 0].item())
            self.predictions[b_orig].append(self.alive_seq[b, 1:])
            self.attention[b_orig].append(
                self.alive_attn[b, :, :self.memory_length] if self.alive_attn is not None else [])
            self.hidden[b_orig].append(
                self.alive_hidden[b, :] if self.alive_hidden is not None else [])
        self.done = self.is_finished.all()
        if self.done:
            return
        is_alive = ~self.is_finished.view(-1)
        self.alive_seq = self.alive_seq[is_alive]
        if self.alive_attn is not None:
            self.alive_attn = self.alive_attn[is_alive]
        if self.alive_hidden is not None:
            self.alive_hidden = self.alive_hidden[is_alive]
        self.select_indices = is_alive.nonzero().view(-1)
        self.original_batch_idx = self.original_batch_idx[is_alive]
        # select_indices is equal to original_batch_idx for greedy search?