Spaces:
Runtime error
Runtime error
File size: 2,382 Bytes
234d9d6 7a990e9 234d9d6 f140b23 7a990e9 234d9d6 9a64f12 234d9d6 f140b23 2dff389 f140b23 c48dde4 7a990e9 b31c0c9 7a990e9 5a15dbc 234d9d6 9a64f12 2dff389 234d9d6 7a990e9 52fb80a 234d9d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import os
from pathlib import Path
import pandas as pd, numpy as np
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
import torch
from torch import nn
import gradio as gr
import requests
LABELS = Path('class_names.txt').read_text().splitlines()
class_model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load('pytorch_model.bin', map_location='cpu')
class_model.load_state_dict(state_dict, strict=False)
class_model.eval()
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = pd.read_csv('data2.csv')
embeddings_npy = np.load('embeddings.npy')
embeddings = np.divide(embeddings_npy, np.sqrt(np.sum(embeddings_npy**2, axis=1, keepdims=True)))
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
def download_img(path):
img_data = requests.get(path).content
local_path = path.split("/")[-1]
with open(local_path, 'wb') as handler:
handler.write(img_data)
return local_path
def predict(im):
x = torch.tensor(im, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
with torch.no_grad():
out = class_model(x)
probabilities = torch.nn.functional.softmax(out[0], dim=0)
values, indices = torch.topk(probabilities, 5)
query = LABELS[indices[0]]
n_results=3
text_embeddings = compute_text_embeddings([query]).detach().numpy()
results = np.argsort((embeddings@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
outputs = [download_img(df.iloc[i]['path']) for i in results]
outputs.insert(0, {LABELS[i]: v.item() for i, v in zip(indices, values)})
print(outputs)
return outputs
title = "Draw to Search"
iface = gr.Interface(
fn=predict,
inputs='sketchpad',
outputs=[gr.outputs.Label(num_top_classes=3), gr.outputs.Image(type="file"), gr.outputs.Image(type="file"), gr.outputs.Image(type="file")],
title=title,
)
iface.launch(debug=True)
|