rrivera1849 commited on
Commit
fcb4463
·
1 Parent(s): 1c4fba2

Upload LUAR

Browse files
Files changed (3) hide show
  1. config.json +4 -1
  2. config.py +6 -0
  3. model.py +116 -6
config.json CHANGED
@@ -7,7 +7,10 @@
7
  "AutoModel": "model.LUAR"
8
  },
9
  "embedding_size": 512,
 
10
  "model_type": "LUAR",
 
11
  "torch_dtype": "float32",
12
- "transformers_version": "4.33.2"
 
13
  }
 
7
  "AutoModel": "model.LUAR"
8
  },
9
  "embedding_size": 512,
10
+ "k_bucket_size": 1024,
11
  "model_type": "LUAR",
12
+ "q_bucket_size": 512,
13
  "torch_dtype": "float32",
14
+ "transformers_version": "4.33.2",
15
+ "use_memory_efficient_attention": false
16
  }
config.py CHANGED
@@ -6,7 +6,13 @@ class LUARConfig(PretrainedConfig):
6
 
7
  def __init__(self,
8
  embedding_size: int = 512,
 
 
 
9
  **kwargs,
10
  ):
11
  self.embedding_size = embedding_size
 
 
 
12
  super().__init__(**kwargs)
 
6
 
7
  def __init__(self,
8
  embedding_size: int = 512,
9
+ use_memory_efficient_attention=False,
10
+ q_bucket_size=512,
11
+ k_bucket_size=1024,
12
  **kwargs,
13
  ):
14
  self.embedding_size = embedding_size
15
+ self.use_memory_efficient_attention = use_memory_efficient_attention
16
+ self.q_bucket_size = q_bucket_size
17
+ self.k_bucket_size = k_bucket_size
18
  super().__init__(**kwargs)
model.py CHANGED
@@ -1,29 +1,135 @@
1
 
2
  import math
 
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  from einops import rearrange, reduce, repeat
 
8
  from transformers import AutoModel, PreTrainedModel
9
 
10
  from .config import LUARConfig
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  class SelfAttention(nn.Module):
13
  """Implements Dot-Product Self-Attention as used in "Attention is all You Need".
14
  """
15
- def __init__(self):
 
 
 
 
 
16
  super(SelfAttention, self).__init__()
 
 
 
17
 
18
  def forward(self, k, q, v):
19
- if hasattr(F, "scaled_dot_product_attention") and torch.cuda.is_available():
20
- with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=True):
21
- return F.scaled_dot_product_attention(k, q, v)
 
 
 
 
 
 
 
 
 
 
 
22
  else:
23
  d_k = q.size(-1)
24
  scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
25
  p_attn = F.softmax(scores, dim=-1)
26
-
27
  return torch.matmul(p_attn, v)
28
 
29
  class LUAR(PreTrainedModel):
@@ -34,7 +140,11 @@ class LUAR(PreTrainedModel):
34
  def __init__(self, config):
35
  super().__init__(config)
36
  self.create_transformer()
37
- self.attn_fn = SelfAttention()
 
 
 
 
38
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
39
 
40
  def create_transformer(self):
 
1
 
2
  import math
3
+ from functools import partial
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  from einops import rearrange, reduce, repeat
9
+ from torch.utils.checkpoint import checkpoint
10
  from transformers import AutoModel, PreTrainedModel
11
 
12
  from .config import LUARConfig
13
 
14
+ # Adapted LucidRains impl. of Memory Efficient Attention
15
+ # https://github.com/lucidrains/memory-efficient-attention-pytorch
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+ def summarize_qkv_chunk(
21
+ q, k, v,
22
+ mask
23
+ ):
24
+ """Dot-Product Attention for a chunk of queries, keys, and values.
25
+ """
26
+ weight = torch.einsum('b h i d, b h j d -> b h i j', q, k)
27
+
28
+ if exists(mask):
29
+ # HuggingFace masks have to be added:
30
+ weight += mask
31
+
32
+ weight_max = weight.amax(dim = -1, keepdim = True).detach()
33
+ weight = weight - weight_max
34
+
35
+ exp_weight = weight.exp()
36
+ weighted_value = torch.einsum('b h i j, b h j d -> b h i d', exp_weight, v)
37
+
38
+ return exp_weight.sum(dim = -1), weighted_value, rearrange(weight_max, '... 1 -> ...')
39
+
40
+ checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
41
+
42
+ def memory_efficient_attention(
43
+ q, k, v,
44
+ mask = None,
45
+ q_bucket_size = 512,
46
+ k_bucket_size = 1024,
47
+ eps = 1e-8
48
+ ):
49
+ scale = q.shape[-1] ** -0.5
50
+ q = q * scale
51
+
52
+ # function
53
+ needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
54
+ summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
55
+
56
+ # chunk all the inputs
57
+ q_chunks = q.split(q_bucket_size, dim = -2)
58
+ k_chunks = k.split(k_bucket_size, dim = -2)
59
+ v_chunks = v.split(k_bucket_size, dim = -2)
60
+ mask_chunks = mask.split(k_bucket_size, dim = -1) if exists(mask) else ((None,) * len(k_chunks))
61
+
62
+ # loop through all chunks and accumulate
63
+ out = []
64
+ for q_index, q_chunk in enumerate(q_chunks):
65
+ exp_weights = []
66
+ weighted_values = []
67
+ weight_maxes = []
68
+
69
+ for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
70
+
71
+ exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
72
+ q_chunk,
73
+ k_chunk,
74
+ v_chunk,
75
+ mask_chunk,
76
+ )
77
+
78
+ exp_weights.append(exp_weight_chunk)
79
+ weighted_values.append(weighted_value_chunk)
80
+ weight_maxes.append(weight_max_chunk)
81
+
82
+ exp_weights = torch.stack(exp_weights, dim = -1)
83
+ weighted_values = torch.stack(weighted_values, dim = -1)
84
+ weight_maxes = torch.stack(weight_maxes, dim = -1)
85
+
86
+ global_max = weight_maxes.amax(dim = -1, keepdim = True)
87
+ renorm_factor = (weight_maxes - global_max).exp().detach()
88
+
89
+ exp_weights = exp_weights * renorm_factor
90
+ weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
91
+
92
+ all_values = weighted_values.sum(dim = -1)
93
+ all_weights = exp_weights.sum(dim = -1)
94
+
95
+ normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
96
+ out.append(normalized_values)
97
+
98
+ return torch.cat(out, dim=-2)
99
+
100
  class SelfAttention(nn.Module):
101
  """Implements Dot-Product Self-Attention as used in "Attention is all You Need".
102
  """
103
+ def __init__(
104
+ self,
105
+ memory_efficient_attention=False,
106
+ q_bucket_size=512,
107
+ k_bucket_size=1024,
108
+ ):
109
  super(SelfAttention, self).__init__()
110
+ self.use_memory_efficient_attention = memory_efficient_attention
111
+ self.q_bucket_size = q_bucket_size
112
+ self.k_bucket_size = k_bucket_size
113
 
114
  def forward(self, k, q, v):
115
+
116
+ if self.use_memory_efficient_attention:
117
+ q, k, v = map(
118
+ lambda t: rearrange(t, 'b n (h d) -> b h n d', h = 12),
119
+ (q, k, v)
120
+ )
121
+
122
+ out = memory_efficient_attention(
123
+ q, k, v,
124
+ q_bucket_size=self.q_bucket_size,
125
+ k_bucket_size=self.k_bucket_size
126
+ )
127
+ out = rearrange(out, 'b h n d -> b n (h d)')
128
+ return out
129
  else:
130
  d_k = q.size(-1)
131
  scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)
132
  p_attn = F.softmax(scores, dim=-1)
 
133
  return torch.matmul(p_attn, v)
134
 
135
  class LUAR(PreTrainedModel):
 
140
  def __init__(self, config):
141
  super().__init__(config)
142
  self.create_transformer()
143
+ self.attn_fn = SelfAttention(
144
+ config.use_memory_efficient_attention,
145
+ config.q_bucket_size,
146
+ config.k_bucket_size,
147
+ )
148
  self.linear = nn.Linear(self.hidden_size, config.embedding_size)
149
 
150
  def create_transformer(self):