rjiang12 commited on
Commit
9cd6c28
·
1 Parent(s): 7eb4eb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -31,33 +31,36 @@ git_model_base.to(device)
31
  # vilt_model.to(device)
32
 
33
  def generate_answer_git(processor, model, image, question):
34
- # prepare image
35
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
36
 
37
- # prepare question
38
- input_ids = processor(text=question, add_special_tokens=False).input_ids
39
- input_ids = [processor.tokenizer.cls_token_id] + input_ids
40
- input_ids = torch.tensor(input_ids).unsqueeze(0)
41
 
42
- generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
43
- print('scores:')
44
- print(generated_ids.scores)
45
- # scoresList0 = torch.softmax(generated_ids.scores[0], dim=1).flatten().tolist()
46
- # print(scoresList0)
47
- # scoresList1 = torch.softmax(generated_ids.scores[1], dim=1).flatten().tolist()
48
- # print(scoresList1)
49
- idx = generated_ids.scores[0].argmax(-1).item()
50
- idx1 = generated_ids.scores[1].argmax(-1).item()
51
- print(idx, idx1)
52
- print(model.config.id2label)
53
- ans = model.config.id2label[idx]
54
- ans1 = model.config.id2label[idx1]
55
- print(ans, ans1)
56
- print('sequences:')
57
- print(generated_ids.sequences)
58
- print(generated_ids)
59
- generated_answer = processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)
60
- print(generated_answer)
 
 
 
61
 
62
 
63
  return 'haha'
 
31
  # vilt_model.to(device)
32
 
33
  def generate_answer_git(processor, model, image, question):
34
+ # # prepare image
35
+ # pixel_values = processor(images=image, return_tensors="pt").pixel_values
36
 
37
+ # # prepare question
38
+ # input_ids = processor(text=question, add_special_tokens=False).input_ids
39
+ # input_ids = [processor.tokenizer.cls_token_id] + input_ids
40
+ # input_ids = torch.tensor(input_ids).unsqueeze(0)
41
 
42
+ # generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
43
+ # print('scores:')
44
+ # print(generated_ids.scores)
45
+ # idx = generated_ids.scores[0].argmax(-1).item()
46
+ # idx1 = generated_ids.scores[1].argmax(-1).item()
47
+ # print(idx, idx1)
48
+ # print(model.config.id2label)
49
+ # ans = model.config.id2label[idx]
50
+ # ans1 = model.config.id2label[idx1]
51
+ # print(ans, ans1)
52
+ # print('sequences:')
53
+ # print(generated_ids.sequences)
54
+ # print(generated_ids)
55
+ # generated_answer = processor.batch_decode(generated_ids.sequences, skip_special_tokens=True)
56
+ # print(generated_answer)
57
+
58
+ encoding = processor(images=image, text=question, return_tensors="pt")
59
+
60
+ with torch.no_grad():
61
+ outputs = model(**encoding)
62
+ predicted_class_idx = outputs.logits.argmax(-1).item()
63
+ return model.config.id2label[predicted_class_idx]
64
 
65
 
66
  return 'haha'