File size: 4,697 Bytes
d09d79e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import open_clip
import torch
import requests
import numpy as np
from PIL import Image
from io import BytesIO
from items import ecommerce_items
import os
from dotenv import load_dotenv

# Load environment variables from the .env file
load_dotenv()

# Sidebar content
sidebar_markdown = """

Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.

## Documentation
πŸ“š [Blog Post]()

πŸ“ [Use Case Blog Post]()

## Code
πŸ’» [GitHub Repo]()

🀝 [Google Colab]()

πŸ€— [Hugging Face Collection]()

## Citation
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
```

```
"""

from huggingface_hub import login

# Get your Hugging Face API key (ensure it is set in your environment variables)
api_key = os.getenv("HF_API_TOKEN")

if api_key is None:
    raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.")

# Login using the token
login(token=api_key)

# Initialize the model and tokenizer
def load_model(progress=gr.Progress()):
    progress(0, "Initializing model...")
    model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B'
    model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
    
    progress(0.5, "Loading tokenizer...")
    tokenizer = open_clip.get_tokenizer(model_name)
    
    text = tokenizer(ecommerce_items)
    
    progress(0.75, "Encoding text features...")
    with torch.no_grad(), torch.amp.autocast('cuda'):
        text_features = model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    
    progress(1.0, "Model loaded successfully!")
    
    return model, preprocess_val, text_features

# Load model and prepare interface
model, preprocess_val, text_features = load_model()

# Prediction function
def predict(image, url):
    if url:
        response = requests.get(url)
        image = Image.open(BytesIO(response.content))
    
    processed_image = preprocess_val(image).unsqueeze(0)

    with torch.no_grad(), torch.amp.autocast('cuda'):
        image_features = model.encode_image(processed_image)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)
        
        sorted_confidences = sorted(
            {ecommerce_items[i]: float(text_probs[0, i]) for i in range(len(ecommerce_items))}.items(), 
            key=lambda x: x[1], 
            reverse=True
        )
        top_10_confidences = dict(sorted_confidences[:10])
        
    return image, top_10_confidences

# Clear function
def clear_fields():
    return None, ""

# Gradio interface
title = "Ecommerce Item Classifier with Marqo-Ecommerce Embedding Models"
description = "Upload an image or provide a URL of a fashion item to classify it using Marqo-Ecommerce Models!"

examples = [
    ["images/laptop.png", "Laptop"],
    ["images/grater.png", "Grater"],
    ["images/flip-flops.jpg", "Flip Flops"],
    ["images/bike-helmet.png", "Bike Helmet"],
    ["images/sleeping-bag.png", "Sleeping Bag"],
    ["images/cutting-board.png", "Cutting Board"],
    ["images/iron.png", "Iron"],
    ["images/coffee.png", "Coffee"],
]

with gr.Blocks(css="""
    .remove-btn {
        font-size: 24px !important; /* Increase the font size of the cross button */
        line-height: 24px !important;
        width: 30px !important; /* Increase the width */
        height: 30px !important; /* Increase the height */
    }
""") as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"# {title}")
            gr.Markdown(description)
            gr.Markdown(sidebar_markdown)
            gr.Markdown(" ", elem_id="vertical-line")  # Add an empty Markdown with a custom ID
        with gr.Column(scale=2):
            input_image = gr.Image(type="pil", label="Upload Fashion Item Image", height=312)
            input_url = gr.Textbox(label="Or provide an image URL")
            with gr.Row():
                predict_button = gr.Button("Classify")
                clear_button = gr.Button("Clear")
            gr.Markdown("Or click on one of the images below to classify it:")
            gr.Examples(examples=examples, inputs=input_image)
            output_label = gr.Label(num_top_classes=6)
            predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
            clear_button.click(clear_fields, outputs=[input_image, input_url])

# Launch the interface
demo.launch()