selvaonline commited on
Commit
b1b6f63
·
verified ·
1 Parent(s): e667020

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +139 -98
app.py CHANGED
@@ -83,25 +83,50 @@ def process_deals_data(deals_data):
83
 
84
  return processed_deals
85
 
86
- # Load the e-commerce specific model and tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  try:
88
- # Try to load the e-commerce BERT model
89
- tokenizer = AutoTokenizer.from_pretrained("prithivida/ecommerce-bert-base-uncased")
90
- model = AutoModelForSequenceClassification.from_pretrained("prithivida/ecommerce-bert-base-uncased")
 
 
 
 
 
 
91
 
92
- # E-commerce BERT categories
93
- categories = [
94
- "electronics", "computers", "mobile_phones", "accessories",
95
- "clothing", "footwear", "watches", "jewelry",
96
- "home", "kitchen", "furniture", "decor",
97
- "beauty", "personal_care", "health", "wellness",
98
- "toys", "games", "sports", "outdoors",
99
- "books", "stationery", "music", "movies"
100
- ]
101
- print("Using e-commerce BERT model")
102
  except Exception as e:
103
- # Fall back to local model if e-commerce BERT fails to load
104
- print(f"Error loading e-commerce BERT model: {str(e)}")
105
  print("Falling back to local model")
106
 
107
  model_id = "selvaonline/shopping-assistant"
@@ -117,6 +142,9 @@ except Exception as e:
117
  except Exception as e:
118
  print(f"Error loading categories: {str(e)}")
119
  categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
 
 
 
120
 
121
  # Global variable to store deals data
122
  deals_cache = None
@@ -127,46 +155,37 @@ def classify_text(text, fetch_deals=True):
127
  """
128
  global deals_cache
129
 
130
- # Prepare the input for classification
131
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
132
-
133
- # Get the model prediction
134
- with torch.no_grad():
135
- outputs = model(**inputs)
136
 
137
- # Handle different model output formats
138
- if hasattr(outputs, 'logits'):
139
- # For models that return logits
140
- if outputs.logits.shape[1] == len(categories):
141
- # Multi-label classification
142
- predictions = torch.sigmoid(outputs.logits)
143
-
144
- # Get the top categories
145
- top_categories = []
146
- for i, score in enumerate(predictions[0]):
147
- if score > 0.3: # Lower threshold for e-commerce model
148
- top_categories.append((categories[i], score.item()))
149
- else:
150
- # Single-label classification
151
- probabilities = torch.softmax(outputs.logits, dim=1)
152
- values, indices = torch.topk(probabilities, 3)
153
-
154
- top_categories = []
155
- for i, idx in enumerate(indices[0]):
156
- if idx < len(categories):
157
- top_categories.append((categories[idx.item()], values[0][i].item()))
158
- else:
159
- # Fallback for other model formats
160
- predictions = torch.sigmoid(outputs[0])
161
 
162
- # Get the top categories
163
- top_categories = []
164
- for i, score in enumerate(predictions[0]):
165
- if score > 0.5:
166
- top_categories.append((categories[i], score.item()))
167
-
168
- # Sort by score
169
- top_categories.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  # Format the classification results
172
  if top_categories:
@@ -188,57 +207,79 @@ def classify_text(text, fetch_deals=True):
188
  deals_data = fetch_deals_data(num_pages=2) # Limit to 2 pages for faster response
189
  deals_cache = process_deals_data(deals_data)
190
 
191
- # Extract query terms and expand with related terms
192
- query_terms = text.lower().split()
193
- expanded_terms = list(query_terms)
194
-
195
- # Add related terms based on the query
196
- if any(term in text.lower() for term in ['headphone', 'headphones']):
197
- expanded_terms.extend(['earbuds', 'earphones', 'earpods', 'airpods', 'audio', 'bluetooth', 'wireless'])
198
- elif any(term in text.lower() for term in ['laptop', 'computer']):
199
- expanded_terms.extend(['notebook', 'macbook', 'chromebook', 'pc'])
200
- elif any(term in text.lower() for term in ['tv', 'television']):
201
- expanded_terms.extend(['smart tv', 'roku', 'streaming'])
202
- elif any(term in text.lower() for term in ['kitchen', 'appliance']):
203
- expanded_terms.extend(['mixer', 'blender', 'toaster', 'microwave', 'oven'])
204
-
205
- # Score deals based on relevance to the query
206
- scored_deals = []
207
- for deal in deals_cache:
208
- title = deal['title'].lower()
209
- content = deal['content'].lower()
210
- excerpt = deal['excerpt'].lower()
211
 
212
- score = 0
 
213
 
214
- # Check original query terms (higher weight)
215
- for term in query_terms:
216
- if term in title:
217
- score += 10
218
- if term in content:
219
- score += 3
220
- if term in excerpt:
221
- score += 3
222
 
223
- # Check expanded terms (lower weight)
224
- for term in expanded_terms:
225
- if term not in query_terms: # Skip original terms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  if term in title:
227
- score += 5
228
  if term in content:
229
- score += 1
230
  if term in excerpt:
231
- score += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
- # Add to scored deals if it has any relevance
234
- if score > 0:
235
- scored_deals.append((deal, score))
236
-
237
- # Sort by score (descending)
238
- scored_deals.sort(key=lambda x: x[1], reverse=True)
239
-
240
- # Extract the deals from the scored list
241
- relevant_deals = [deal for deal, _ in scored_deals[:5]]
242
 
243
  if relevant_deals:
244
  for i, deal in enumerate(relevant_deals, 1):
 
83
 
84
  return processed_deals
85
 
86
+ # Define product categories
87
+ category_descriptions = {
88
+ "electronics": "Electronic devices like headphones, speakers, TVs, smartphones, and gadgets",
89
+ "computers": "Laptops, desktops, computer parts, monitors, and computing accessories",
90
+ "mobile": "Mobile phones, smartphones, phone cases, screen protectors, and chargers",
91
+ "audio": "Headphones, earbuds, speakers, microphones, and audio equipment",
92
+ "clothing": "Clothes, shirts, pants, dresses, and fashion items",
93
+ "footwear": "Shoes, boots, sandals, slippers, and all types of footwear",
94
+ "home": "Home decor, furniture, bedding, and household items",
95
+ "kitchen": "Kitchen appliances, cookware, utensils, and kitchen gadgets",
96
+ "toys": "Toys, games, and children's entertainment items",
97
+ "sports": "Sports equipment, fitness gear, and outdoor recreation items",
98
+ "beauty": "Beauty products, makeup, skincare, and personal care items",
99
+ "books": "Books, e-books, audiobooks, and reading materials"
100
+ }
101
+
102
+ # List of categories
103
+ categories = list(category_descriptions.keys())
104
+
105
+ # Try to load the recommended models
106
  try:
107
+ # 1. Load BART model for zero-shot classification
108
+ from transformers import pipeline
109
+
110
+ # Initialize the zero-shot classification pipeline
111
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
112
+ print("Using facebook/bart-large-mnli for classification")
113
+
114
+ # 2. Load MPNet model for semantic search
115
+ from sentence_transformers import SentenceTransformer, util
116
 
117
+ # Load the sentence transformer model
118
+ sentence_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
119
+ print("Using sentence-transformers/all-mpnet-base-v2 for semantic search")
120
+
121
+ # Pre-compute embeddings for category descriptions
122
+ category_texts = list(category_descriptions.values())
123
+ category_embeddings = sentence_model.encode(category_texts, convert_to_tensor=True)
124
+
125
+ # Using recommended models
126
+ using_recommended_models = True
127
  except Exception as e:
128
+ # Fall back to local model if recommended models fail to load
129
+ print(f"Error loading recommended models: {str(e)}")
130
  print("Falling back to local model")
131
 
132
  model_id = "selvaonline/shopping-assistant"
 
142
  except Exception as e:
143
  print(f"Error loading categories: {str(e)}")
144
  categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
145
+
146
+ # Not using recommended models
147
+ using_recommended_models = False
148
 
149
  # Global variable to store deals data
150
  deals_cache = None
 
155
  """
156
  global deals_cache
157
 
158
+ # Get the top categories based on the model type
159
+ if using_recommended_models:
160
+ # Using BART for zero-shot classification
161
+ result = classifier(text, categories, multi_label=True)
 
 
162
 
163
+ # Extract categories and scores
164
+ top_categories = []
165
+ for i, (category, score) in enumerate(zip(result['labels'], result['scores'])):
166
+ if score > 0.1: # Lower threshold for zero-shot classification
167
+ top_categories.append((category, score))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Limit to top 3 categories
170
+ if i >= 2:
171
+ break
172
+ else:
173
+ # Using the original classification model
174
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
175
+
176
+ # Get the model prediction
177
+ with torch.no_grad():
178
+ outputs = model(**inputs)
179
+ predictions = torch.sigmoid(outputs.logits)
180
+
181
+ # Get the top categories
182
+ top_categories = []
183
+ for i, score in enumerate(predictions[0]):
184
+ if score > 0.5: # Threshold for multi-label classification
185
+ top_categories.append((categories[i], score.item()))
186
+
187
+ # Sort by score
188
+ top_categories.sort(key=lambda x: x[1], reverse=True)
189
 
190
  # Format the classification results
191
  if top_categories:
 
207
  deals_data = fetch_deals_data(num_pages=2) # Limit to 2 pages for faster response
208
  deals_cache = process_deals_data(deals_data)
209
 
210
+ # Using MPNet for semantic search if available
211
+ if using_recommended_models:
212
+ # Create deal texts for semantic search
213
+ deal_texts = []
214
+ for deal in deals_cache:
215
+ # Combine title and excerpt for better matching
216
+ deal_text = f"{deal['title']} {deal['excerpt']}"
217
+ deal_texts.append(deal_text)
218
+
219
+ # Encode the query and deals
220
+ query_embedding = sentence_model.encode(text, convert_to_tensor=True)
221
+ deal_embeddings = sentence_model.encode(deal_texts, convert_to_tensor=True)
 
 
 
 
 
 
 
 
222
 
223
+ # Calculate semantic similarity
224
+ similarities = util.cos_sim(query_embedding, deal_embeddings)[0]
225
 
226
+ # Get top 5 most similar deals
227
+ top_indices = torch.topk(similarities, k=min(5, len(deals_cache))).indices
 
 
 
 
 
 
228
 
229
+ # Extract the relevant deals
230
+ relevant_deals = [deals_cache[idx] for idx in top_indices]
231
+ else:
232
+ # Fallback to keyword-based search
233
+ query_terms = text.lower().split()
234
+ expanded_terms = list(query_terms)
235
+
236
+ # Add related terms based on the query
237
+ if any(term in text.lower() for term in ['headphone', 'headphones']):
238
+ expanded_terms.extend(['earbuds', 'earphones', 'earpods', 'airpods', 'audio', 'bluetooth', 'wireless'])
239
+ elif any(term in text.lower() for term in ['laptop', 'computer']):
240
+ expanded_terms.extend(['notebook', 'macbook', 'chromebook', 'pc'])
241
+ elif any(term in text.lower() for term in ['tv', 'television']):
242
+ expanded_terms.extend(['smart tv', 'roku', 'streaming'])
243
+ elif any(term in text.lower() for term in ['kitchen', 'appliance']):
244
+ expanded_terms.extend(['mixer', 'blender', 'toaster', 'microwave', 'oven'])
245
+
246
+ # Score deals based on relevance to the query
247
+ scored_deals = []
248
+ for deal in deals_cache:
249
+ title = deal['title'].lower()
250
+ content = deal['content'].lower()
251
+ excerpt = deal['excerpt'].lower()
252
+
253
+ score = 0
254
+
255
+ # Check original query terms (higher weight)
256
+ for term in query_terms:
257
  if term in title:
258
+ score += 10
259
  if term in content:
260
+ score += 3
261
  if term in excerpt:
262
+ score += 3
263
+
264
+ # Check expanded terms (lower weight)
265
+ for term in expanded_terms:
266
+ if term not in query_terms: # Skip original terms
267
+ if term in title:
268
+ score += 5
269
+ if term in content:
270
+ score += 1
271
+ if term in excerpt:
272
+ score += 1
273
+
274
+ # Add to scored deals if it has any relevance
275
+ if score > 0:
276
+ scored_deals.append((deal, score))
277
 
278
+ # Sort by score (descending)
279
+ scored_deals.sort(key=lambda x: x[1], reverse=True)
280
+
281
+ # Extract the deals from the scored list
282
+ relevant_deals = [deal for deal, _ in scored_deals[:5]]
 
 
 
 
283
 
284
  if relevant_deals:
285
  for i, deal in enumerate(relevant_deals, 1):