|
import gradio as gr |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, AutoModel |
|
|
|
|
|
model_id = "patrickjohncyh/fashion-clip" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
model = AutoModel.from_pretrained(model_id) |
|
|
|
def compute_embeddings(input_data, input_type="image"): |
|
if input_type == "image": |
|
image = Image.open(input_data) |
|
inputs = processor(images=image, return_tensors="pt") |
|
outputs = model.get_image_features(**inputs) |
|
else: |
|
inputs = processor(text=input_data, return_tensors="pt") |
|
outputs = model.get_text_features(**inputs) |
|
return outputs.detach().numpy() |
|
|
|
def image_text_search(query, image): |
|
|
|
text_emb = compute_embeddings(query, input_type="text") |
|
image_emb = compute_embeddings(image, input_type="image") |
|
|
|
|
|
similarity = torch.nn.functional.cosine_similarity( |
|
torch.tensor(text_emb), torch.tensor(image_emb), dim=1 |
|
) |
|
return f"Similarity score: {similarity.item():.3f}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Fashion-CLIP Demo 🛍️") |
|
with gr.Row(): |
|
text_input = gr.Textbox(label="Search Query", placeholder="e.g., 'red dress'") |
|
image_input = gr.Image(label="Upload Fashion Item", type="filepath") |
|
submit_btn = gr.Button("Search") |
|
output = gr.Textbox(label="Similarity Score") |
|
|
|
submit_btn.click( |
|
fn=image_text_search, |
|
inputs=[text_input, image_input], |
|
outputs=output |
|
) |
|
|
|
demo.launch() |