rjiang12 commited on
Commit
832c797
·
1 Parent(s): c2e8e56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -60,7 +60,7 @@ def generate_answer_git(processor, model, image, question):
60
  with torch.no_grad():
61
  outputs = model(**encoding)
62
  print(outputs.logits)
63
- predicted_class_idx = outputs.logits.argmax(-1)
64
  # return model.config.id2label[predicted_class_idx]
65
  print(predicted_class_idx)
66
  print(model.config.id2label)
 
60
  with torch.no_grad():
61
  outputs = model(**encoding)
62
  print(outputs.logits)
63
+ predicted_class_idx = outputs.logits[0].argmax(-1).item
64
  # return model.config.id2label[predicted_class_idx]
65
  print(predicted_class_idx)
66
  print(model.config.id2label)