X-iZhang commited on
Commit
bcc2db4
·
verified ·
1 Parent(s): 8e07496

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +7 -1
libra/model/builder.py CHANGED
@@ -23,12 +23,18 @@ from libra.model import *
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
  kwargs = {"device_map": device_map}
28
 
 
29
  if device != "cuda":
30
  kwargs['device_map'] = {"": device}
31
 
 
 
 
 
 
32
  if load_8bit:
33
  kwargs['load_in_8bit'] = True
34
  elif load_4bit:
 
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cup"):
27
  kwargs = {"device_map": device_map}
28
 
29
+
30
  if device != "cuda":
31
  kwargs['device_map'] = {"": device}
32
 
33
+ if device == "cpu":
34
+ kwargs["torch_dtype"] = torch.float32
35
+ else:
36
+ kwargs["torch_dtype"] = torch.float16
37
+
38
  if load_8bit:
39
  kwargs['load_in_8bit'] = True
40
  elif load_4bit: