DINGOLANI's picture
Update app.py
b5861b0 verified
raw
history blame
1.62 kB
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
# Load Fashion-CLIP model and processor
model_id = "patrickjohncyh/fashion-clip"
processor = AutoProcessor.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
def compute_embeddings(input_data, input_type="image"):
if input_type == "image":
image = Image.open(input_data)
inputs = processor(images=image, return_tensors="pt")
outputs = model.get_image_features(**inputs)
else: # text
inputs = processor(text=input_data, return_tensors="pt")
outputs = model.get_text_features(**inputs)
return outputs.detach().numpy()
def image_text_search(query, image):
# Compute embeddings
text_emb = compute_embeddings(query, input_type="text")
image_emb = compute_embeddings(image, input_type="image")
# Compute similarity (example: cosine similarity)
similarity = torch.nn.functional.cosine_similarity(
torch.tensor(text_emb), torch.tensor(image_emb), dim=1
)
return f"Similarity score: {similarity.item():.3f}"
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Fashion-CLIP Demo 🛍️")
with gr.Row():
text_input = gr.Textbox(label="Search Query", placeholder="e.g., 'red dress'")
image_input = gr.Image(label="Upload Fashion Item", type="filepath")
submit_btn = gr.Button("Search")
output = gr.Textbox(label="Similarity Score")
submit_btn.click(
fn=image_text_search,
inputs=[text_input, image_input],
outputs=output
)
demo.launch()