Update vision_tower_builder.py
Browse files- vision_tower_builder.py +9 -2
vision_tower_builder.py
CHANGED
@@ -24,8 +24,10 @@ import torch.utils.checkpoint as checkpoint
|
|
24 |
from functools import partial
|
25 |
try:
|
26 |
from flash_attn import flash_attn_qkvpacked_func
|
|
|
27 |
except:
|
28 |
-
|
|
|
29 |
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
30 |
|
31 |
|
@@ -67,6 +69,12 @@ class Attention(nn.Module):
|
|
67 |
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
68 |
proj_drop=0., attn_head_dim=None,
|
69 |
attn_type='flash_v2'):
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
super().__init__()
|
71 |
self.num_heads = num_heads
|
72 |
head_dim = dim // num_heads
|
@@ -613,7 +621,6 @@ def build_vision_tower(vision_tower_cfg, **kwargs):
|
|
613 |
if "umt-hd" in vision_tower:
|
614 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
|
615 |
elif "umt" in vision_tower:
|
616 |
-
raise NotImplementedError
|
617 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
618 |
|
619 |
raise ValueError(f"Unknown vision tower: {vision_tower}")
|
|
|
24 |
from functools import partial
|
25 |
try:
|
26 |
from flash_attn import flash_attn_qkvpacked_func
|
27 |
+
use_flash_attn = True
|
28 |
except:
|
29 |
+
use_flash_attn = False
|
30 |
+
print("You need to install flash_attn to be faster!")
|
31 |
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
32 |
|
33 |
|
|
|
69 |
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
70 |
proj_drop=0., attn_head_dim=None,
|
71 |
attn_type='flash_v2'):
|
72 |
+
|
73 |
+
if use_flash_attn:
|
74 |
+
attn_type = attn_type
|
75 |
+
else:
|
76 |
+
attn_type = 'origin'
|
77 |
+
|
78 |
super().__init__()
|
79 |
self.num_heads = num_heads
|
80 |
head_dim = dim // num_heads
|
|
|
621 |
if "umt-hd" in vision_tower:
|
622 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, image_size=448, **kwargs)
|
623 |
elif "umt" in vision_tower:
|
|
|
624 |
return UMTVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
|
625 |
|
626 |
raise ValueError(f"Unknown vision tower: {vision_tower}")
|