import gradio as gr import open_clip import torch import requests import numpy as np from PIL import Image from io import BytesIO from items import ecommerce_items import os from dotenv import load_dotenv # Load environment variables from the .env file load_dotenv() # Sidebar content sidebar_markdown = """ Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added. ## Documentation 📚 [Blog Post]() 📝 [Use Case Blog Post]() ## Code 💻 [GitHub Repo]() 🤝 [Google Colab]() 🤗 [Hugging Face Collection]() ## Citation If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us: ``` ``` """ from huggingface_hub import login # Get your Hugging Face API key (ensure it is set in your environment variables) api_key = os.getenv("HF_API_TOKEN") if api_key is None: raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.") # Login using the token login(token=api_key) # Initialize the model and tokenizer def load_model(progress=gr.Progress()): progress(0, "Initializing model...") model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B' model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name) progress(0.5, "Loading tokenizer...") tokenizer = open_clip.get_tokenizer(model_name) text = tokenizer(ecommerce_items) progress(0.75, "Encoding text features...") with torch.no_grad(), torch.amp.autocast('cuda'): text_features = model.encode_text(text) text_features /= text_features.norm(dim=-1, keepdim=True) progress(1.0, "Model loaded successfully!") return model, preprocess_val, text_features # Load model and prepare interface model, preprocess_val, text_features = load_model() # Prediction function def predict(image, url): if url: response = requests.get(url) image = Image.open(BytesIO(response.content)) processed_image = preprocess_val(image).unsqueeze(0) with torch.no_grad(), torch.amp.autocast('cuda'): image_features = model.encode_image(processed_image) image_features /= image_features.norm(dim=-1, keepdim=True) text_probs = (100 * image_features @ text_features.T).softmax(dim=-1) sorted_confidences = sorted( {ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(), key=lambda x: x[1], reverse=True ) top_10_confidences = dict(sorted_confidences[:10]) return image, top_10_confidences # Clear function def clear_fields(): return None, "" # Gradio interface title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models" description = "Upload an image or provide a URL of a fashion item to classify it using Marqo-Ecommerce Models!" examples = [ ["images/laptop.png", "Laptop"], ["images/grater.png", "Grater"], ["images/flip-flops.jpg", "Flip Flops"], ["images/bike-helmet.png", "Bike Helmet"], ["images/sleeping-bag.png", "Sleeping Bag"], ["images/cutting-board.png", "Cutting Board"], ["images/iron.png", "Iron"], ["images/coffee.png", "Coffee"], ] with gr.Blocks(css=""" .remove-btn { font-size: 24px !important; /* Increase the font size of the cross button */ line-height: 24px !important; width: 30px !important; /* Increase the width */ height: 30px !important; /* Increase the height */ } """) as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown(f"# {title}") gr.Markdown(description) gr.Markdown(sidebar_markdown) gr.Markdown(" ", elem_id="vertical-line") # Add an empty Markdown with a custom ID with gr.Column(scale=2): input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312) input_url = gr.Textbox(label="Or provide an image URL") with gr.Row(): predict_button = gr.Button("Classify") clear_button = gr.Button("Clear") gr.Markdown("Or click on one of the images below to classify it:") gr.Examples(examples=examples, inputs=input_image) output_label = gr.Label(num_top_classes=6) predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label]) clear_button.click(clear_fields, outputs=[input_image, input_url]) # Launch the interface demo.launch()