Spaces:
Running
Running
# mypy: ignore-errors | |
from inspect import getattr_static | |
from ..bytecode_transformation import create_call_function | |
from ..exc import Unsupported | |
from .base import VariableTracker | |
class SDPAParamsVariable(VariableTracker): | |
"""Represents the c++ params struct for scaled dot product attention. | |
This is a read-only container.""" | |
def create(tx, value, source): | |
from torch.backends.cuda import SDPAParams | |
from ..source import AttrSource | |
from .builder import VariableBuilder | |
from .torch import TorchInGraphFunctionVariable | |
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) | |
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) | |
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) | |
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( | |
value.attn_mask | |
) | |
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) | |
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( | |
value.is_causal | |
) | |
param_vars = [ | |
query_var, | |
key_var, | |
value_var, | |
attn_mask_var, | |
dropout_var, | |
is_causal_var, | |
] | |
return TorchInGraphFunctionVariable(SDPAParams).call_function( | |
tx, param_vars, {} | |
) | |
def __init__(self, proxy, param_vars, **kwargs): | |
self.proxy = proxy | |
self.param_vars = param_vars | |
super().__init__(**kwargs) | |
def reconstruct(self, codegen): | |
assert self.source is None | |
assert self.param_vars is not None | |
codegen.load_import_from("torch._C", "_SDPAParams") | |
codegen.foreach(self.param_vars) | |
codegen.extend_output(create_call_function(len(self.param_vars), True)) | |
def as_proxy(self): | |
return self.proxy | |
def var_getattr(self, tx, name: str) -> VariableTracker: | |
import torch._C | |
from ..source import AttrSource | |
from .builder import wrap_fx_proxy | |
from .misc import GetAttrVariable | |
try: | |
getattr_static(torch._C._SDPAParams, name) | |
except AttributeError: | |
# Using raise from is too verbose here | |
raise Unsupported( # noqa: TRY200 | |
f"Unsupported torch._C._SDPAParams attribute {name}" | |
) | |
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name) | |
if self.source is not None: | |
return wrap_fx_proxy( | |
tx=tx, proxy=proxy, source=AttrSource(self.source, name) | |
) | |
else: | |
return wrap_fx_proxy(tx=tx, proxy=proxy) | |
def is_sdpa_params(value): | |
from torch.backends.cuda import SDPAParams | |
return value is SDPAParams | |