Spaces:
Running
on
Zero
Running
on
Zero
import functools | |
import threading | |
from .composition import parse_composition, parse_heads_from_composition | |
class AdapterSetup: | |
""" | |
Represents an adapter setup of a model including active adapters and active heads. This class is intended to be | |
used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will | |
override static adapter setups defined in a model (i.e. setups specified via ``active_adapters``). | |
Example:: | |
with AdapterSetup(Stack("a", "b")): | |
# will use the adapter stack "a" and "b" outputs = model(**inputs) | |
Note that the context manager is thread-local, i.e. it can be used with different setups in a multi-threaded | |
environment. | |
""" | |
# thread-local storage that holds a stack of active contexts | |
storage = threading.local() | |
def __init__(self, adapter_setup, head_setup=None, ignore_empty: bool = False): | |
self.adapter_setup = parse_composition(adapter_setup) | |
if head_setup: | |
self.head_setup = head_setup | |
else: | |
self.head_setup = parse_heads_from_composition(self.adapter_setup) | |
self._empty = ignore_empty and self.adapter_setup is None and self.head_setup is None | |
def __enter__(self): | |
if not self._empty: | |
AdapterSetup.get_contexts().append(self) | |
return self | |
def __exit__(self, type, value, traceback): | |
if not self._empty: | |
AdapterSetup.get_contexts().pop() | |
def get_contexts(cls): | |
if not hasattr(cls.storage, "contexts"): | |
cls.storage.contexts = [] | |
return cls.storage.contexts | |
def get_context(cls): | |
try: | |
return cls.get_contexts()[-1] | |
except IndexError: | |
return None | |
def get_context_adapter_setup(cls): | |
context = cls.get_context() | |
if context: | |
return context.adapter_setup | |
return None | |
def get_context_head_setup(cls): | |
context = cls.get_context() | |
if context: | |
return context.head_setup | |
return None | |
class ForwardContext: | |
""" | |
Holds context information during a forward pass through a model. This class should be used via the | |
``ForwardContext.wrap()`` method. | |
Note that the context is thread-local. | |
""" | |
# thread-local storage that holds a stack of active contexts | |
storage = threading.local() | |
context_attributes = [ | |
"adapter_gating_scores", | |
"adapter_fusion_attentions", | |
"adapter_input_parallelized", | |
] | |
# Additional used attributes not exposed to the user | |
# - prompt_tokens_length: length of the prompt tokens | |
def __init__(self, model, *args, **kwargs): | |
# If the model has a method ``forward_context()``, use it to create the context. | |
if hasattr(model, "forward_context"): | |
model.forward_context(self, *args, **kwargs) | |
def __enter__(self): | |
ForwardContext.get_contexts().append(self) | |
return self | |
def __exit__(self, type, value, traceback): | |
ForwardContext.get_contexts().pop() | |
def wrap(cls, f): | |
""" | |
Decorator method that wraps a ``forward()`` function of a model class. | |
""" | |
def wrapper_func(self, *args, **kwargs): | |
if self.adapters_config is not None: | |
with cls(self, *args, **kwargs) as ctx: | |
# whether to output the context attributes | |
output_context = kwargs.pop("output_context", False) | |
kwargs = { | |
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes | |
} | |
results = f(self, *args, **kwargs) | |
# append output attributes | |
if isinstance(results, tuple): | |
for attr in cls.context_attributes: | |
if getattr(ctx, "output_" + attr, False): | |
results = results + (dict(getattr(ctx, attr)),) | |
else: | |
for attr in cls.context_attributes: | |
if getattr(ctx, "output_" + attr, False): | |
results[attr] = dict(getattr(ctx, attr)) | |
if output_context: | |
context_dict = ctx.__dict__ | |
if output_context: | |
return results, context_dict | |
else: | |
return results | |
else: | |
return f(self, *args, **kwargs) | |
return wrapper_func | |
def get_contexts(cls): | |
if not hasattr(cls.storage, "contexts"): | |
cls.storage.contexts = [] | |
return cls.storage.contexts | |
def get_context(cls): | |
try: | |
return cls.get_contexts()[-1] | |
except IndexError: | |
return None | |