Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,161 Bytes
4ebb492 e0de7f3 b53eff8 4ebb492 e0de7f3 ccb786c 4ebb492 e0de7f3 74ba09f 4ebb492 fcc4e47 4ebb492 d2b0ea9 4ebb492 b53eff8 4ebb492 e0de7f3 4ebb492 e0de7f3 4ebb492 e0de7f3 f9ba4ce b53eff8 4ebb492 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, SiglipModel
import faiss
import numpy as np
from huggingface_hub import hf_hub_download
from datasets import load_dataset
import pandas as pd
import requests
from io import BytesIO
# download model and dataset
hf_hub_download("merve/siglip-faiss-wikiart", "siglip_10k_latest.index", local_dir="./")
hf_hub_download("merve/siglip-faiss-wikiart", "wikiart_10k_latest.csv", local_dir="./")
# read index, dataset and load siglip model and processor
index = faiss.read_index("./siglip_10k_latest.index")
df = pd.read_csv("./wikiart_10k_latest.csv")
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
model = SiglipModel.from_pretrained("google/siglip-base-patch16-224").to(device)
def read_image_from_url(url):
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
return img
def extract_features_siglip(image):
with torch.no_grad():
inputs = processor(images=image, return_tensors="pt").to(device)
image_features = model.get_image_features(**inputs)
return image_features
def infer(input_image):
input_features = extract_features_siglip(input_image["composite"].convert("RGB"))
input_features = input_features.detach().cpu().numpy()
input_features = np.float32(input_features)
faiss.normalize_L2(input_features)
distances, indices = index.search(input_features, 3)
gallery_output = []
for i,v in enumerate(indices[0]):
sim = -distances[0][i]
image_url = df.iloc[v]["Link"]
img_retrieved = read_image_from_url(image_url)
gallery_output.append(img_retrieved)
return gallery_output
description="This is an application where you can draw an image and find the closest artwork among 10k art from wikiart dataset. This is built on 🤗 transformers integration of SIGLIP model by Google, and FAISS for indexing."
sketchpad = gr.ImageEditor(type="pil")
gr.Interface(infer, sketchpad, "gallery", description=description, title="Draw to Search Art 🖼️").launch()
|