rjiang12 commited on
Commit
9a456dc
·
1 Parent(s): c1b19c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -67,8 +67,15 @@ def generate_answer_vilt(processor, model, image, question):
67
  print(torch.softmax(outputs.logits, dim=1))
68
  predicted_class_idx = outputs.logits.argmax(-1).item()
69
  print(f"prdicted_class_idx: {predicted_class_idx}")
70
- runnerup_class_idx = outputs.logits.argmax(-2).item()
71
- print(model.config.id2label[predicted_class_idx], model.config.id2label[runnerup_class_idx])
 
 
 
 
 
 
 
72
  return model.config.id2label[predicted_class_idx]
73
 
74
 
 
67
  print(torch.softmax(outputs.logits, dim=1))
68
  predicted_class_idx = outputs.logits.argmax(-1).item()
69
  print(f"prdicted_class_idx: {predicted_class_idx}")
70
+ logitsList = outputs.logits.flatten().tolist()
71
+ print(f"predicted_class_idx_in_list = {logitsList.index(max(logitsList))}")
72
+ m = max(logitsList)
73
+ s = -math.infinity
74
+ for logit in logitsList:
75
+ if s <= logit < m:
76
+ s = logit
77
+ print(f"runnerup_idx_in_list = {logitsList.index(s)}")
78
+ print(f"runnerup val: {model.config.id2label[logitsList.index(s)]}")
79
  return model.config.id2label[predicted_class_idx]
80
 
81