draw_to_search / app.py
osanseviero's picture
Update app.py
a91f98a
raw
history blame
2.39 kB
import os
from pathlib import Path
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import torch
from torch import nn
import gradio as gr
import requests
LABELS = Path('class_names.txt').read_text().splitlines()
class_model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
class_model.load_state_dict(state_dict, strict=False)
class_model.eval()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = pd.read_csv('data2.csv')
embeddings_npy = np.load('embeddings.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
def download_img(path):
img_data = requests.get(path).content
local_path = path.split("/")[-1]
with open(local_path, 'wb') as handler:
handler.write(img_data)
return local_path
def predict(im):
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
out = class_model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
query = LABELS[indices[0]]
n_results=3
text_embeddings = compute_text_embeddings([query]).detach().numpy()
results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
outputs = [download_img(df.iloc[i]['path']) for i in results]
outputs.insert(0, {LABELS[i]: v.item() for i, v in zip(indices, values)})
print(outputs)
return outputs
title = "Draw to Search"
iface = gr.Interface(
fn=predict,
inputs='sketchpad',
outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
title=title,
live=True
)
iface.launch(debug=True)