Spaces:
Running
Running
Update libra/model/builder.py
Browse files- libra/model/builder.py +3 -4
libra/model/builder.py
CHANGED
@@ -23,14 +23,13 @@ from libra.model import *
|
|
23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=
|
27 |
-
kwargs = {"device_map": device_map}
|
28 |
|
29 |
|
30 |
if device != "cuda":
|
31 |
kwargs['device_map'] = {"": device}
|
32 |
|
33 |
-
|
34 |
if load_8bit:
|
35 |
kwargs['load_in_8bit'] = True
|
36 |
elif load_4bit:
|
@@ -125,7 +124,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
|
|
125 |
|
126 |
vision_tower = model.get_vision_tower()
|
127 |
if not vision_tower.is_loaded:
|
128 |
-
vision_tower.load_model()
|
129 |
vision_tower.to(device=device, dtype=torch.float16)
|
130 |
image_processor = vision_tower.image_processor
|
131 |
|
|
|
23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=True, device_map="auto", device="cpu"):
|
27 |
+
kwargs = {"device_map": device_map, **kwargs}
|
28 |
|
29 |
|
30 |
if device != "cuda":
|
31 |
kwargs['device_map'] = {"": device}
|
32 |
|
|
|
33 |
if load_8bit:
|
34 |
kwargs['load_in_8bit'] = True
|
35 |
elif load_4bit:
|
|
|
124 |
|
125 |
vision_tower = model.get_vision_tower()
|
126 |
if not vision_tower.is_loaded:
|
127 |
+
vision_tower.load_model(device_map=device_map)
|
128 |
vision_tower.to(device=device, dtype=torch.float16)
|
129 |
image_processor = vision_tower.image_processor
|
130 |
|