Spaces:
Sleeping
Sleeping
import torch | |
from .decode_strategy import DecodeStrategy | |
class BeamSearch(DecodeStrategy): | |
"""Generation with beam search. | |
""" | |
def __init__(self, pad, bos, eos, batch_size, beam_size, n_best, min_length, | |
return_attention, max_length): | |
super(BeamSearch, self).__init__( | |
pad, bos, eos, batch_size, beam_size, min_length, return_attention, max_length) | |
self.beam_size = beam_size | |
self.n_best = n_best | |
# result caching | |
self.hypotheses = [[] for _ in range(batch_size)] | |
# beam state | |
self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) | |
self._batch_offset = torch.arange(batch_size, dtype=torch.long) | |
self.select_indices = None | |
self.done = False | |
def initialize(self, memory_bank, device=None): | |
"""Repeat src objects `beam_size` times. | |
""" | |
def fn_map_state(state, dim): | |
return torch.repeat_interleave(state, self.beam_size, dim=dim) | |
memory_bank = torch.repeat_interleave(memory_bank, self.beam_size, dim=0) | |
if device is None: | |
device = memory_bank.device | |
self.memory_length = memory_bank.size(1) | |
super().initialize(memory_bank, device) | |
self.best_scores = torch.full([self.batch_size], -1e10, dtype=torch.float, device=device) | |
self._beam_offset = torch.arange( | |
0, self.batch_size * self.beam_size, step=self.beam_size, dtype=torch.long, device=device) | |
self.topk_log_probs = torch.tensor( | |
[0.0] + [float("-inf")] * (self.beam_size - 1), device=device | |
).repeat(self.batch_size) | |
# buffers for the topk scores and 'backpointer' | |
self.topk_scores = torch.empty((self.batch_size, self.beam_size), dtype=torch.float, device=device) | |
self.topk_ids = torch.empty((self.batch_size, self.beam_size), dtype=torch.long, device=device) | |
self._batch_index = torch.empty([self.batch_size, self.beam_size], dtype=torch.long, device=device) | |
return fn_map_state, memory_bank | |
def current_predictions(self): | |
return self.alive_seq[:, -1] | |
def current_backptr(self): | |
# for testing | |
return self.select_indices.view(self.batch_size, self.beam_size) | |
def batch_offset(self): | |
return self._batch_offset | |
def _pick(self, log_probs): | |
"""Return token decision for a step. | |
Args: | |
log_probs (FloatTensor): (B, vocab_size) | |
Returns: | |
topk_scores (FloatTensor): (B, beam_size) | |
topk_ids (LongTensor): (B, beam_size) | |
""" | |
vocab_size = log_probs.size(-1) | |
# Flatten probs into a list of probabilities. | |
curr_scores = log_probs.reshape(-1, self.beam_size * vocab_size) | |
topk_scores, topk_ids = torch.topk(curr_scores, self.beam_size, dim=-1) | |
return topk_scores, topk_ids | |
def advance(self, log_probs, attn): | |
""" | |
Args: | |
log_probs: (B * beam_size, vocab_size) | |
""" | |
vocab_size = log_probs.size(-1) | |
# (non-finished) batch_size | |
_B = log_probs.shape[0] // self.beam_size | |
step = len(self) # alive_seq | |
self.ensure_min_length(log_probs) | |
# Multiply probs by the beam probability | |
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1) | |
curr_length = step + 1 | |
curr_scores = log_probs / curr_length # avg log_prob | |
self.topk_scores, self.topk_ids = self._pick(curr_scores) | |
# topk_scores/topk_ids: (batch_size, beam_size) | |
# Recover log probs | |
torch.mul(self.topk_scores, curr_length, out=self.topk_log_probs) | |
# Resolve beam origin and map to batch index flat representation. | |
self._batch_index = self.topk_ids // vocab_size | |
self._batch_index += self._beam_offset[:_B].unsqueeze(1) | |
self.select_indices = self._batch_index.view(_B * self.beam_size) | |
self.topk_ids.fmod_(vocab_size) # resolve true word ids | |
# Append last prediction. | |
self.alive_seq = torch.cat( | |
[self.alive_seq.index_select(0, self.select_indices), | |
self.topk_ids.view(_B * self.beam_size, 1)], -1) | |
if self.return_attention: | |
current_attn = attn.index_select(1, self.select_indices) | |
if step == 1: | |
self.alive_attn = current_attn | |
else: | |
self.alive_attn = self.alive_attn.index_select( | |
1, self.select_indices) | |
self.alive_attn = torch.cat([self.alive_attn, current_attn], 0) | |
self.is_finished = self.topk_ids.eq(self.eos) | |
self.ensure_max_length() | |
def update_finished(self): | |
_B_old = self.topk_log_probs.shape[0] | |
step = self.alive_seq.shape[-1] # len(self) | |
self.topk_log_probs.masked_fill_(self.is_finished, -1e10) | |
self.is_finished = self.is_finished.to('cpu') | |
self.top_beam_finished |= self.is_finished[:, 0].eq(1) | |
predictions = self.alive_seq.view(_B_old, self.beam_size, step) | |
attention = ( | |
self.alive_attn.view( | |
step - 1, _B_old, self.beam_size, self.alive_attn.size(-1)) | |
if self.alive_attn is not None else None) | |
non_finished_batch = [] | |
for i in range(self.is_finished.size(0)): | |
b = self._batch_offset[i] | |
finished_hyp = self.is_finished[i].nonzero(as_tuple=False).view(-1) | |
# Store finished hypothesis for this batch. | |
for j in finished_hyp: # Beam level: finished beam j in batch i | |
self.hypotheses[b].append(( | |
self.topk_scores[i, j], | |
predictions[i, j, 1:], # Ignore start token | |
attention[:, i, j, :self.memory_length] | |
if attention is not None else None)) | |
# End condition is the top beam finished and we can return | |
# n_best hypotheses. | |
finish_flag = self.top_beam_finished[i] != 0 | |
if finish_flag and len(self.hypotheses[b]) >= self.n_best: | |
best_hyp = sorted( | |
self.hypotheses[b], key=lambda x: x[0], reverse=True) | |
for n, (score, pred, attn) in enumerate(best_hyp): | |
if n >= self.n_best: | |
break | |
self.scores[b].append(score.item()) | |
self.predictions[b].append(pred) | |
self.attention[b].append( | |
attn if attn is not None else []) | |
else: | |
non_finished_batch.append(i) | |
non_finished = torch.tensor(non_finished_batch) | |
if len(non_finished) == 0: | |
self.done = True | |
return | |
_B_new = non_finished.shape[0] | |
# Remove finished batches for the next step | |
self.top_beam_finished = self.top_beam_finished.index_select(0, non_finished) | |
self._batch_offset = self._batch_offset.index_select(0, non_finished) | |
non_finished = non_finished.to(self.topk_ids.device) | |
self.topk_log_probs = self.topk_log_probs.index_select(0, non_finished) | |
self._batch_index = self._batch_index.index_select(0, non_finished) | |
self.select_indices = self._batch_index.view(_B_new * self.beam_size) | |
self.alive_seq = predictions.index_select(0, non_finished).view(-1, self.alive_seq.size(-1)) | |
self.topk_scores = self.topk_scores.index_select(0, non_finished) | |
self.topk_ids = self.topk_ids.index_select(0, non_finished) | |
if self.alive_attn is not None: | |
inp_seq_len = self.alive_attn.size(-1) | |
self.alive_attn = attention.index_select(1, non_finished) \ | |
.view(step - 1, _B_new * self.beam_size, inp_seq_len) | |