Spaces:
Running
Running
Update libra/model/builder.py
Browse files- libra/model/builder.py +7 -1
libra/model/builder.py
CHANGED
@@ -23,12 +23,18 @@ from libra.model import *
|
|
23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
-
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="
|
27 |
kwargs = {"device_map": device_map}
|
28 |
|
|
|
29 |
if device != "cuda":
|
30 |
kwargs['device_map'] = {"": device}
|
31 |
|
|
|
|
|
|
|
|
|
|
|
32 |
if load_8bit:
|
33 |
kwargs['load_in_8bit'] = True
|
34 |
elif load_4bit:
|
|
|
23 |
from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
24 |
|
25 |
|
26 |
+
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cup"):
|
27 |
kwargs = {"device_map": device_map}
|
28 |
|
29 |
+
|
30 |
if device != "cuda":
|
31 |
kwargs['device_map'] = {"": device}
|
32 |
|
33 |
+
if device == "cpu":
|
34 |
+
kwargs["torch_dtype"] = torch.float32
|
35 |
+
else:
|
36 |
+
kwargs["torch_dtype"] = torch.float16
|
37 |
+
|
38 |
if load_8bit:
|
39 |
kwargs['load_in_8bit'] = True
|
40 |
elif load_4bit:
|