Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os | |
import skimage | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
import numpy as np | |
from collections import OrderedDict | |
import torch | |
from imagebind import data | |
from imagebind.models import imagebind_model | |
from imagebind.models.imagebind_model import ModalityType | |
import torch.nn as nn | |
import pickle | |
device = "cpu" #"cuda:0" if torch.cuda.is_available() else "cpu" | |
model = imagebind_model.imagebind_huge(pretrained=True) | |
model.eval() | |
model.to(device) | |
image_features = pickle.load(open("./assets/image_features_norm.pkl","rb")) | |
image_paths = pickle.load(open("./assets/image_paths.pkl","rb")) | |
def generate_image(text): | |
inputs = { | |
ModalityType.TEXT: data.load_and_transform_text([text], device) | |
} | |
with torch.no_grad(): | |
embeddings = model(inputs) | |
text_features = embeddings[ModalityType.TEXT] | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T | |
#pega index maior | |
index_img = np.argmax(similarity) | |
img_name = os.path.basename(image_paths[index_img]) | |
im = Image.open(f"./assets/images/{img_name}").convert("RGB") | |
return im | |
# Interface do Gradio | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs="text", | |
outputs="image", | |
title="Texto para Imagem", | |
description="Digite um texto e obtenha uma imagem com o texto." | |
) | |
# Executa o servidor Gradio | |
iface.launch() | |