ubamba98's picture
Create app.py
409c8d1
raw
history blame
1.57 kB
import torch
import numpy as np
import pandas as pd
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
def find_similar(image):
device = "cuda" if torch.cuda.is_available() else "cpu"
## Define model
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = model.to(device)
## Load data
photos = pd.read_csv("./photos.tsv000", sep='\t', header=0)
photo_features = np.load("./features.npy")
photo_ids = pd.read_csv("./photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
## Inference
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photos_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
search_photos_feature = search_photos_feature.cpu().numpy()
## Find similarity
similarities = list((search_photos_features @ photo_features.T).squeeze(0))
## Return best image :)
best_photo = sorted(zip(similarities, range(photo_features.shape[0])), key=lambda x: x[0], reverse=True)[0]
idx = best_photos[1]
photo_id = photo_ids[idx]
photo_data = photos[photos["photo_id"] == photo_id].iloc[0]
return Image(url=photo_data["photo_image_url"] + "?w=640")
iface = gr.Interface(fn=bg_remove, inputs="image", outputs="image").launch()