rjiang12 commited on
Commit
3eeacdb
·
1 Parent(s): dc6e60c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -17
app.py CHANGED
@@ -64,23 +64,7 @@ def generate_answer_vilt(processor, model, image, question):
64
 
65
  with torch.no_grad():
66
  outputs = model(**encoding)
67
- print(outputs.logits)
68
- print(torch.softmax(outputs.logits, dim=1))
69
- predicted_class_idx = outputs.logits.argmax(-1).item()
70
- print(f"prdicted_class_idx: {predicted_class_idx}")
71
- logitsList = torch.softmax(outputs.logits, dim=1).flatten().tolist()
72
- print(f"predicted_class_idx_in_list = {logitsList.index(max(logitsList))}")
73
- m = max(logitsList)
74
- s = -math.inf
75
- for logit in logitsList:
76
- if s <= logit < m:
77
- s = logit
78
- t = sum(logitsList)
79
- pm, ps = m/t, s/t
80
- print(f"{pm}, {ps}")
81
- print(f"scaled: {pm/(pm + ps)}, {ps/(pm + ps)}")
82
- print(f"runnerup_idx_in_list = {logitsList.index(s)}")
83
- print(f"runnerup val: {model.config.id2label[logitsList.index(s)]}")
84
  return model.config.id2label[predicted_class_idx]
85
 
86
 
 
64
 
65
  with torch.no_grad():
66
  outputs = model(**encoding)
67
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  return model.config.id2label[predicted_class_idx]
69
 
70