rjiang12 commited on
Commit
c1b19c7
·
1 Parent(s): a444ecf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -66,8 +66,9 @@ def generate_answer_vilt(processor, model, image, question):
66
  print(outputs.logits)
67
  print(torch.softmax(outputs.logits, dim=1))
68
  predicted_class_idx = outputs.logits.argmax(-1).item()
69
- print(f"top2: {model.config.id2label[predicted_class_idx]}, {model.config.id2label[outputs.logits.argmax(-2).item()]}")
70
- print(f"top2_: {outputs.logits.argmax(-1)}, {outputs.logits.argmax(-2)}")
 
71
  return model.config.id2label[predicted_class_idx]
72
 
73
 
 
66
  print(outputs.logits)
67
  print(torch.softmax(outputs.logits, dim=1))
68
  predicted_class_idx = outputs.logits.argmax(-1).item()
69
+ print(f"prdicted_class_idx: {predicted_class_idx}")
70
+ runnerup_class_idx = outputs.logits.argmax(-2).item()
71
+ print(model.config.id2label[predicted_class_idx], model.config.id2label[runnerup_class_idx])
72
  return model.config.id2label[predicted_class_idx]
73
 
74