File size: 5,590 Bytes
b395c00
 
 
cbadb1a
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
 
 
 
eda8f6b
9acd641
4056078
 
b395c00
4056078
b395c00
 
 
 
 
 
a2b1833
fe6ca74
b395c00
 
 
 
 
 
 
 
 
 
 
bbec7cd
a444ecf
b395c00
c1b19c7
06b58b5
9a456dc
 
42f0fa3
9a456dc
 
 
a298ea6
dc4835e
a298ea6
 
9a456dc
 
b395c00
 
 
 
 
 
de2dda2
b395c00
 
 
de2dda2
b395c00
 
 
2ce2973
b395c00
 
 
c567393
b395c00
 
 
 
 
 
 
 
 
 
 
 
 
a298ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")

vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
vilt_model.to(device)

def generate_answer_git(processor, model, image, question):
    # prepare image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # prepare question
    input_ids = processor(text=question, add_special_tokens=False).input_ids
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
    print(generated_ids.scores)
    print(generated_scores)
    # generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return 'haha'


def generate_answer_blip(processor, model, image, question):
    # prepare image + question
    inputs = processor(images=image, text=question, return_tensors="pt")
    
    generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer


def generate_answer_vilt(processor, model, image, question):
    # prepare image + question
    encoding = processor(images=image, text=question, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)
    print(outputs.logits)
    print(torch.softmax(outputs.logits, dim=1))
    predicted_class_idx = outputs.logits.argmax(-1).item()
    print(f"prdicted_class_idx: {predicted_class_idx}")
    logitsList = torch.softmax(outputs.logits, dim=1).flatten().tolist()
    print(f"predicted_class_idx_in_list = {logitsList.index(max(logitsList))}")
    m = max(logitsList)
    s = -math.inf 
    for logit in logitsList:
        if s <= logit < m:
            s = logit
    t = sum(logitsList)
    pm, ps = m/t, s/t
    print(f"{pm}, {ps}")
    print(f"scaled: {pm/(pm + ps)}, {ps/(pm + ps)}")
    print(f"runnerup_idx_in_list = {logitsList.index(s)}")
    print(f"runnerup val: {model.config.id2label[logitsList.index(s)]}")
    return model.config.id2label[predicted_class_idx]


def generate_answers(image, question):
    answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)

    # answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)

    answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)

    # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)

    answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)

    return answer_git_base, answer_blip_base, answer_vilt

   
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")] 

title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_answers, 
                         inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch()