aps's picture
Set eval mode
559ac78
raw
history blame
1.57 kB
import gradio as gr
from transformers import FlavaModel, BertTokenizer, FlavaFeatureExtractor
import numpy as np
from PIL import Image
import torch
images="dog.jpg"
model = FlavaModel.from_pretrained("facebook/flava-full")
model.eval()
fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
def shot(image, labels_text):
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
labels = labels_text.split(",")
label_with_template = [f"This is a photo of a {label}" for label in labels]
image_input = fe([PIL_image], return_tensors="pt")
text_inputs = tokenizer(label_with_template, padding="max_length", return_tensors="pt")
image_embeddings = model.get_image_features(**image_input)[:, 0, :]
text_embeddings = model.get_text_features(**text_inputs)[:, 0, :]
similarities = list(torch.nn.functional.softmax((text_embeddings @ image_embeddings.T).squeeze(0), dim=0))
return {label: similarities[idx].item() for idx, label in enumerate(labels)}
iface = gr.Interface(shot,
["image", "text"],
"label",
examples=[["dog.jpg", "dog,cat,bird"],
["germany.jpg", "germany,belgium,colombia"],
["rocket.jpg", "car,rocket,train"]
],
description="Add a picture and a list of labels separated by commas",
title="FLAVA Zero-shot Image Classification")
iface.launch()