Spaces:
Running
on
Zero
Running
on
Zero
import itertools | |
import warnings | |
from collections.abc import Sequence | |
from typing import List, Optional, Set, Tuple, Union | |
import torch | |
class AdapterCompositionBlock(Sequence): | |
def __init__(self, *children): | |
self.children = [parse_composition(b, None) for b in children] | |
def __getitem__(self, key): | |
return self.children[key] | |
def __len__(self): | |
return len(self.children) | |
def __eq__(self, o: object) -> bool: | |
if isinstance(o, type(self)): | |
return all([c1 == c2 for c1, c2 in zip(self.children, o.children)]) | |
else: | |
return False | |
def __repr__(self): | |
child_repr = ", ".join(map(str, self.children)) | |
return f"{self.__class__.__name__}[{child_repr}]" | |
def first(self): | |
if not isinstance(self.children[0], AdapterCompositionBlock): | |
return self.children[0] | |
else: | |
return self.children[0].first() | |
def last(self): | |
if not isinstance(self.children[-1], AdapterCompositionBlock): | |
return self.children[-1] | |
else: | |
return self.children[-1].last() | |
def parallel_channels(self): | |
return max([b.parallel_channels if isinstance(b, AdapterCompositionBlock) else 1 for b in self.children]) | |
def flatten(self) -> Set[str]: | |
return set(itertools.chain(*[[b] if isinstance(b, str) else b.flatten() for b in self.children])) | |
class Parallel(AdapterCompositionBlock): | |
def __init__(self, *parallel_adapters: List[str]): | |
""" | |
Can be used to perform inference for multiple tasks (i.e., adapters) in parallel (for the same input). | |
See AdapterDrop https://arxiv.org/abs/2010.11918 | |
""" | |
super().__init__(*parallel_adapters) | |
def parallel_channels(self): | |
return len(self.children) | |
class Stack(AdapterCompositionBlock): | |
def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]): | |
super().__init__(*stack_layers) | |
class Fuse(AdapterCompositionBlock): | |
def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]): | |
super().__init__(*fuse_stacks) | |
# TODO-V2 pull this up to all block classes? | |
def name(self): | |
return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) | |
class Split(AdapterCompositionBlock): | |
def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], splits: Union[List[int], int]): | |
super().__init__(*split_adapters) | |
self.splits = splits if isinstance(splits, list) else [splits] * len(split_adapters) | |
class BatchSplit(AdapterCompositionBlock): | |
def __init__(self, *split_adapters: List[Union[AdapterCompositionBlock, str]], batch_sizes: Union[List[int], int]): | |
super().__init__(*split_adapters) | |
self.batch_sizes = batch_sizes if isinstance(batch_sizes, list) else [batch_sizes] * len(split_adapters) | |
class Average(AdapterCompositionBlock): | |
def __init__( | |
self, | |
*average_adapters: List[Union[AdapterCompositionBlock, str]], | |
weights: Optional[List[float]] = None, | |
normalize_weights: bool = True, | |
): | |
super().__init__(*average_adapters) | |
if weights is not None: | |
# normalize weights | |
if normalize_weights: | |
sum_weights = sum(weights) if weights else 1 | |
self.weights = [w / sum_weights for w in weights] | |
else: | |
self.weights = weights | |
else: | |
self.weights = [1 / len(average_adapters)] * len(average_adapters) | |
# Mapping each composition block type to the allowed nested types | |
ALLOWED_NESTINGS = { | |
Stack: [str, Fuse, Split, Parallel, BatchSplit, Average], | |
Fuse: [str, Stack], | |
Split: [str, Split, Stack, BatchSplit, Average], | |
Parallel: [str, Stack, BatchSplit, Average], | |
BatchSplit: [str, Stack, Split, BatchSplit, Average], | |
Average: [str, Stack, Split, BatchSplit], | |
} | |
# Some composition blocks might not be supported by all models. | |
# Add a whitelist of models for those here. | |
SUPPORTED_MODELS = { | |
Parallel: [ | |
"albert", | |
"bert", | |
"roberta", | |
"distilbert", | |
"deberta-v2", | |
"deberta", | |
"bart", | |
"mbart", | |
"mt5", | |
"plbart", | |
"gpt2", | |
"gptj", | |
"t5", | |
"vit", | |
"xlm-roberta", | |
"bert-generation", | |
"llama", | |
"mistral", | |
"electra", | |
"whisper", | |
"xmod", | |
], | |
} | |
def validate_composition(adapter_composition: AdapterCompositionBlock, level=0, model_type=None): | |
if level > 1 and not (isinstance(adapter_composition, Stack) or isinstance(adapter_composition, str)): | |
raise ValueError(f"Adapter setup is too deep. Cannot have {adapter_composition} at level {level}.") | |
if isinstance(adapter_composition, AdapterCompositionBlock): | |
block_type = type(adapter_composition) | |
if model_type and block_type in SUPPORTED_MODELS: | |
if model_type not in SUPPORTED_MODELS[block_type]: | |
raise ValueError( | |
f"Models of type {model_type} don't support adapter composition using {block_type.__name__}." | |
) | |
for child in adapter_composition: | |
if not type(child) in ALLOWED_NESTINGS[type(adapter_composition)]: | |
raise ValueError(f"Adapter setup is invalid. Cannot nest {child} in {adapter_composition}") | |
# recursively validate children | |
validate_composition(child, level=level + 1) | |
def parse_composition(adapter_composition, level=0, model_type=None) -> AdapterCompositionBlock: | |
""" | |
Parses and validates a setup of adapters. | |
Args: | |
adapter_composition: The adapter setup to be parsed. | |
level (int, optional): If set to none, disables validation. Defaults to 0. | |
""" | |
if not adapter_composition: | |
return None | |
elif isinstance(adapter_composition, AdapterCompositionBlock): | |
if level is not None: | |
validate_composition(adapter_composition, level=level, model_type=model_type) | |
return adapter_composition | |
elif isinstance(adapter_composition, str): | |
if level == 0: | |
return Stack(adapter_composition) | |
else: | |
return adapter_composition | |
elif isinstance(adapter_composition, Sequence): | |
# Functionality of adapter-transformers v1.x | |
warnings.warn( | |
"Passing list objects for adapter activation is deprecated. Please use Stack or Fuse explicitly.", | |
category=FutureWarning, | |
) | |
# for backwards compatibility | |
if level == 1: | |
block_class = Fuse | |
else: | |
block_class = Stack | |
level = level + 1 if level is not None else None | |
return block_class(*[parse_composition(b, level) for b in adapter_composition]) | |
else: | |
raise TypeError(adapter_composition) | |
def parse_heads_from_composition(adapter_composition, reference_heads: list = None): | |
""" | |
Parses a potential head configuration from a setup of adapters. | |
Args: | |
adapter_composition: The adapter setup to be parsed. | |
reference_heads: The list of available to validate the retrieved head configuration against. | |
""" | |
final_block = adapter_composition | |
if isinstance(final_block, Stack): | |
final_block = final_block.children[-1] | |
if isinstance(final_block, str) and (reference_heads is None or final_block in reference_heads): | |
return final_block | |
elif isinstance(final_block, Parallel): | |
return [a if isinstance(a, str) else a.last() for a in final_block.children] | |
elif isinstance(final_block, BatchSplit): | |
# Convert BatchSplit of adapters to a BatchSplit of heads. | |
blocks = [block.last() if isinstance(block, AdapterCompositionBlock) else block for block in final_block] | |
head_setup = BatchSplit(*blocks, batch_sizes=final_block.batch_sizes) | |
if reference_heads is None or all(head in reference_heads for head in head_setup): | |
return head_setup | |
else: | |
raise ValueError( | |
"Missing at least one head for the given BatchSplit setup. Expected heads: {}".format(blocks) | |
) | |
else: | |
return None | |
def adjust_tensors_for_parallel(hidden_states, *tensors): | |
""" | |
Replicates a given list of tensors based on the shape of the reference tensor (first argument). | |
""" | |
outputs = [] | |
for tensor in tensors: | |
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: | |
repeats = [1] * len(tensor.shape) | |
repeats[0] = hidden_states.shape[0] // tensor.shape[0] | |
new_tensor = tensor.repeat(*repeats) | |
outputs.append(new_tensor) | |
else: | |
outputs.append(tensor) | |
return tuple(outputs) | |
def adjust_tensors_for_parallel_(hidden_states, *tensors): | |
""" | |
In-place version of adjust_tensors_for_parallel(). | |
""" | |
for tensor in tensors: | |
if tensor is not None and hidden_states.shape[0] >= tensor.shape[0]: | |
repeats = [1] * len(tensor.shape) | |
repeats[0] = hidden_states.shape[0] // tensor.shape[0] | |
new_tensor = tensor.repeat(*repeats) | |
tensor.set_(new_tensor) | |
def match_attn_matrices_for_parallel(query, key, value) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Matches the shapes of query, key and value matrices for parallel composition. | |
""" | |
max_bsz = max(query.shape[0], key.shape[0], value.shape[0]) | |
query = query.repeat(max_bsz // query.shape[0], *([1] * len(query.shape[1:]))) | |
key = key.repeat(max_bsz // key.shape[0], *([1] * len(key.shape[1:]))) | |
value = value.repeat(max_bsz // value.shape[0], *([1] * len(value.shape[1:]))) | |
return query, key, value | |