Commit
Β·
fe26b6e
1
Parent(s):
c969994
adding both models for classification
Browse files
app.py
CHANGED
@@ -19,44 +19,40 @@ sidebar_markdown = """
|
|
19 |
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
|
20 |
|
21 |
## Documentation
|
22 |
-
π [Blog Post]()
|
23 |
|
24 |
-
π [Use Case Blog Post]()
|
|
|
|
|
25 |
|
26 |
## Code
|
27 |
-
π» [GitHub Repo]()
|
28 |
|
29 |
-
π€ [Google Colab]()
|
30 |
|
31 |
-
π€ [Hugging Face Collection]()
|
32 |
|
33 |
## Citation
|
34 |
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
|
35 |
```
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
```
|
38 |
"""
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
if api_key is None:
|
46 |
-
raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.")
|
47 |
-
|
48 |
-
# Login using the token
|
49 |
-
login(token=api_key)
|
50 |
-
|
51 |
-
# Initialize the model and tokenizer
|
52 |
-
@spaces.GPU
|
53 |
-
def load_model(progress=gr.Progress()):
|
54 |
-
progress(0, "Initializing model...")
|
55 |
-
model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B'
|
56 |
-
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
|
57 |
|
58 |
progress(0.5, "Loading tokenizer...")
|
59 |
-
tokenizer = open_clip.get_tokenizer(model_name)
|
60 |
|
61 |
text = tokenizer(ecommerce_items)
|
62 |
|
@@ -65,16 +61,23 @@ def load_model(progress=gr.Progress()):
|
|
65 |
text_features = model.encode_text(text)
|
66 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
67 |
|
68 |
-
progress(1.0, "Model loaded successfully!")
|
69 |
|
70 |
return model, preprocess_val, text_features
|
71 |
|
72 |
-
# Load model
|
73 |
-
|
|
|
|
|
74 |
|
75 |
# Prediction function
|
76 |
@spaces.GPU
|
77 |
-
def predict(image, url):
|
|
|
|
|
|
|
|
|
|
|
78 |
if url:
|
79 |
response = requests.get(url)
|
80 |
image = Image.open(BytesIO(response.content))
|
@@ -133,14 +136,19 @@ with gr.Blocks(css="""
|
|
133 |
with gr.Column(scale=2):
|
134 |
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
|
135 |
input_url = gr.Textbox(label="Or provide an image URL")
|
|
|
|
|
|
|
|
|
|
|
136 |
with gr.Row():
|
137 |
predict_button = gr.Button("Classify")
|
138 |
clear_button = gr.Button("Clear")
|
139 |
gr.Markdown("Or click on one of the images below to classify it:")
|
140 |
gr.Examples(examples=examples, inputs=input_image)
|
141 |
output_label = gr.Label(num_top_classes=6)
|
142 |
-
predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
|
143 |
-
clear_button.click(clear_fields, outputs=[input_image, input_url])
|
144 |
|
145 |
# Launch the interface
|
146 |
-
demo.launch()
|
|
|
19 |
Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
|
20 |
|
21 |
## Documentation
|
22 |
+
π [Blog Post](https://www.marqo.ai/blog/introducing-marqos-ecommerce-embedding-models)
|
23 |
|
24 |
+
π [Classification Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-huggingface-transformers)
|
25 |
+
|
26 |
+
π [Image Search Use Case Blog Post](https://www.marqo.ai/blog/how-to-build-an-ecommerce-image-search-application)
|
27 |
|
28 |
## Code
|
29 |
+
π» [GitHub Repo](https://github.com/marqo-ai/marqo-ecommerce-embeddings)
|
30 |
|
31 |
+
π€ [Google Colab](https://colab.research.google.com/drive/1ctqDrXs_P-RIOPc9xcUF83WLdYQ0wf-8?usp=sharing)
|
32 |
|
33 |
+
π€ [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb)
|
34 |
|
35 |
## Citation
|
36 |
If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
|
37 |
```
|
38 |
+
@software{zhu2024marqoecommembed_2024,
|
39 |
+
author = {Tianyu Zhu and and Jesse Clark},
|
40 |
+
month = oct,
|
41 |
+
title = {{Marqo Ecommerce Embeddings - Foundation Model for Product Embeddings}},
|
42 |
+
url = {https://github.com/marqo-ai/marqo-ecommerce-embeddings/},
|
43 |
+
version = {1.0.0},
|
44 |
+
year = {2024}
|
45 |
+
}
|
46 |
```
|
47 |
"""
|
48 |
|
49 |
+
# Function to initialize a model, preprocess, and text features
|
50 |
+
def initialize_model(model_name, progress=gr.Progress()):
|
51 |
+
progress(0, f"Initializing model: {model_name}...")
|
52 |
+
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(f"hf-hub:Marqo/{model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
progress(0.5, "Loading tokenizer...")
|
55 |
+
tokenizer = open_clip.get_tokenizer(f"hf-hub:Marqo/{model_name}")
|
56 |
|
57 |
text = tokenizer(ecommerce_items)
|
58 |
|
|
|
61 |
text_features = model.encode_text(text)
|
62 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
63 |
|
64 |
+
progress(1.0, f"Model {model_name} loaded successfully!")
|
65 |
|
66 |
return model, preprocess_val, text_features
|
67 |
|
68 |
+
# Load L model first, followed by B model
|
69 |
+
progress_bar = gr.Progress()
|
70 |
+
model_l, preprocess_val_l, text_features_l = initialize_model("marqo-ecommerce-embeddings-L", progress=progress_bar)
|
71 |
+
model_b, preprocess_val_b, text_features_b = initialize_model("marqo-ecommerce-embeddings-B", progress=progress_bar)
|
72 |
|
73 |
# Prediction function
|
74 |
@spaces.GPU
|
75 |
+
def predict(image, url, model_name):
|
76 |
+
if model_name == "marqo-ecommerce-embeddings-B":
|
77 |
+
model, preprocess_val, text_features = model_b, preprocess_val_b, text_features_b
|
78 |
+
else:
|
79 |
+
model, preprocess_val, text_features = model_l, preprocess_val_l, text_features_l
|
80 |
+
|
81 |
if url:
|
82 |
response = requests.get(url)
|
83 |
image = Image.open(BytesIO(response.content))
|
|
|
136 |
with gr.Column(scale=2):
|
137 |
input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
|
138 |
input_url = gr.Textbox(label="Or provide an image URL")
|
139 |
+
model_selector = gr.Dropdown(
|
140 |
+
choices=["marqo-ecommerce-embeddings-L", "marqo-ecommerce-embeddings-B"],
|
141 |
+
value="marqo-ecommerce-embeddings-L",
|
142 |
+
label="Select Model"
|
143 |
+
)
|
144 |
with gr.Row():
|
145 |
predict_button = gr.Button("Classify")
|
146 |
clear_button = gr.Button("Clear")
|
147 |
gr.Markdown("Or click on one of the images below to classify it:")
|
148 |
gr.Examples(examples=examples, inputs=input_image)
|
149 |
output_label = gr.Label(num_top_classes=6)
|
150 |
+
predict_button.click(predict, inputs=[input_image, input_url, model_selector], outputs=[input_image, output_label])
|
151 |
+
clear_button.click(clear_fields, outputs=[input_image, input_url, model_selector])
|
152 |
|
153 |
# Launch the interface
|
154 |
+
demo.launch()
|