finalhw / app.py
dakkoong's picture
initial commit
2054da9
raw
history blame
2.04 kB
import gradio as gr
from transformers import Blip2ForVisualQuestionAnswering, Blip2Config, AutoTokenizer
from PIL import Image
import torch
model_path = "microsoft/git-base-vqav2"
dataset_name = "Multimodal-Fatima/OK-VQA_train"
questions = ["What can happen the objects shown are thrown on the ground?",
"What was the machine beside the bowl used for?",
"What kind of cars are in the photo?",
"What is the hairstyle of the blond called?",
"How old do you have to be in canada to do this?",
"Can you guess the place where the man is playing?",
"What loony tune character is in this photo?",
"Whose birthday is being celebrated?",
"Where can that toilet seat be bought?",
"What do you call the kind of pants that the man on the right is wearing?"]
model = Blip2ForVisualQuestionAnswering.from_pretrained(model_path)
def main(select_exemple_num):
selectednum = select_exemple_num
exemple_img = f"image{selectednum}.jpg"
img = Image.open(exemple_img)
question = questions[selectednum - 1]
encoding = processor(img, question, return_tensors='pt')
outputs = model(**encoding)
logits = outputs.logits
# ---
output_str = 'pridicted : \n'
predicted_classes = torch.sigmoid(logits)
probs, classes = torch.topk(predicted_classes, 5)
ans = ''
for prob, class_idx in zip(probs.squeeze().tolist(), classes.squeeze().tolist()):
print(prob, model.config.id2label[class_idx])
output_str += str(prob)
output_str += " "
output_str += model.config.id2label[class_idx]
output_str += "\n"
if not ans:
ans = model.config.id2label[class_idx]
print(ans)
# ---
output_str += f"\nso I think it's answer is : \n{ans}"
return exemple_img, question, output_str
demo = gr.Interface(
fn=main,
inputs=[gr.Slider(1, len(questions), step=1)],
outputs=["image", "text", "text"],
)
demo.launch(share=True)