X-iZhang commited on
Commit
bf1e548
·
verified ·
1 Parent(s): a0339e8

Update libra/eval/run_libra.py

Browse files
Files changed (1) hide show
  1. libra/eval/run_libra.py +7 -6
libra/eval/run_libra.py CHANGED
@@ -28,6 +28,7 @@ def load_model(model_path, model_base=None):
28
  disable_torch_init()
29
  model_name = get_model_name_from_path(model_path)
30
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
 
31
  return tokenizer, model, image_processor, context_len
32
 
33
  def load_images(image_file):
@@ -92,7 +93,7 @@ def load_images(image_file):
92
 
93
  return image
94
 
95
- def get_image_tensors(image_path, image_processor, model, device='cuda'):
96
  # Load and preprocess the images
97
  if isinstance(image_path, str):
98
  image = []
@@ -144,7 +145,7 @@ def libra_eval(
144
  ):
145
  # Model
146
  disable_torch_init()
147
-
148
  if libra_model is not None:
149
  tokenizer, model, image_processor, context_len = libra_model
150
  model_name = model.config._name_or_path
@@ -171,18 +172,18 @@ def libra_eval(
171
  conv.append_message(conv.roles[1], None)
172
  prompt = conv.get_prompt()
173
 
174
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
175
- attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
176
  pad_token_id = tokenizer.pad_token_id
177
 
178
- image_tensor = get_image_tensors(image_file, image_processor, model)
179
 
180
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
181
  keywords = [stop_str]
182
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
183
 
184
  with torch.inference_mode():
185
- torch.cuda.empty_cache()
186
  if num_beams > 1:
187
  output_ids = model.generate(
188
  input_ids=input_ids,
 
28
  disable_torch_init()
29
  model_name = get_model_name_from_path(model_path)
30
  tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name)
31
+ model.to("cpu")
32
  return tokenizer, model, image_processor, context_len
33
 
34
  def load_images(image_file):
 
93
 
94
  return image
95
 
96
+ def get_image_tensors(image_path, image_processor, model, device='cpu'):
97
  # Load and preprocess the images
98
  if isinstance(image_path, str):
99
  image = []
 
145
  ):
146
  # Model
147
  disable_torch_init()
148
+ device = "cpu"
149
  if libra_model is not None:
150
  tokenizer, model, image_processor, context_len = libra_model
151
  model_name = model.config._name_or_path
 
172
  conv.append_message(conv.roles[1], None)
173
  prompt = conv.get_prompt()
174
 
175
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(device)
176
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
177
  pad_token_id = tokenizer.pad_token_id
178
 
179
+ image_tensor = get_image_tensors(image_file, image_processor, model, , device=device)
180
 
181
  stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
182
  keywords = [stop_str]
183
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
184
 
185
  with torch.inference_mode():
186
+
187
  if num_beams > 1:
188
  output_ids = model.generate(
189
  input_ids=input_ids,