X-iZhang commited on
Commit
43bcde1
·
verified ·
1 Parent(s): fbb58ea

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +3 -4
libra/model/builder.py CHANGED
@@ -23,14 +23,13 @@ from libra.model import *
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cpu"):
27
- kwargs = {"device_map": device_map}
28
 
29
 
30
  if device != "cuda":
31
  kwargs['device_map'] = {"": device}
32
 
33
-
34
  if load_8bit:
35
  kwargs['load_in_8bit'] = True
36
  elif load_4bit:
@@ -125,7 +124,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
125
 
126
  vision_tower = model.get_vision_tower()
127
  if not vision_tower.is_loaded:
128
- vision_tower.load_model()
129
  vision_tower.to(device=device, dtype=torch.float16)
130
  image_processor = vision_tower.image_processor
131
 
 
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=True, device_map="auto", device="cpu"):
27
+ kwargs = {"device_map": device_map, **kwargs}
28
 
29
 
30
  if device != "cuda":
31
  kwargs['device_map'] = {"": device}
32
 
 
33
  if load_8bit:
34
  kwargs['load_in_8bit'] = True
35
  elif load_4bit:
 
124
 
125
  vision_tower = model.get_vision_tower()
126
  if not vision_tower.is_loaded:
127
+ vision_tower.load_model(device_map=device_map)
128
  vision_tower.to(device=device, dtype=torch.float16)
129
  image_processor = vision_tower.image_processor
130