Spaces:
Runtime error
Runtime error
File size: 5,397 Bytes
b395c00 cbadb1a b395c00 de2dda2 b395c00 de2dda2 b395c00 de2dda2 b395c00 a2b1833 fe6ca74 b395c00 a2b1833 fe6ca74 b395c00 bbec7cd a444ecf b395c00 c1b19c7 9a456dc 42f0fa3 9a456dc b395c00 de2dda2 b395c00 de2dda2 b395c00 2ce2973 b395c00 c567393 b395c00 2753619 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, BlipForQuestionAnswering, ViltForQuestionAnswering
import torch
import math
torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg')
torch.hub.download_url_to_file('https://huggingface.co/datasets/nielsr/textcaps-sample/resolve/main/stop_sign.png', 'stop_sign.png')
torch.hub.download_url_to_file('https://cdn.openai.com/dall-e-2/demos/text2im/astronaut/horse/photo/0.jpg', 'astronaut.jpg')
git_processor_base = AutoProcessor.from_pretrained("microsoft/git-base-vqav2")
git_model_base = AutoModelForCausalLM.from_pretrained("microsoft/git-base-vqav2")
# git_processor_large = AutoProcessor.from_pretrained("microsoft/git-large-vqav2")
# git_model_large = AutoModelForCausalLM.from_pretrained("microsoft/git-large-vqav2")
blip_processor_base = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
# blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
device = "cuda" if torch.cuda.is_available() else "cpu"
git_model_base.to(device)
blip_model_base.to(device)
#git_model_large.to(device)
#blip_model_large.to(device)
vilt_model.to(device)
def generate_answer_git(processor, model, image, question):
# prepare image
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# prepare question
input_ids = processor(text=question, add_special_tokens=False).input_ids
input_ids = [processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=50, output_scores=True)
print(generated_ids)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer
def generate_answer_blip(processor, model, image, question):
# prepare image + question
inputs = processor(images=image, text=question, return_tensors="pt")
generated_ids = model.generate(**inputs, max_length=50, output_scores=True)
print(generated_ids)
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
return generated_answer
def generate_answer_vilt(processor, model, image, question):
# prepare image + question
encoding = processor(images=image, text=question, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
print(outputs.logits)
print(torch.softmax(outputs.logits, dim=1))
predicted_class_idx = outputs.logits.argmax(-1).item()
print(f"prdicted_class_idx: {predicted_class_idx}")
logitsList = outputs.logits.flatten().tolist()
print(f"predicted_class_idx_in_list = {logitsList.index(max(logitsList))}")
m = max(logitsList)
s = -math.inf
for logit in logitsList:
if s <= logit < m:
s = logit
print(f"runnerup_idx_in_list = {logitsList.index(s)}")
print(f"runnerup val: {model.config.id2label[logitsList.index(s)]}")
return model.config.id2label[predicted_class_idx]
def generate_answers(image, question):
answer_git_base = generate_answer_git(git_processor_base, git_model_base, image, question)
# answer_git_large = generate_answer_git(git_processor_large, git_model_large, image, question)
answer_blip_base = generate_answer_blip(blip_processor_base, blip_model_base, image, question)
# answer_blip_large = generate_answer_blip(blip_processor_large, blip_model_large, image, question)
answer_vilt = generate_answer_vilt(vilt_processor, vilt_model, image, question)
return answer_git_base, answer_blip_base, answer_vilt
examples = [["cats.jpg", "How many cats are there?"], ["stop_sign.png", "What's behind the stop sign?"], ["astronaut.jpg", "What's the astronaut riding on?"]]
outputs = [gr.outputs.Textbox(label="Answer generated by GIT-base"), gr.outputs.Textbox(label="Answer generated by BLIP-base"), gr.outputs.Textbox(label="Answer generated by ViLT")]
title = "Interactive demo: comparing visual question answering (VQA) models"
description = "Gradio Demo to compare GIT, BLIP and ViLT, 3 state-of-the-art vision+language models. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://huggingface.co/docs/transformers/main/model_doc/blip' target='_blank'>BLIP docs</a> | <a href='https://huggingface.co/docs/transformers/main/model_doc/git' target='_blank'>GIT docs</a></p>"
interface = gr.Interface(fn=generate_answers,
inputs=[gr.inputs.Image(type="pil"), gr.inputs.Textbox(label="Question")],
outputs=outputs,
examples=examples,
title=title,
description=description,
article=article,
enable_queue=True)
interface.launch(debug=True) |