Spaces:
Runtime error
Runtime error
# Copyright (c) 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the LICENSE file in | |
# the root directory of this source tree. An additional grant of patent rights | |
# can be found in the PATENTS file in the same directory. | |
import logging | |
from collections.abc import Iterable | |
from itertools import repeat | |
from typing import List, Optional, Tuple | |
import torch | |
from torch import Tensor | |
# ------------------------------------------------------------------------------ | |
# assert_equal() | |
# ------------------------------------------------------------------------------ | |
def assert_equal(value1, value2, name1=None, name2=None): | |
"""Asserts two values are equal otherwise raise an error.""" | |
str_name1 = "" if name1 is None else "{} ".format(name1) | |
str_name2 = "" if name2 is None else "{} ".format(name2) | |
if value1 != value2: | |
str_value1 = "{}" if name1 is None else "({})" | |
str_value1 = str_value1.format(value1) | |
str_value2 = "{}" if name2 is None else "({})" | |
str_value2 = str_value2.format(value2) | |
raise ValueError( | |
"Expected {}{} == {}{}".format(str_name1, str_value1, str_name2, str_value2) | |
) | |
def fill_config(config, key, value): | |
if value is not None: | |
if key not in config or config[key] is None: | |
config[key] = value | |
assert_equal(value, config[key], "value", f'config["{key}"]') | |
# ------------------------------------------------------------------------------ | |
# check_and_return_expected() | |
# ------------------------------------------------------------------------------ | |
def check_and_return_expected(value, undefined_value, expected_value, name=None): | |
""" | |
Return the expected value while checking if the given value is undefined or | |
equal to the expected value. | |
""" | |
if (undefined_value is None and value is None) or (undefined_value == value): | |
return expected_value | |
if value != expected_value: | |
str_name = "" if name is None else "{} ".format(name) | |
str_value = "{}" if name is None else "({})" | |
str_value = str_value.format(value) | |
raise ValueError( | |
"Expected {}{} == {}".format(str_name, str_value, expected_value) | |
) | |
return expected_value | |
# ------------------------------------------------------------------------------ | |
# get_time_axis() | |
# ------------------------------------------------------------------------------ | |
def get_time_axis(layout): | |
""" | |
Extract the time axis from the layout, for example for breaking sequence into | |
segments. | |
""" | |
if layout in ["TB", "TBD"]: | |
return 0 | |
if layout in ["BT", "BTD"]: | |
return 1 | |
if layout in ["BCTD"]: | |
return 2 | |
raise ValueError("Unsupported layout = {}".format(layout)) | |
# ------------------------------------------------------------------------------ | |
# get_batch_axis() | |
# ------------------------------------------------------------------------------ | |
def get_batch_axis(layout): | |
""" | |
Extract the batch axis from the layout | |
""" | |
if layout in ["TB", "TBD"]: | |
return 1 | |
if layout in ["BT", "BTD", "BCTD"]: | |
return 0 | |
raise ValueError("Unsupported layout = {}".format(layout)) | |
# ------------------------------------------------------------------------------ | |
# monotonically_increasing_and_bounded() | |
# ------------------------------------------------------------------------------ | |
def monotonically_increasing_and_bounded(iterable, min=None, max=None): | |
""" | |
Check if the elements in the given iterable are monotonically increasing and | |
bounded by upper/lower bounds. | |
""" | |
if not isinstance(iterable, Iterable): | |
raise TypeError( | |
"Expected iterable to be of type Iterable, got ({})".format( | |
iterable.__class__.__name__ | |
) | |
) | |
for i in range(len(iterable)): | |
if min is not None and iterable[i] < min: | |
return False | |
if max is not None and iterable[i] > max: | |
return False | |
if i > 0 and iterable[i] <= iterable[i - 1]: | |
return False | |
return True | |
# ------------------------------------------------------------------------------ | |
# to_pair() | |
# ------------------------------------------------------------------------------ | |
def to_pair(value, name): | |
"""Make a pair (of type tuple) of given value.""" | |
if isinstance(value, Iterable): | |
if len(value) != 2: | |
raise ValueError( | |
"Expected `{}` to have exactly 2 elements, got: ({})".format( | |
name, value | |
) | |
) | |
return value | |
return tuple(repeat(value, 2)) | |
# ------------------------------------------------------------------------------ | |
# infer_conv_output_attrs() | |
# ------------------------------------------------------------------------------ | |
# TODO(cfyeh): figure out if we can get `output_dim` without calling the module. | |
def infer_conv_output_attrs( | |
module, input_channels, input_dim, batch_size=1, max_length=8 | |
): | |
"""Get output attributes of a module with input.""" | |
input = torch.randn(batch_size, input_channels, max_length, input_dim) | |
output = module(input) | |
output_channels = output.shape[1] | |
output_dim = output.shape[-1] | |
return output_channels, output_dim | |
# ------------------------------------------------------------------------------ | |
# NoOp | |
# ------------------------------------------------------------------------------ | |
class NoOp(torch.nn.Module): | |
""" | |
NoOp simply passes the input as the output. | |
""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, input: Tensor) -> Tensor: | |
return input | |
# ------------------------------------------------------------------------------ | |
# Permute: a torch.nn.Module applies permutation on the input tensor. | |
# ------------------------------------------------------------------------------ | |
class Permute(torch.nn.Module): | |
def __init__(self, dims): | |
super().__init__() | |
self.dims = dims | |
def forward(self, input: Tensor) -> Tensor: | |
return input.permute(self.dims).contiguous() | |
# ------------------------------------------------------------------------------ | |
# lengths_to_padding_mask() | |
# ------------------------------------------------------------------------------ | |
def lengths_to_padding_mask(lengths: Tensor) -> Tensor: | |
"""Convert lengths of shape (B, ) to padding mask.""" | |
batch_size = lengths.shape[0] | |
max_length = int(torch.max(lengths).item()) | |
padding_mask = torch.arange( # [0, ..., T-1] | |
max_length, device=lengths.device, dtype=lengths.dtype | |
).expand(batch_size, max_length) >= lengths.unsqueeze(1) | |
return padding_mask | |
# ------------------------------------------------------------------------------ | |
# lengths_to_attention_mask() | |
# ------------------------------------------------------------------------------ | |
def lengths_to_attention_mask( | |
lengths: Tensor, | |
left_context: Optional[int] = None, | |
right_context: Optional[int] = None, | |
) -> Optional[Tensor]: | |
""" | |
Generate attention mask based on (lengths, left_context, right_context). | |
left_context is None means unlimited left context. | |
right_context is None means unlimited right context. | |
""" | |
if left_context is None and right_context is None: | |
return None | |
max_length = int(torch.max(lengths).item()) | |
# For example, with `max_length` == 5, | |
# indices = tensor([ | |
# [ 0, 1, 2, 3, 4, 5], | |
# [-1, 0, 1, 2, 3, 4], | |
# [-2, -1, 0, 1, 2, 3], | |
# [-3, -2, -1, 0, 1, 2], | |
# [-4, -3, -2, -1, 0, 1], | |
# [-5, -4, -3, -2, -1, 0], | |
# ]) | |
# In some cases the second torch.arange is created on cpu which causes a | |
# failure. Adding the device option to guard against it. | |
indices = torch.arange( | |
max_length, device=lengths.device, dtype=lengths.dtype | |
).expand(max_length, max_length) - torch.arange( | |
max_length, device=lengths.device | |
).view( | |
max_length, -1 | |
) | |
# For example, with `max_length` == 5, | |
# bool_mask = tensor([ | |
# [True, True, True, True, True], | |
# [True, True, True, True, True], | |
# [True, True, True, True, True], | |
# [True, True, True, True, True], | |
# [True, True, True, True, True], | |
# ]) | |
bool_mask = ( | |
torch.tensor([True]).to(device=lengths.device).expand(max_length, max_length) | |
) | |
# For example, with `max_length` == 5, left_context == 2 | |
# left_mask = tensor([ | |
# [ True, True, True, True, True], | |
# [ True, True, True, True, True], | |
# [ True, True, True, True, True], | |
# [False, True, True, True, True], | |
# [False, False, True, True, True], | |
# ]) | |
if left_context is not None: | |
left_mask = indices >= -left_context | |
bool_mask = bool_mask & left_mask | |
# For example, with `max_length` == 5, right_context == 1 | |
# right_mask = tensor([ | |
# [True, True, False, False, False], | |
# [True, True, True, False, False], | |
# [True, True, True, True, False], | |
# [True, True, True, True, True], | |
# [True, True, True, True, True], | |
# ]) | |
if right_context is not None: | |
right_mask = indices <= right_context | |
bool_mask = bool_mask & right_mask | |
bool_mask = (~bool_mask).to(device=lengths.device) | |
return bool_mask | |
# ------------------------------------------------------------------------------ | |
# infer_output_norm() | |
# ------------------------------------------------------------------------------ | |
def infer_output_norm(module, output_norm=None): | |
""" | |
Infer the output norm (string and module) needed on the module gvien desired | |
output normalization. | |
""" | |
if output_norm == module.output_norm(): | |
# output_norm already matches module.output_norm(). | |
return (None, NoOp()) | |
if output_norm is None and module.output_norm() is not None: | |
logger = logging.getLogger("infer_output_norm()") | |
logger.warning( | |
"trying to set output_norm ({}) ".format(output_norm) | |
+ "but got module.output_norm() ({}), ".format(module.output_norm()) | |
+ "the combined output_norm() will be ({})".format(module.output_norm()) | |
) | |
return (None, NoOp()) | |
if output_norm == "log_softmax": | |
if module.output_norm() is not None: | |
raise ValueError( | |
"incompatible output_norm ({}) ".format(output_norm) | |
+ "and module.output_norm() ({})".format(module.output_norm()) | |
) | |
else: | |
return ("log_softmax", torch.nn.LogSoftmax(dim=-1)) | |
if output_norm == "softmax": | |
if module.output_norm() is not None: | |
raise ValueError( | |
"incompatible output_norm ({}) ".format(output_norm) | |
+ "and module.output_norm() ({})".format(module.output_norm()) | |
) | |
else: | |
return ("softmax", torch.nn.Softmax(dim=-1)) | |
raise ValueError( | |
"output_norm ({}) not in ".format(output_norm) | |
+ "supported list = [None, softmax, log_softmax]" | |
) | |
# ------------------------------------------------------------------------------ | |
# infer_channels_from_layout() | |
# ------------------------------------------------------------------------------ | |
def infer_channels_from_layout(layout, channels): | |
"""Extract the number of channels from the layout.""" | |
if layout in ("TBD", "BTD"): | |
if channels is not None and channels != 1: | |
raise ValueError( | |
"Expected channels ({}) to be 1 for layout = {}".format( | |
channels, layout | |
) | |
) | |
if channels is None: | |
return 1 | |
return channels | |
# ------------------------------------------------------------------------------ | |
# pad_sequence() | |
# ------------------------------------------------------------------------------ | |
def pad_sequence( | |
sequence: Tensor, | |
time_axis: int, | |
extra_left_context: int = 0, | |
extra_right_context: int = 0, | |
) -> Tensor: | |
"""Pad extra left/right contexts to the sequence.""" | |
if extra_left_context == 0 and extra_right_context == 0: | |
return sequence | |
tensors_to_concat = [] | |
if extra_left_context: | |
size = (extra_left_context,) | |
fill_value = 0 | |
indices = torch.full( | |
size=size, | |
fill_value=fill_value, | |
dtype=torch.long, | |
device=sequence.device, | |
) | |
left_padding = torch.index_select(sequence, time_axis, indices) | |
tensors_to_concat.append(left_padding) | |
tensors_to_concat.append(sequence) | |
# NOTE(cfyeh): for efficiency reason we pad 0 instead of the last frame for | |
# extra right contexts. | |
if extra_right_context: | |
size = list(sequence.shape) | |
size[time_axis] = extra_right_context | |
right_padding = torch.zeros(size, dtype=sequence.dtype, device=sequence.device) | |
tensors_to_concat.append(right_padding) | |
padded_sequence = torch.cat(tensors_to_concat, dim=time_axis) | |
return padded_sequence | |
# ------------------------------------------------------------------------------ | |
# sequence_to_segments() | |
# ------------------------------------------------------------------------------ | |
def sequence_to_segments( | |
sequence: Tensor, | |
time_axis: int, | |
lengths: Tensor, | |
segment_size: Optional[int] = None, | |
extra_left_context: int = 0, | |
extra_right_context: int = 0, | |
) -> List[Tuple[Tensor, Tensor]]: | |
"""Breaks sequence into segments.""" | |
sequence = pad_sequence( | |
sequence=sequence, | |
time_axis=time_axis, | |
extra_left_context=extra_left_context, | |
extra_right_context=extra_right_context, | |
) | |
lengths = lengths + extra_left_context + extra_right_context | |
segments: List[Tuple[Tensor, Tensor]] = [] | |
if segment_size is None: | |
segments.append((sequence, lengths)) | |
return segments | |
offset = 0 | |
end = sequence.shape[time_axis] | |
step = segment_size | |
size = extra_left_context + segment_size + extra_right_context | |
while offset + extra_left_context + extra_right_context < end: | |
clamped_size = min(size, end - offset) | |
segment_lengths = torch.clamp(lengths - offset, min=0, max=clamped_size) | |
indices = torch.arange( | |
start=offset, | |
end=(offset + clamped_size), | |
step=1, | |
dtype=torch.long, | |
device=sequence.device, | |
) | |
segment_tensor = torch.index_select(sequence, time_axis, indices) | |
segments.append((segment_tensor, segment_lengths)) | |
offset = offset + step | |
return segments | |
# ------------------------------------------------------------------------------ | |
# segments_to_sequence() | |
# ------------------------------------------------------------------------------ | |
def segments_to_sequence( | |
segments: List[Tuple[Tensor, Tensor]], time_axis: int | |
) -> Tuple[Tensor, Tensor]: | |
"""Concatenate segments into a full sequence.""" | |
if len(segments) == 1: | |
return segments[0] | |
tensors_to_concat: List[Tensor] = [] | |
lengths_to_stack: List[Tensor] = [] | |
for tensor, lengths in segments: | |
tensors_to_concat.append(tensor) | |
lengths_to_stack.append(lengths) | |
sequence = torch.cat(tensors_to_concat, dim=time_axis) | |
lengths = torch.stack(lengths_to_stack, dim=0) | |
lengths = torch.sum(lengths, dim=0) | |
return sequence, lengths | |
def lengths_to_encoder_padding_mask(lengths, batch_first: bool = False): | |
""" | |
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor | |
Args: | |
lengths: a (B, )-shaped tensor | |
batch_first: whether to return a (B, T) tensor | |
Return: | |
max_length: maximum length of B sequences | |
encoder_padding_mask: a (max_length, B) binary mask, where | |
[t, b] = False for t < lengths[b] and True otherwise | |
TODO: | |
kernelize this function if benchmarking shows this function is slow | |
""" | |
max_lengths = torch.max(lengths).item() | |
bsz = lengths.size(0) | |
encoder_padding_mask = torch.arange( | |
max_lengths | |
).to( # a (T, ) tensor with [0, ..., T-1] | |
lengths.device | |
).view( # move to the right device | |
1, max_lengths | |
).expand( # reshape to (1, T)-shaped tensor | |
bsz, -1 | |
) > lengths.view( # expand to (B, T)-shaped tensor | |
bsz, 1 | |
).expand( | |
-1, max_lengths | |
) | |
if not batch_first: | |
return encoder_padding_mask.t(), max_lengths | |
else: | |
return encoder_padding_mask, max_lengths | |
# ------------------------------------------------------------------------------ | |
# attention suppression | |
# ------------------------------------------------------------------------------ | |
def attention_suppression(attention_weights: Tensor, scale: float): | |
# B, H, qlen, klen -> B, H, qlen, 1 | |
attention_prob = torch.nn.functional.softmax(attention_weights.float(), dim=-1) | |
attention_nozeros = attention_prob.to(torch.bool) | |
nozeros_sum = torch.sum(attention_nozeros.to(torch.float), dim=-1, keepdim=True) | |
# For very sparse situation, we need get round about 0s | |
key_sum = torch.sum(attention_prob, dim=-1, keepdim=True) | |
# nozeros_sum should > 1 | |
key_mean = key_sum / (nozeros_sum + 1e-8) | |
# std calculation | |
dis = (attention_prob - key_mean) * (attention_prob - key_mean) | |
# if attention_prob[i] < threshold, then dis_masked[i] = 0; for all i | |
dis_masked = torch.where( | |
attention_nozeros, dis, attention_prob.new_zeros(attention_prob.size()) | |
) | |
key_var = torch.sum(dis_masked, dim=-1, keepdim=True) | |
key_var = key_var / (nozeros_sum - 1.0 + 1e-8) | |
key_std = torch.sqrt(key_var) | |
key_thread = key_mean - scale * key_std | |
# if attention_prob[i] >= key_thread, then attention_prob[i] | |
# , otherwise "-inf" | |
inf_tensor = attention_prob.new_zeros(attention_prob.size()).detach() | |
inf_tensor[:] = float("-inf") | |
attention_weights_float = torch.where( | |
attention_prob < key_thread, | |
inf_tensor, | |
attention_weights.float(), | |
) | |
return attention_weights_float.type_as(attention_weights) | |
def layer_norm_backward_hook(module, grad_input, grad_output, clamp_value): | |
return tuple(torch.clamp(v, min=-clamp_value, max=clamp_value) for v in grad_input) | |