Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import skimage | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import numpy as np | |
from collections import OrderedDict | |
import torch | |
from imagebind import data | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType | |
import torch.nn as nn | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
model = imagebind_model.imagebind_huge(pretrained=True) | |
model.eval() | |
model.to(device) | |
def image_text_zeroshot(image, text_list): | |
image_paths = [image] | |
labels = [label.strip(" ") for label in text_list.strip(" ").split("|")] | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text(labels, device), | |
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
scores = ( | |
torch.softmax( | |
embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1 | |
) | |
.squeeze(0) | |
.tolist() | |
) | |
score_dict = {label: score for label, score in zip(labels, scores)} | |
return score_dict | |
def main(): | |
inputs = [ | |
gr.inputs.Textbox(lines=1, label="texts"), | |
gr.inputs.Image(type="filepath", label="Output image") | |
] | |
iface = gr.Interface( | |
image_text_zeroshot(image, text_list), | |
inputs, | |
"label", | |
description="""...""", | |
title="ImageBind", | |
) | |
iface.launch() |