elliesleightholm commited on
Commit
fe26b6e
Β·
1 Parent(s): c969994

adding both models for classification

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -19,44 +19,40 @@ sidebar_markdown = """
19
  Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
20
 
21
  ## Documentation
22
- πŸ“š [Blog Post]()
23
 
24
- πŸ“ [Use Case Blog Post]()
 
 
25
 
26
  ## Code
27
- πŸ’» [GitHub Repo]()
28
 
29
- 🀝 [Google Colab]()
30
 
31
- πŸ€— [Hugging Face Collection]()
32
 
33
  ## Citation
34
  If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
35
  ```
36
-
 
 
 
 
 
 
 
37
  ```
38
  """
39
 
40
- from huggingface_hub import login
41
-
42
- # Get your Hugging Face API key (ensure it is set in your environment variables)
43
- api_key = os.getenv("HF_API_TOKEN")
44
-
45
- if api_key is None:
46
- raise ValueError("Hugging Face API key not found. Please set the 'HF_API_TOKEN' environment variable.")
47
-
48
- # Login using the token
49
- login(token=api_key)
50
-
51
- # Initialize the model and tokenizer
52
- @spaces.GPU
53
- def load_model(progress=gr.Progress()):
54
- progress(0, "Initializing model...")
55
- model_name = 'hf-hub:Marqo/marqo-ecommerce-embeddings-B'
56
- model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(model_name)
57
 
58
  progress(0.5, "Loading tokenizer...")
59
- tokenizer = open_clip.get_tokenizer(model_name)
60
 
61
  text = tokenizer(ecommerce_items)
62
 
@@ -65,16 +61,23 @@ def load_model(progress=gr.Progress()):
65
  text_features = model.encode_text(text)
66
  text_features /= text_features.norm(dim=-1, keepdim=True)
67
 
68
- progress(1.0, "Model loaded successfully!")
69
 
70
  return model, preprocess_val, text_features
71
 
72
- # Load model and prepare interface
73
- model, preprocess_val, text_features = load_model()
 
 
74
 
75
  # Prediction function
76
  @spaces.GPU
77
- def predict(image, url):
 
 
 
 
 
78
  if url:
79
  response = requests.get(url)
80
  image = Image.open(BytesIO(response.content))
@@ -133,14 +136,19 @@ with gr.Blocks(css="""
133
  with gr.Column(scale=2):
134
  input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
135
  input_url = gr.Textbox(label="Or provide an image URL")
 
 
 
 
 
136
  with gr.Row():
137
  predict_button = gr.Button("Classify")
138
  clear_button = gr.Button("Clear")
139
  gr.Markdown("Or click on one of the images below to classify it:")
140
  gr.Examples(examples=examples, inputs=input_image)
141
  output_label = gr.Label(num_top_classes=6)
142
- predict_button.click(predict, inputs=[input_image, input_url], outputs=[input_image, output_label])
143
- clear_button.click(clear_fields, outputs=[input_image, input_url])
144
 
145
  # Launch the interface
146
- demo.launch()
 
19
  Note, this demo can classify 200 items. If you didn't find what you're looking for, reach out to us on our [Community](https://join.slack.com/t/marqo-community/shared_invite/zt-2iab0260n-QJrZLUSOJYUifVxf964Gdw) and request an item to be added.
20
 
21
  ## Documentation
22
+ πŸ“š [Blog Post](https://www.marqo.ai/blog/introducing-marqos-ecommerce-embedding-models)
23
 
24
+ πŸ“ [Classification Use Case Blog Post](https://www.marqo.ai/blog/ecommerce-image-classification-with-huggingface-transformers)
25
+
26
+ πŸ”Ž [Image Search Use Case Blog Post](https://www.marqo.ai/blog/how-to-build-an-ecommerce-image-search-application)
27
 
28
  ## Code
29
+ πŸ’» [GitHub Repo](https://github.com/marqo-ai/marqo-ecommerce-embeddings)
30
 
31
+ 🀝 [Google Colab](https://colab.research.google.com/drive/1ctqDrXs_P-RIOPc9xcUF83WLdYQ0wf-8?usp=sharing)
32
 
33
+ πŸ€— [Hugging Face Collection](https://huggingface.co/collections/Marqo/marqo-ecommerce-embeddings-66f611b9bb9d035a8d164fbb)
34
 
35
  ## Citation
36
  If you use Marqo-Ecommerce-L or Marqo-Ecommerce-B, please cite us:
37
  ```
38
+ @software{zhu2024marqoecommembed_2024,
39
+ author = {Tianyu Zhu and and Jesse Clark},
40
+ month = oct,
41
+ title = {{Marqo Ecommerce Embeddings - Foundation Model for Product Embeddings}},
42
+ url = {https://github.com/marqo-ai/marqo-ecommerce-embeddings/},
43
+ version = {1.0.0},
44
+ year = {2024}
45
+ }
46
  ```
47
  """
48
 
49
+ # Function to initialize a model, preprocess, and text features
50
+ def initialize_model(model_name, progress=gr.Progress()):
51
+ progress(0, f"Initializing model: {model_name}...")
52
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(f"hf-hub:Marqo/{model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  progress(0.5, "Loading tokenizer...")
55
+ tokenizer = open_clip.get_tokenizer(f"hf-hub:Marqo/{model_name}")
56
 
57
  text = tokenizer(ecommerce_items)
58
 
 
61
  text_features = model.encode_text(text)
62
  text_features /= text_features.norm(dim=-1, keepdim=True)
63
 
64
+ progress(1.0, f"Model {model_name} loaded successfully!")
65
 
66
  return model, preprocess_val, text_features
67
 
68
+ # Load L model first, followed by B model
69
+ progress_bar = gr.Progress()
70
+ model_l, preprocess_val_l, text_features_l = initialize_model("marqo-ecommerce-embeddings-L", progress=progress_bar)
71
+ model_b, preprocess_val_b, text_features_b = initialize_model("marqo-ecommerce-embeddings-B", progress=progress_bar)
72
 
73
  # Prediction function
74
  @spaces.GPU
75
+ def predict(image, url, model_name):
76
+ if model_name == "marqo-ecommerce-embeddings-B":
77
+ model, preprocess_val, text_features = model_b, preprocess_val_b, text_features_b
78
+ else:
79
+ model, preprocess_val, text_features = model_l, preprocess_val_l, text_features_l
80
+
81
  if url:
82
  response = requests.get(url)
83
  image = Image.open(BytesIO(response.content))
 
136
  with gr.Column(scale=2):
137
  input_image = gr.Image(type="pil", label="Upload Ecommerce Item Image", height=312)
138
  input_url = gr.Textbox(label="Or provide an image URL")
139
+ model_selector = gr.Dropdown(
140
+ choices=["marqo-ecommerce-embeddings-L", "marqo-ecommerce-embeddings-B"],
141
+ value="marqo-ecommerce-embeddings-L",
142
+ label="Select Model"
143
+ )
144
  with gr.Row():
145
  predict_button = gr.Button("Classify")
146
  clear_button = gr.Button("Clear")
147
  gr.Markdown("Or click on one of the images below to classify it:")
148
  gr.Examples(examples=examples, inputs=input_image)
149
  output_label = gr.Label(num_top_classes=6)
150
+ predict_button.click(predict, inputs=[input_image, input_url, model_selector], outputs=[input_image, output_label])
151
+ clear_button.click(clear_fields, outputs=[input_image, input_url, model_selector])
152
 
153
  # Launch the interface
154
+ demo.launch()