File size: 3,776 Bytes
1d117d0
 
 
 
 
 
 
 
8e4b5fb
1d117d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e4b5fb
1d117d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
try :
    import xformers
except ImportError:
    pass
import torch


def attention_xformers(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False
) -> torch.Tensor:
    """#### Make an attention call using xformers. Fastest attention implementation.

    #### Args:
        - `q` (torch.Tensor): The query tensor.
        - `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
        - `v` (torch.Tensor): The value tensor, must have the same shape as `q`.
        - `heads` (int): The number of heads, must be a divisor of the hidden dimension.
        - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.

    #### Returns:
        - `torch.Tensor`: The output tensor.
    """
    b, _, dim_head = q.shape
    dim_head //= heads

    q, k, v = map(
        lambda t: t.unsqueeze(3)
        .reshape(b, -1, heads, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b * heads, -1, dim_head)
        .contiguous(),
        (q, k, v),
    )

    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)

    out = (
        out.unsqueeze(0)
        .reshape(b, heads, -1, dim_head)
        .permute(0, 2, 1, 3)
        .reshape(b, -1, heads * dim_head)
    )
    return out


def attention_pytorch(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False
) -> torch.Tensor:
    """#### Make an attention call using PyTorch.

    #### Args:
        - `q` (torch.Tensor): The query tensor.
        - `k` (torch.Tensor): The key tensor, must have the same shape as `q.
        - `v` (torch.Tensor): The value tensor, must have the same shape as `q.
        - `heads` (int): The number of heads, must be a divisor of the hidden dimension.
        - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.

    #### Returns:
        - `torch.Tensor`: The output tensor.
    """
    b, _, dim_head = q.shape
    dim_head //= heads
    q, k, v = map(
        lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
        (q, k, v),
    )

    out = torch.nn.functional.scaled_dot_product_attention(
        q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
    )
    out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
    return out


def xformers_attention(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
    """#### Compute attention using xformers.

    #### Args:
        - `q` (torch.Tensor): The query tensor.
        - `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
        - `v` (torch.Tensor): The value tensor, must have the same shape as `q`.

    Returns:
        - `torch.Tensor`: The output tensor.
    """
    B, C, H, W = q.shape
    q, k, v = map(
        lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
        (q, k, v),
    )
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
    out = out.transpose(1, 2).reshape(B, C, H, W)
    return out


def pytorch_attention(
    q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> torch.Tensor:
    """#### Compute attention using PyTorch.

    #### Args:
        - `q` (torch.Tensor): The query tensor.
        - `k` (torch.Tensor): The key tensor, must have the same shape as `q.
        - `v` (torch.Tensor): The value tensor, must have the same shape as `q.

    #### Returns:
        - `torch.Tensor`: The output tensor.
    """
    B, C, H, W = q.shape
    q, k, v = map(
        lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
        (q, k, v),
    )
    out = torch.nn.functional.scaled_dot_product_attention(
        q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
    )
    out = out.transpose(2, 3).reshape(B, C, H, W)
    return out