File size: 1,618 Bytes
a234e10
b5861b0
 
 
a234e10
b5861b0
 
 
 
a234e10
b5861b0
 
 
 
 
 
 
 
 
a234e10
b5861b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel

# Load Fashion-CLIP model and processor
model_id = "patrickjohncyh/fashion-clip"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)

def compute_embeddings(input_data, input_type="image"):
    if input_type == "image":
        image = Image.open(input_data)
        inputs = processor(images=image, return_tensors="pt")
        outputs = model.get_image_features(**inputs)
    else:  # text
        inputs = processor(text=input_data, return_tensors="pt")
        outputs = model.get_text_features(**inputs)
    return outputs.detach().numpy()

def image_text_search(query, image):
    # Compute embeddings
    text_emb = compute_embeddings(query, input_type="text")
    image_emb = compute_embeddings(image, input_type="image")
    
    # Compute similarity (example: cosine similarity)
    similarity = torch.nn.functional.cosine_similarity(
        torch.tensor(text_emb), torch.tensor(image_emb), dim=1
    )
    return f"Similarity score: {similarity.item():.3f}"

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Fashion-CLIP Demo 🛍️")
    with gr.Row():
        text_input = gr.Textbox(label="Search Query", placeholder="e.g., 'red dress'")
        image_input = gr.Image(label="Upload Fashion Item", type="filepath")
    submit_btn = gr.Button("Search")
    output = gr.Textbox(label="Similarity Score")
    
    submit_btn.click(
        fn=image_text_search,
        inputs=[text_input, image_input],
        outputs=output
    )

demo.launch()