m7n's picture
first commit
d1ed09d
raw
history blame
5.04 kB
import functools
import threading
from .composition import parse_composition, parse_heads_from_composition
class AdapterSetup:
"""
Represents an adapter setup of a model including active adapters and active heads. This class is intended to be
used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will
override static adapter setups defined in a model (i.e. setups specified via ``active_adapters``).
Example::
with AdapterSetup(Stack("a", "b")):
# will use the adapter stack "a" and "b" outputs = model(**inputs)
Note that the context manager is thread-local, i.e. it can be used with different setups in a multi-threaded
environment.
"""
# thread-local storage that holds a stack of active contexts
storage = threading.local()
def __init__(self, adapter_setup, head_setup=None, ignore_empty: bool = False):
self.adapter_setup = parse_composition(adapter_setup)
if head_setup:
self.head_setup = head_setup
else:
self.head_setup = parse_heads_from_composition(self.adapter_setup)
self._empty = ignore_empty and self.adapter_setup is None and self.head_setup is None
def __enter__(self):
if not self._empty:
AdapterSetup.get_contexts().append(self)
return self
def __exit__(self, type, value, traceback):
if not self._empty:
AdapterSetup.get_contexts().pop()
@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts
@classmethod
def get_context(cls):
try:
return cls.get_contexts()[-1]
except IndexError:
return None
@classmethod
def get_context_adapter_setup(cls):
context = cls.get_context()
if context:
return context.adapter_setup
return None
@classmethod
def get_context_head_setup(cls):
context = cls.get_context()
if context:
return context.head_setup
return None
class ForwardContext:
"""
Holds context information during a forward pass through a model. This class should be used via the
``ForwardContext.wrap()`` method.
Note that the context is thread-local.
"""
# thread-local storage that holds a stack of active contexts
storage = threading.local()
context_attributes = [
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens
def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)
def __enter__(self):
ForwardContext.get_contexts().append(self)
return self
def __exit__(self, type, value, traceback):
ForwardContext.get_contexts().pop()
@classmethod
def wrap(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a model class.
"""
@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)
# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
if output_context:
context_dict = ctx.__dict__
if output_context:
return results, context_dict
else:
return results
else:
return f(self, *args, **kwargs)
return wrapper_func
@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts
@classmethod
def get_context(cls):
try:
return cls.get_contexts()[-1]
except IndexError:
return None