Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -83,25 +83,50 @@ def process_deals_data(deals_data):
|
|
83 |
|
84 |
return processed_deals
|
85 |
|
86 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
-
#
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
except Exception as e:
|
103 |
-
# Fall back to local model if
|
104 |
-
print(f"Error loading
|
105 |
print("Falling back to local model")
|
106 |
|
107 |
model_id = "selvaonline/shopping-assistant"
|
@@ -117,6 +142,9 @@ except Exception as e:
|
|
117 |
except Exception as e:
|
118 |
print(f"Error loading categories: {str(e)}")
|
119 |
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
|
|
|
|
|
|
|
120 |
|
121 |
# Global variable to store deals data
|
122 |
deals_cache = None
|
@@ -127,46 +155,37 @@ def classify_text(text, fetch_deals=True):
|
|
127 |
"""
|
128 |
global deals_cache
|
129 |
|
130 |
-
#
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
with torch.no_grad():
|
135 |
-
outputs = model(**inputs)
|
136 |
|
137 |
-
#
|
138 |
-
|
139 |
-
|
140 |
-
if
|
141 |
-
|
142 |
-
predictions = torch.sigmoid(outputs.logits)
|
143 |
-
|
144 |
-
# Get the top categories
|
145 |
-
top_categories = []
|
146 |
-
for i, score in enumerate(predictions[0]):
|
147 |
-
if score > 0.3: # Lower threshold for e-commerce model
|
148 |
-
top_categories.append((categories[i], score.item()))
|
149 |
-
else:
|
150 |
-
# Single-label classification
|
151 |
-
probabilities = torch.softmax(outputs.logits, dim=1)
|
152 |
-
values, indices = torch.topk(probabilities, 3)
|
153 |
-
|
154 |
-
top_categories = []
|
155 |
-
for i, idx in enumerate(indices[0]):
|
156 |
-
if idx < len(categories):
|
157 |
-
top_categories.append((categories[idx.item()], values[0][i].item()))
|
158 |
-
else:
|
159 |
-
# Fallback for other model formats
|
160 |
-
predictions = torch.sigmoid(outputs[0])
|
161 |
|
162 |
-
#
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
|
171 |
# Format the classification results
|
172 |
if top_categories:
|
@@ -188,57 +207,79 @@ def classify_text(text, fetch_deals=True):
|
|
188 |
deals_data = fetch_deals_data(num_pages=2) # Limit to 2 pages for faster response
|
189 |
deals_cache = process_deals_data(deals_data)
|
190 |
|
191 |
-
#
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
expanded_terms.extend(['mixer', 'blender', 'toaster', 'microwave', 'oven'])
|
204 |
-
|
205 |
-
# Score deals based on relevance to the query
|
206 |
-
scored_deals = []
|
207 |
-
for deal in deals_cache:
|
208 |
-
title = deal['title'].lower()
|
209 |
-
content = deal['content'].lower()
|
210 |
-
excerpt = deal['excerpt'].lower()
|
211 |
|
212 |
-
|
|
|
213 |
|
214 |
-
#
|
215 |
-
|
216 |
-
if term in title:
|
217 |
-
score += 10
|
218 |
-
if term in content:
|
219 |
-
score += 3
|
220 |
-
if term in excerpt:
|
221 |
-
score += 3
|
222 |
|
223 |
-
#
|
224 |
-
for
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
if term in title:
|
227 |
-
score +=
|
228 |
if term in content:
|
229 |
-
score +=
|
230 |
if term in excerpt:
|
231 |
-
score +=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
|
233 |
-
#
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
scored_deals.sort(key=lambda x: x[1], reverse=True)
|
239 |
-
|
240 |
-
# Extract the deals from the scored list
|
241 |
-
relevant_deals = [deal for deal, _ in scored_deals[:5]]
|
242 |
|
243 |
if relevant_deals:
|
244 |
for i, deal in enumerate(relevant_deals, 1):
|
|
|
83 |
|
84 |
return processed_deals
|
85 |
|
86 |
+
# Define product categories
|
87 |
+
category_descriptions = {
|
88 |
+
"electronics": "Electronic devices like headphones, speakers, TVs, smartphones, and gadgets",
|
89 |
+
"computers": "Laptops, desktops, computer parts, monitors, and computing accessories",
|
90 |
+
"mobile": "Mobile phones, smartphones, phone cases, screen protectors, and chargers",
|
91 |
+
"audio": "Headphones, earbuds, speakers, microphones, and audio equipment",
|
92 |
+
"clothing": "Clothes, shirts, pants, dresses, and fashion items",
|
93 |
+
"footwear": "Shoes, boots, sandals, slippers, and all types of footwear",
|
94 |
+
"home": "Home decor, furniture, bedding, and household items",
|
95 |
+
"kitchen": "Kitchen appliances, cookware, utensils, and kitchen gadgets",
|
96 |
+
"toys": "Toys, games, and children's entertainment items",
|
97 |
+
"sports": "Sports equipment, fitness gear, and outdoor recreation items",
|
98 |
+
"beauty": "Beauty products, makeup, skincare, and personal care items",
|
99 |
+
"books": "Books, e-books, audiobooks, and reading materials"
|
100 |
+
}
|
101 |
+
|
102 |
+
# List of categories
|
103 |
+
categories = list(category_descriptions.keys())
|
104 |
+
|
105 |
+
# Try to load the recommended models
|
106 |
try:
|
107 |
+
# 1. Load BART model for zero-shot classification
|
108 |
+
from transformers import pipeline
|
109 |
+
|
110 |
+
# Initialize the zero-shot classification pipeline
|
111 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
112 |
+
print("Using facebook/bart-large-mnli for classification")
|
113 |
+
|
114 |
+
# 2. Load MPNet model for semantic search
|
115 |
+
from sentence_transformers import SentenceTransformer, util
|
116 |
|
117 |
+
# Load the sentence transformer model
|
118 |
+
sentence_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
119 |
+
print("Using sentence-transformers/all-mpnet-base-v2 for semantic search")
|
120 |
+
|
121 |
+
# Pre-compute embeddings for category descriptions
|
122 |
+
category_texts = list(category_descriptions.values())
|
123 |
+
category_embeddings = sentence_model.encode(category_texts, convert_to_tensor=True)
|
124 |
+
|
125 |
+
# Using recommended models
|
126 |
+
using_recommended_models = True
|
127 |
except Exception as e:
|
128 |
+
# Fall back to local model if recommended models fail to load
|
129 |
+
print(f"Error loading recommended models: {str(e)}")
|
130 |
print("Falling back to local model")
|
131 |
|
132 |
model_id = "selvaonline/shopping-assistant"
|
|
|
142 |
except Exception as e:
|
143 |
print(f"Error loading categories: {str(e)}")
|
144 |
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
|
145 |
+
|
146 |
+
# Not using recommended models
|
147 |
+
using_recommended_models = False
|
148 |
|
149 |
# Global variable to store deals data
|
150 |
deals_cache = None
|
|
|
155 |
"""
|
156 |
global deals_cache
|
157 |
|
158 |
+
# Get the top categories based on the model type
|
159 |
+
if using_recommended_models:
|
160 |
+
# Using BART for zero-shot classification
|
161 |
+
result = classifier(text, categories, multi_label=True)
|
|
|
|
|
162 |
|
163 |
+
# Extract categories and scores
|
164 |
+
top_categories = []
|
165 |
+
for i, (category, score) in enumerate(zip(result['labels'], result['scores'])):
|
166 |
+
if score > 0.1: # Lower threshold for zero-shot classification
|
167 |
+
top_categories.append((category, score))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
# Limit to top 3 categories
|
170 |
+
if i >= 2:
|
171 |
+
break
|
172 |
+
else:
|
173 |
+
# Using the original classification model
|
174 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
175 |
+
|
176 |
+
# Get the model prediction
|
177 |
+
with torch.no_grad():
|
178 |
+
outputs = model(**inputs)
|
179 |
+
predictions = torch.sigmoid(outputs.logits)
|
180 |
+
|
181 |
+
# Get the top categories
|
182 |
+
top_categories = []
|
183 |
+
for i, score in enumerate(predictions[0]):
|
184 |
+
if score > 0.5: # Threshold for multi-label classification
|
185 |
+
top_categories.append((categories[i], score.item()))
|
186 |
+
|
187 |
+
# Sort by score
|
188 |
+
top_categories.sort(key=lambda x: x[1], reverse=True)
|
189 |
|
190 |
# Format the classification results
|
191 |
if top_categories:
|
|
|
207 |
deals_data = fetch_deals_data(num_pages=2) # Limit to 2 pages for faster response
|
208 |
deals_cache = process_deals_data(deals_data)
|
209 |
|
210 |
+
# Using MPNet for semantic search if available
|
211 |
+
if using_recommended_models:
|
212 |
+
# Create deal texts for semantic search
|
213 |
+
deal_texts = []
|
214 |
+
for deal in deals_cache:
|
215 |
+
# Combine title and excerpt for better matching
|
216 |
+
deal_text = f"{deal['title']} {deal['excerpt']}"
|
217 |
+
deal_texts.append(deal_text)
|
218 |
+
|
219 |
+
# Encode the query and deals
|
220 |
+
query_embedding = sentence_model.encode(text, convert_to_tensor=True)
|
221 |
+
deal_embeddings = sentence_model.encode(deal_texts, convert_to_tensor=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
|
223 |
+
# Calculate semantic similarity
|
224 |
+
similarities = util.cos_sim(query_embedding, deal_embeddings)[0]
|
225 |
|
226 |
+
# Get top 5 most similar deals
|
227 |
+
top_indices = torch.topk(similarities, k=min(5, len(deals_cache))).indices
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
|
229 |
+
# Extract the relevant deals
|
230 |
+
relevant_deals = [deals_cache[idx] for idx in top_indices]
|
231 |
+
else:
|
232 |
+
# Fallback to keyword-based search
|
233 |
+
query_terms = text.lower().split()
|
234 |
+
expanded_terms = list(query_terms)
|
235 |
+
|
236 |
+
# Add related terms based on the query
|
237 |
+
if any(term in text.lower() for term in ['headphone', 'headphones']):
|
238 |
+
expanded_terms.extend(['earbuds', 'earphones', 'earpods', 'airpods', 'audio', 'bluetooth', 'wireless'])
|
239 |
+
elif any(term in text.lower() for term in ['laptop', 'computer']):
|
240 |
+
expanded_terms.extend(['notebook', 'macbook', 'chromebook', 'pc'])
|
241 |
+
elif any(term in text.lower() for term in ['tv', 'television']):
|
242 |
+
expanded_terms.extend(['smart tv', 'roku', 'streaming'])
|
243 |
+
elif any(term in text.lower() for term in ['kitchen', 'appliance']):
|
244 |
+
expanded_terms.extend(['mixer', 'blender', 'toaster', 'microwave', 'oven'])
|
245 |
+
|
246 |
+
# Score deals based on relevance to the query
|
247 |
+
scored_deals = []
|
248 |
+
for deal in deals_cache:
|
249 |
+
title = deal['title'].lower()
|
250 |
+
content = deal['content'].lower()
|
251 |
+
excerpt = deal['excerpt'].lower()
|
252 |
+
|
253 |
+
score = 0
|
254 |
+
|
255 |
+
# Check original query terms (higher weight)
|
256 |
+
for term in query_terms:
|
257 |
if term in title:
|
258 |
+
score += 10
|
259 |
if term in content:
|
260 |
+
score += 3
|
261 |
if term in excerpt:
|
262 |
+
score += 3
|
263 |
+
|
264 |
+
# Check expanded terms (lower weight)
|
265 |
+
for term in expanded_terms:
|
266 |
+
if term not in query_terms: # Skip original terms
|
267 |
+
if term in title:
|
268 |
+
score += 5
|
269 |
+
if term in content:
|
270 |
+
score += 1
|
271 |
+
if term in excerpt:
|
272 |
+
score += 1
|
273 |
+
|
274 |
+
# Add to scored deals if it has any relevance
|
275 |
+
if score > 0:
|
276 |
+
scored_deals.append((deal, score))
|
277 |
|
278 |
+
# Sort by score (descending)
|
279 |
+
scored_deals.sort(key=lambda x: x[1], reverse=True)
|
280 |
+
|
281 |
+
# Extract the deals from the scored list
|
282 |
+
relevant_deals = [deal for deal, _ in scored_deals[:5]]
|
|
|
|
|
|
|
|
|
283 |
|
284 |
if relevant_deals:
|
285 |
for i, deal in enumerate(relevant_deals, 1):
|