Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,038 Bytes
d1ed09d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import functools
import threading
from .composition import parse_composition, parse_heads_from_composition
class AdapterSetup:
"""
Represents an adapter setup of a model including active adapters and active heads. This class is intended to be
used as a context manager using the ``with`` statement. The setup defined by the ``AdapterSetup`` context will
override static adapter setups defined in a model (i.e. setups specified via ``active_adapters``).
Example::
with AdapterSetup(Stack("a", "b")):
# will use the adapter stack "a" and "b" outputs = model(**inputs)
Note that the context manager is thread-local, i.e. it can be used with different setups in a multi-threaded
environment.
"""
# thread-local storage that holds a stack of active contexts
storage = threading.local()
def __init__(self, adapter_setup, head_setup=None, ignore_empty: bool = False):
self.adapter_setup = parse_composition(adapter_setup)
if head_setup:
self.head_setup = head_setup
else:
self.head_setup = parse_heads_from_composition(self.adapter_setup)
self._empty = ignore_empty and self.adapter_setup is None and self.head_setup is None
def __enter__(self):
if not self._empty:
AdapterSetup.get_contexts().append(self)
return self
def __exit__(self, type, value, traceback):
if not self._empty:
AdapterSetup.get_contexts().pop()
@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts
@classmethod
def get_context(cls):
try:
return cls.get_contexts()[-1]
except IndexError:
return None
@classmethod
def get_context_adapter_setup(cls):
context = cls.get_context()
if context:
return context.adapter_setup
return None
@classmethod
def get_context_head_setup(cls):
context = cls.get_context()
if context:
return context.head_setup
return None
class ForwardContext:
"""
Holds context information during a forward pass through a model. This class should be used via the
``ForwardContext.wrap()`` method.
Note that the context is thread-local.
"""
# thread-local storage that holds a stack of active contexts
storage = threading.local()
context_attributes = [
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens
def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)
def __enter__(self):
ForwardContext.get_contexts().append(self)
return self
def __exit__(self, type, value, traceback):
ForwardContext.get_contexts().pop()
@classmethod
def wrap(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a model class.
"""
@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)
# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
if output_context:
context_dict = ctx.__dict__
if output_context:
return results, context_dict
else:
return results
else:
return f(self, *args, **kwargs)
return wrapper_func
@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts
@classmethod
def get_context(cls):
try:
return cls.get_contexts()[-1]
except IndexError:
return None
|