X-iZhang commited on
Commit
d6d6d10
·
verified ·
1 Parent(s): bf1a4c1

Update libra/model/builder.py

Browse files
Files changed (1) hide show
  1. libra/model/builder.py +6 -3
libra/model/builder.py CHANGED
@@ -24,6 +24,9 @@ from libra.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, D
24
 
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
 
 
 
27
  device_map = {"": device}
28
  kwargs = {
29
  "device_map": device_map,
@@ -81,14 +84,14 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
81
  model.load_state_dict(mm_projector_weights, strict=False)
82
  else:
83
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
84
- model = LibraLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
85
  else:
86
  # Load language model
87
  if model_base is not None:
88
  # PEFT model
89
  from peft import PeftModel
90
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
- model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
92
  print(f"Loading LoRA weights from {model_path}")
93
  model = PeftModel.from_pretrained(model, model_path)
94
  print(f"Merging weights")
@@ -98,7 +101,7 @@ def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
98
  else:
99
  use_fast = False
100
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
101
- model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
102
 
103
  image_processor = None
104
 
 
24
 
25
 
26
  def load_pretrained_model(model_path, model_base, model_name, device="cpu"):
27
+ quantization_config = BitsAndBytesConfig(
28
+ load_in_8bit=True
29
+ )
30
  device_map = {"": device}
31
  kwargs = {
32
  "device_map": device_map,
 
84
  model.load_state_dict(mm_projector_weights, strict=False)
85
  else:
86
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
87
+ model = LibraLlamaForCausalLM.from_pretrained(model_path, quantization_config=quantization_config, low_cpu_mem_usage=True, **kwargs)
88
  else:
89
  # Load language model
90
  if model_base is not None:
91
  # PEFT model
92
  from peft import PeftModel
93
  tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
94
+ model = AutoModelForCausalLM.from_pretrained(model_base, quantization_config=quantization_config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
95
  print(f"Loading LoRA weights from {model_path}")
96
  model = PeftModel.from_pretrained(model, model_path)
97
  print(f"Merging weights")
 
101
  else:
102
  use_fast = False
103
  tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
104
+ model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=quantization_config, low_cpu_mem_usage=True, **kwargs)
105
 
106
  image_processor = None
107