draw_to_search / app.py
osanseviero's picture
Update app.py
1c8a40e
raw
history blame
1.59 kB
import pandas as pd, numpy as np
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import gradio as gr
import requests
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings2.npy'), 1: np.load('embeddings.npy')}
for k in [0, 1]:
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
def download_img(path):
img_data = requests.get(path).content
local_path = path.split("/")[-1] + ".jpg"
with open(local_path, 'wb') as handler:
handler.write(img_data)
return local_path
def predict(query):
corpus = 'Movies'
n_results=3
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == 'Unsplash' else 1
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
paths = [download_img(df[k].iloc[i]['path']) for i in results]
print(paths)
return paths
title = "Draw to Search"
iface = gr.Interface(
fn=predict,
inputs=[gr.inputs.Textbox(label="text", lines=3)],
outputs=[gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
title=title,
examples=[["Sunset"]]
)
iface.launch(debug=True)