sk-test / app.py
dmayboroda's picture
updated gradio
e63c501
raw
history blame
1.75 kB
import os
import torch
import clip
import transformers
import numpy as np
import gradio as gr
from PIL import Image
from multilingual_clip import pt_multilingual_clip
from torch.utils.data import DataLoader
from datasets import load_dataset
from usearch.index import Index
dataset = load_dataset("dmayboroda/sk-test_1")
device = "cuda" if torch.cuda.is_available() else "cpu"
clipmodel, preprocess = clip.load("ViT-L/14", device=device)
model_name = 'M-CLIP/LABSE-Vit-L-14'
model = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model.to(device)
index = Index(ndim=768, metric='cos', dtype='f32')
img_embeddings = []
emb_to_img = {}
print('Encoding images...')
for img in dataset['train']:
image = preprocess(img['image']).unsqueeze(0).to(device)
with torch.no_grad():
image_features = clipmodel.encode_image(image)
img_embeddings.append(image_features)
emb_to_img[image_features] = img['image']
for i in range(0, len(img_embeddings)):
index.add(i, img_embeddings[i].squeeze(0).cpu().detach().numpy())
def get_similar(text, num_sim):
tokens = clip.tokenize(text).to(device)
text_features = clipmodel.encode_text(tokens)
search = text_features.squeeze(0).cpu().detach().numpy()
matches = index.search(search, num_sim)
similar = []
for match in matches:
key = match.key.item()
emb = img_embeddings[key]
similar.append(emb_to_img[emb])
return similar
iface = gr.Interface(
fn=get_similar,
inputs=[
gr.Textbox(label="Enter Text Here..."),
gr.Number(label="Number of Images", default=15)
],
outputs="image",
title="Model Testing"
)
iface.launch()