File size: 4,697 Bytes
d09d79e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from items import ecommerce_items
import os
from dotenv import load_dotenv
# Load environment variables from the .env file
load_dotenv()
# Sidebar content
sidebar_markdown = """
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
## Documentation
π [Blog Post]()
π [Use Case Blog Post]()
## Code
π» [GitHub Repo]()
π€ [Google Colab]()
π€ [Hugging Face Collection]()
## Citation
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
```
```
"""
from huggingface_hub import login
# Get your Hugging Face API key (ensure it is set in your environment variables)
api_key = os.getenv("HF_API_TOKEN")
if api_key is None:
raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.")
# Login using the token
login(token=api_key)
# Initialize the model and tokenizer
def load_model(progress=gr.Progress()):
progress(0, "Initializing model...")
model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B'
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
progress(0.5, "Loading tokenizer...")
tokenizer = open_clip.get_tokenizer(model_name)
text = tokenizer(ecommerce_items)
progress(0.75, "Encoding text features...")
with torch.no_grad(), torch.amp.autocast('cuda'):
text_features = model.encode_text(text)
text_features /= text_features.norm(dim=-1, keepdim=True)
progress(1.0, "Model loaded successfully!")
return model, preprocess_val, text_features
# Load model and prepare interface
model, preprocess_val, text_features = load_model()
# Prediction function
def predict(image, url):
if url:
response = requests.get(url)
image = Image.open(BytesIO(response.content))
processed_image = preprocess_val(image).unsqueeze(0)
with torch.no_grad(), torch.amp.autocast('cuda'):
image_features = model.encode_image(processed_image)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
sorted_confidences = sorted(
{ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(),
key=lambda x: x[1],
reverse=True
)
top_10_confidences = dict(sorted_confidences[:10])
return image, top_10_confidences
# Clear function
def clear_fields():
return None, ""
# Gradio interface
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models"
description = "Upload an image or provide a URL of a fashion item to classify it using Marqo-Ecommerce Models!"
examples = [
["images/laptop.png", "Laptop"],
["images/grater.png", "Grater"],
["images/flip-flops.jpg", "Flip Flops"],
["images/bike-helmet.png", "Bike Helmet"],
["images/sleeping-bag.png", "Sleeping Bag"],
["images/cutting-board.png", "Cutting Board"],
["images/iron.png", "Iron"],
["images/coffee.png", "Coffee"],
]
with gr.Blocks(css="""
.remove-btn {
font-size: 24px !important; /* Increase the font size of the cross button */
line-height: 24px !important;
width: 30px !important; /* Increase the width */
height: 30px !important; /* Increase the height */
}
""") as demo:
with gr.Row():
with gr.Column(scale=1):
gr.Markdown(f"# {title}")
gr.Markdown(description)
gr.Markdown(sidebar_markdown)
gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID
with gr.Column(scale=2):
input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
input_url = gr.Textbox(label="Or provide an image URL")
with gr.Row():
predict_button = gr.Button("Classify")
clear_button = gr.Button("Clear")
gr.Markdown("Or click on one of the images below to classify it:")
gr.Examples(examples=examples, inputs=input_image)
output_label = gr.Label(num_top_classes=6)
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
clear_button.click(clear_fields, outputs=[input_image, input_url])
# Launch the interface
demo.launch()
|