rjiang12 commited on
Commit
2960381
·
1 Parent(s): 1ce0263

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -67,7 +67,7 @@ def generate_answer_vilt(processor, model, image, question):
67
  lsm = torch.nn.LogSoftmax(dim=1)
68
  print(lsm(outputs.logits))
69
  predicted_class_idx = outputs.logits.argmax(-1).item()
70
- logitsList = outputs.logits.tolist()
71
  print(logitsList)
72
  maybeProbsList = [math.exp(i) for i in logitsList]
73
  return model.config.id2label[predicted_class_idx]
 
67
  lsm = torch.nn.LogSoftmax(dim=1)
68
  print(lsm(outputs.logits))
69
  predicted_class_idx = outputs.logits.argmax(-1).item()
70
+ logitsList = outputs.logits.flatten().tolist()
71
  print(logitsList)
72
  maybeProbsList = [math.exp(i) for i in logitsList]
73
  return model.config.id2label[predicted_class_idx]