lixinhao commited on
Commit
78403f6
·
verified ·
1 Parent(s): 2f8163e

Update vision_tower_builder.py

Browse files
Files changed (1) hide show
  1. vision_tower_builder.py +9 -2
vision_tower_builder.py CHANGED
@@ -24,8 +24,10 @@ import torch.utils.checkpoint as checkpoint
24
  from functools import partial
25
  try:
26
  from flash_attn import flash_attn_qkvpacked_func
 
27
  except:
28
- print("You need to install flash_attn")
 
29
  from timm.layers import drop_path, to_2tuple, trunc_normal_
30
 
31
 
@@ -67,6 +69,12 @@ class Attention(nn.Module):
67
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
68
  proj_drop=0., attn_head_dim=None,
69
  attn_type='flash_v2'):
 
 
 
 
 
 
70
  super().__init__()
71
  self.num_heads = num_heads
72
  head_dim = dim // num_heads
@@ -613,7 +621,6 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
613
  if "umt-hd" in vision_tower:
614
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
615
  elif "umt" in vision_tower:
616
- raise NotImplementedError
617
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
618
 
619
  raise ValueError(f"Unknown vision tower: {vision_tower}")
 
24
  from functools import partial
25
  try:
26
  from flash_attn import flash_attn_qkvpacked_func
27
+ use_flash_attn = True
28
  except:
29
+ use_flash_attn = False
30
+ print("You need to install flash_attn to be faster!")
31
  from timm.layers import drop_path, to_2tuple, trunc_normal_
32
 
33
 
 
69
  self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
70
  proj_drop=0., attn_head_dim=None,
71
  attn_type='flash_v2'):
72
+
73
+ if use_flash_attn:
74
+ attn_type = attn_type
75
+ else:
76
+ attn_type = 'origin'
77
+
78
  super().__init__()
79
  self.num_heads = num_heads
80
  head_dim = dim // num_heads
 
621
  if "umt-hd" in vision_tower:
622
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
623
  elif "umt" in vision_tower:
 
624
  return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
625
 
626
  raise ValueError(f"Unknown vision tower: {vision_tower}")