X-iZhang commited on
Commit
4e94887
·
verified ·
1 Parent(s): 4de70fa

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +9 -20
libra/model/builder.py CHANGED
@@ -23,25 +23,14 @@ from libra.model import *
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
- def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=True, device_map="auto", device="cpu"):
27
- kwargs = {"device_map": device_map}
 
 
 
 
28
 
29
-
30
- if device != "cuda":
31
- kwargs['device_map'] = {"": device}
32
 
33
- if load_8bit:
34
- kwargs['load_in_8bit'] = True
35
- elif load_4bit:
36
- kwargs['load_in_4bit'] = True
37
- kwargs['quantization_config'] = BitsAndBytesConfig(
38
- load_in_4bit=True,
39
- bnb_4bit_compute_dtype=torch.float16,
40
- bnb_4bit_use_double_quant=True,
41
- bnb_4bit_quant_type='nf4'
42
- )
43
- else:
44
- kwargs['torch_dtype'] = torch.float16
45
 
46
  if 'libra' in model_name.lower():
47
  # Load Libra model
@@ -92,7 +81,7 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
92
  model.load_state_dict(mm_projector_weights, strict=False)
93
  else:
94
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
95
- model = LibraLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
96
  else:
97
  # Load language model
98
  if model_base is not None:
@@ -124,8 +113,8 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
124
 
125
  vision_tower = model.get_vision_tower()
126
  if not vision_tower.is_loaded:
127
- vision_tower.load_model(device_map=device_map)
128
- vision_tower.to(device=device, dtype=torch.float16)
129
  image_processor = vision_tower.image_processor
130
 
131
  if hasattr(model.config, "max_sequence_length"):
 
23
  from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
 
25
 
26
+ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
+ device_map = {"": device}
28
+ kwargs = {
29
+ "device_map": device_map,
30
+ "torch_dtype": torch.float32 # 对于 CPU,建议使用 float32 或 bfloat16
31
+ }
32
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  if 'libra' in model_name.lower():
36
  # Load Libra model
 
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
84
+ model = LibraLlamaForCausalLM.from_pretrained(model_path, **kwargs)
85
  else:
86
  # Load language model
87
  if model_base is not None:
 
113
 
114
  vision_tower = model.get_vision_tower()
115
  if not vision_tower.is_loaded:
116
+ vision_tower.load_model()
117
+ vision_tower.to(device=device, dtype=torch.float32)
118
  image_processor = vision_tower.image_processor
119
 
120
  if hasattr(model.config, "max_sequence_length"):