rjiang12 commited on
Commit
bbec7cd
·
1 Parent(s): b395c00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -39,6 +39,7 @@ def generate_answer_git(processor, model, image, question):
39
  input_ids = torch.tensor(input_ids).unsqueeze(0)
40
 
41
  generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
 
42
  generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
43
 
44
  return generated_answer
@@ -49,6 +50,7 @@ def generate_answer_blip(processor, model, image, question):
49
  inputs = processor(images=image, text=question, return_tensors="pt")
50
 
51
  generated_ids = model.generate(**inputs, max_length=50)
 
52
  generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
53
 
54
  return generated_answer
@@ -60,7 +62,7 @@ def generate_answer_vilt(processor, model, image, question):
60
 
61
  with torch.no_grad():
62
  outputs = model(**encoding)
63
-
64
  predicted_class_idx = outputs.logits.argmax(-1).item()
65
 
66
  return model.config.id2label[predicted_class_idx]
 
39
  input_ids = torch.tensor(input_ids).unsqueeze(0)
40
 
41
  generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50)
42
+ print(generated_ids.logits)
43
  generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
44
 
45
  return generated_answer
 
50
  inputs = processor(images=image, text=question, return_tensors="pt")
51
 
52
  generated_ids = model.generate(**inputs, max_length=50)
53
+ print(generated_ids.logits)
54
  generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
55
 
56
  return generated_answer
 
62
 
63
  with torch.no_grad():
64
  outputs = model(**encoding)
65
+ print(outputs.logits)
66
  predicted_class_idx = outputs.logits.argmax(-1).item()
67
 
68
  return model.config.id2label[predicted_class_idx]