File size: 4,927 Bytes
b395c00
 
 
cbadb1a
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
de2dda2
 
b395c00
 
 
 
 
 
 
 
 
 
 
eda8f6b
9acd641
dc6e60c
4056078
b395c00
4056078
b395c00
 
 
 
 
 
a2b1833
fe6ca74
b395c00
 
 
 
 
 
 
 
 
 
 
0647c43
b395c00
 
 
 
 
 
de2dda2
b395c00
 
 
de2dda2
b395c00
 
 
2ce2973
b395c00
 
 
c567393
b395c00
 
 
 
 
 
 
 
 
 
 
 
 
a298ea6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math

torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')

git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")

# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")

blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")

# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")

vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")

device = "cuda" if torch.cuda.is_available() else "cpu"

git_model_base.to(device)
blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
vilt_model.to(device)

def generate_answer_git(processor, model, image, question):
    # prepare image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # prepare question
    input_ids = processor(text=question, add_special_tokens=False).input_ids
    input_ids = [processor.tokenizer.cls_token_id] + input_ids
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    
    generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, return_dict_in_generate=True, output_scores=True)
    print(generated_ids.scores)
    print(generated_ids)
    # generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return 'haha'


def generate_answer_blip(processor, model, image, question):
    # prepare image + question
    inputs = processor(images=image, text=question, return_tensors="pt")
    
    generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
    print(generated_ids)
    generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
   
    return generated_answer


def generate_answer_vilt(processor, model, image, question):
    # prepare image + question
    encoding = processor(images=image, text=question, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**encoding)
    predicted_class_idx = outputs.logits.argmax(-1).item()
    return model.config.id2label[predicted_class_idx]


def generate_answers(image, question):
    answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)

    # answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)

    answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)

    # answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)

    answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)

    return answer_git_base, answer_blip_base, answer_vilt

   
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")] 

title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"

interface = gr.Interface(fn=generate_answers, 
                         inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
                         outputs=outputs,
                         examples=examples, 
                         title=title,
                         description=description,
                         article=article, 
                         enable_queue=True)
interface.launch()