Vanyadoing commited on
Commit
307213f
·
verified ·
1 Parent(s): b565be2

Update depth_anything_v2/dinov2_layers/attention.py

Browse files
depth_anything_v2/dinov2_layers/attention.py CHANGED
@@ -64,18 +64,16 @@ class Attention(nn.Module):
64
 
65
  class MemEffAttention(Attention):
66
  def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
- if not XFORMERS_AVAILABLE:
 
68
  assert attn_bias is None, "xFormers is required for nested tensors usage"
69
  return super().forward(x)
70
-
71
  B, N, C = x.shape
72
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
-
74
  q, k, v = unbind(qkv, 2)
75
-
76
  x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
  x = x.reshape([B, N, C])
78
-
79
  x = self.proj(x)
80
  x = self.proj_drop(x)
81
  return x
 
64
 
65
  class MemEffAttention(Attention):
66
  def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ # If xformers is not available, or input is not CUDA, use vanilla attention
68
+ if (not XFORMERS_AVAILABLE) or (not x.is_cuda):
69
  assert attn_bias is None, "xFormers is required for nested tensors usage"
70
  return super().forward(x)
71
+ # Otherwise, use memory efficient attention
72
  B, N, C = x.shape
73
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
 
74
  q, k, v = unbind(qkv, 2)
 
75
  x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
76
  x = x.reshape([B, N, C])
 
77
  x = self.proj(x)
78
  x = self.proj_drop(x)
79
  return x