selvaonline commited on
Commit
a9553ab
·
verified ·
1 Parent(s): 863cca4

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +78 -22
app.py CHANGED
@@ -258,25 +258,38 @@ category_descriptions = {
258
  # List of categories
259
  categories = list(category_descriptions.keys())
260
 
261
- # Try to load the recommended models
262
  try:
263
- # 1. Load BART model for zero-shot classification
264
- from transformers import pipeline
 
265
 
266
- # Initialize the zero-shot classification pipeline
267
- classifier = pipeline("zero-shot-classification", model="facebook/bart-base-mnli")
268
- print("Using facebook/bart-base-mnli for classification")
 
 
 
 
 
 
269
 
270
- # 2. Load MiniLM model for semantic search
 
271
  from sentence_transformers import SentenceTransformer, util
272
 
273
- # Load the sentence transformer model
274
- sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
275
- print("Using sentence-transformers/all-MiniLM-L6-v2 for semantic search")
 
 
 
276
 
277
  # Pre-compute embeddings for category descriptions
 
278
  category_texts = list(category_descriptions.values())
279
  category_embeddings = sentence_model.encode(category_texts, convert_to_tensor=True)
 
280
 
281
  # Using recommended models
282
  using_recommended_models = True
@@ -285,20 +298,63 @@ except Exception as e:
285
  print(f"Error loading recommended models: {str(e)}")
286
  print("Falling back to local model")
287
 
288
- model_path = os.path.dirname(os.path.abspath(__file__))
289
- tokenizer = AutoTokenizer.from_pretrained(model_path)
290
- model = AutoModelForSequenceClassification.from_pretrained(model_path)
291
-
292
- # Load the local categories
293
  try:
294
- with open(os.path.join(model_path, "categories.json"), "r") as f:
295
- categories = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  except Exception as e:
297
- print(f"Error loading categories: {str(e)}")
298
- categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
299
-
300
- # Not using recommended models
301
- using_recommended_models = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  # File path for storing deals data locally
304
  DEALS_DATA_PATH = "deals_data.json"
 
258
  # List of categories
259
  categories = list(category_descriptions.keys())
260
 
261
+ # Try to load the recommended models with specific versions and robust error handling
262
  try:
263
+ print("Loading classification model...")
264
+ # 1. Load BART model for zero-shot classification with specific version
265
+ from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
266
 
267
+ # Initialize the zero-shot classification pipeline with specific model version
268
+ # Use a smaller model with explicit version
269
+ classifier = pipeline(
270
+ "zero-shot-classification",
271
+ model="facebook/bart-base-mnli",
272
+ framework="pt", # Explicitly use PyTorch
273
+ device=-1 # Use CPU
274
+ )
275
+ print("Successfully loaded facebook/bart-base-mnli for classification")
276
 
277
+ print("Loading semantic search model...")
278
+ # 2. Load MiniLM model for semantic search with specific version
279
  from sentence_transformers import SentenceTransformer, util
280
 
281
+ # Load the sentence transformer model with explicit version
282
+ sentence_model = SentenceTransformer(
283
+ 'sentence-transformers/all-MiniLM-L6-v2',
284
+ device="cpu" # Explicitly use CPU
285
+ )
286
+ print("Successfully loaded sentence-transformers/all-MiniLM-L6-v2 for semantic search")
287
 
288
  # Pre-compute embeddings for category descriptions
289
+ print("Pre-computing category embeddings...")
290
  category_texts = list(category_descriptions.values())
291
  category_embeddings = sentence_model.encode(category_texts, convert_to_tensor=True)
292
+ print("Successfully pre-computed category embeddings")
293
 
294
  # Using recommended models
295
  using_recommended_models = True
 
298
  print(f"Error loading recommended models: {str(e)}")
299
  print("Falling back to local model")
300
 
 
 
 
 
 
301
  try:
302
+ model_path = os.path.dirname(os.path.abspath(__file__))
303
+ print(f"Loading local model from {model_path}")
304
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
305
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
306
+ print("Successfully loaded local model")
307
+
308
+ # Load the local categories
309
+ try:
310
+ with open(os.path.join(model_path, "categories.json"), "r") as f:
311
+ categories = json.load(f)
312
+ print(f"Loaded {len(categories)} categories from categories.json")
313
+ except Exception as e:
314
+ print(f"Error loading categories: {str(e)}")
315
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
316
+ print(f"Using default categories: {categories}")
317
+
318
+ # Not using recommended models
319
+ using_recommended_models = False
320
  except Exception as e:
321
+ print(f"Error loading local model: {str(e)}")
322
+ print("Using extremely simplified fallback mode")
323
+ # Define a simple fallback classifier function
324
+ def simple_classify(text):
325
+ keywords = {
326
+ "electronics": ["electronics", "gadget", "device", "tech", "electronic"],
327
+ "computers": ["computer", "laptop", "desktop", "pc", "monitor"],
328
+ "mobile": ["phone", "mobile", "smartphone", "cell", "iphone", "android"],
329
+ "audio": ["audio", "headphone", "speaker", "earbud", "sound"],
330
+ "clothing": ["clothing", "clothes", "shirt", "pants", "dress", "wear"],
331
+ "footwear": ["shoe", "boot", "sneaker", "footwear", "sandal"],
332
+ "home": ["home", "furniture", "decor", "house", "living"],
333
+ "kitchen": ["kitchen", "cook", "appliance", "food", "dining"],
334
+ "toys": ["toy", "game", "play", "kid", "child"],
335
+ "sports": ["sport", "fitness", "exercise", "workout", "athletic"],
336
+ "beauty": ["beauty", "makeup", "cosmetic", "skin", "hair"],
337
+ "books": ["book", "read", "novel", "textbook", "ebook"]
338
+ }
339
+
340
+ text_lower = text.lower()
341
+ scores = {}
342
+
343
+ for category, terms in keywords.items():
344
+ score = 0
345
+ for term in terms:
346
+ if term in text_lower:
347
+ score += 1
348
+ scores[category] = score
349
+
350
+ # Sort by score
351
+ sorted_categories = sorted(scores.items(), key=lambda x: x[1], reverse=True)
352
+
353
+ # Return top categories with scores
354
+ return [(cat, score/5) for cat, score in sorted_categories if score > 0][:3]
355
+
356
+ # Not using recommended models
357
+ using_recommended_models = False
358
 
359
  # File path for storing deals data locally
360
  DEALS_DATA_PATH = "deals_data.json"