X-iZhang commited on
Commit
d57ad7c
·
verified ·
1 Parent(s): ffbe576

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +5 -5
libra/model/builder.py CHANGED
@@ -26,14 +26,14 @@ from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
  quantization_config = BitsAndBytesConfig(
28
  load_in_4bit=True,
29
- bnb_4bit_compute_dtype=torch.bfloat16,
30
  bnb_4bit_use_double_quant=True,
31
  bnb_4bit_quant_type='nf4'
32
  )
33
  device_map = {"": device}
34
  kwargs = {
35
  "device_map": device_map,
36
- "torch_dtype": torch.bfloat16
37
  }
38
 
39
 
@@ -83,7 +83,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
83
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
84
 
85
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
86
- mm_projector_weights = {k: v.to(torch.bfloat16) for k, v in mm_projector_weights.items()}
87
  model.load_state_dict(mm_projector_weights, strict=False)
88
  else:
89
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
@@ -94,13 +94,13 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
94
  # PEFT model
95
  from peft import PeftModel
96
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
97
- model = AutoModelForCausalLM.from_pretrained(model_base, quantization_config=quantization_config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
98
  print(f"Loading LoRA weights from {model_path}")
99
  model = PeftModel.from_pretrained(model, model_path)
100
  print(f"Merging weights")
101
  model = model.merge_and_unload()
102
  print('Convert to FP16...')
103
- model.to(torch.bfloat16)
104
  else:
105
  use_fast = False
106
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
 
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
  quantization_config = BitsAndBytesConfig(
28
  load_in_4bit=True,
29
+ bnb_4bit_compute_dtype=torch.float16,
30
  bnb_4bit_use_double_quant=True,
31
  bnb_4bit_quant_type='nf4'
32
  )
33
  device_map = {"": device}
34
  kwargs = {
35
  "device_map": device_map,
36
+ "torch_dtype": torch.float16
37
  }
38
 
39
 
 
83
  model = LibraLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
84
 
85
  mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
86
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
87
  model.load_state_dict(mm_projector_weights, strict=False)
88
  else:
89
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
 
94
  # PEFT model
95
  from peft import PeftModel
96
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
97
+ model = AutoModelForCausalLM.from_pretrained(model_base, quantization_config=quantization_config, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
98
  print(f"Loading LoRA weights from {model_path}")
99
  model = PeftModel.from_pretrained(model, model_path)
100
  print(f"Merging weights")
101
  model = model.merge_and_unload()
102
  print('Convert to FP16...')
103
+ model.to(torch.float16)
104
  else:
105
  use_fast = False
106
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)