File size: 2,951 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# mypy: ignore-errors

from inspect import getattr_static

from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker


class SDPAParamsVariable(VariableTracker):
    """Represents the c++ params struct for scaled dot product attention.

    This is a read-only container."""

    @staticmethod
    def create(tx, value, source):
        from torch.backends.cuda import SDPAParams
        from ..source import AttrSource
        from .builder import VariableBuilder
        from .torch import TorchInGraphFunctionVariable

        query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
        key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
        value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
        attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
            value.attn_mask
        )
        dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
        is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
            value.is_causal
        )
        param_vars = [
            query_var,
            key_var,
            value_var,
            attn_mask_var,
            dropout_var,
            is_causal_var,
        ]
        return TorchInGraphFunctionVariable(SDPAParams).call_function(
            tx, param_vars, {}
        )

    def __init__(self, proxy, param_vars, **kwargs):
        self.proxy = proxy
        self.param_vars = param_vars
        super().__init__(**kwargs)

    def reconstruct(self, codegen):
        assert self.source is None
        assert self.param_vars is not None
        codegen.load_import_from("torch._C", "_SDPAParams")
        codegen.foreach(self.param_vars)
        codegen.extend_output(create_call_function(len(self.param_vars), True))

    def as_proxy(self):
        return self.proxy

    def var_getattr(self, tx, name: str) -> VariableTracker:
        import torch._C
        from ..source import AttrSource
        from .builder import wrap_fx_proxy
        from .misc import GetAttrVariable

        try:
            getattr_static(torch._C._SDPAParams, name)
        except AttributeError:
            # Using raise from is too verbose here
            raise Unsupported(  # noqa: TRY200
                f"Unsupported torch._C._SDPAParams attribute {name}"
            )

        proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
        if self.source is not None:
            return wrap_fx_proxy(
                tx=tx, proxy=proxy, source=AttrSource(self.source, name)
            )
        else:
            return wrap_fx_proxy(tx=tx, proxy=proxy)

    @staticmethod
    def is_sdpa_params(value):
        from torch.backends.cuda import SDPAParams

        return value is SDPAParams