Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -482,43 +482,111 @@ except Exception as e:
|
|
482 |
deals_cache = process_deals_data(SAMPLE_DEALS)
|
483 |
print(f"Initialized with {len(deals_cache)} sample deals")
|
484 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
def classify_text(text, fetch_deals=True):
|
486 |
"""
|
487 |
Classify the text using the model and fetch relevant deals
|
488 |
"""
|
489 |
-
global deals_cache
|
490 |
|
491 |
# Get the top categories based on the model type
|
492 |
if using_recommended_models:
|
493 |
# Using BART for zero-shot classification
|
494 |
-
|
495 |
-
|
496 |
-
# Extract categories and scores
|
497 |
-
top_categories = []
|
498 |
-
for i, (category, score) in enumerate(zip(result['labels'], result['scores'])):
|
499 |
-
if score > 0.1: # Lower threshold for zero-shot classification
|
500 |
-
top_categories.append((category, score))
|
501 |
|
502 |
-
#
|
503 |
-
|
504 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
505 |
else:
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
522 |
|
523 |
# Format the classification results
|
524 |
if top_categories:
|
|
|
482 |
deals_cache = process_deals_data(SAMPLE_DEALS)
|
483 |
print(f"Initialized with {len(deals_cache)} sample deals")
|
484 |
|
485 |
+
# Global variables for models
|
486 |
+
classifier = None
|
487 |
+
sentence_model = None
|
488 |
+
tokenizer = None
|
489 |
+
model = None
|
490 |
+
simple_classify = None
|
491 |
+
|
492 |
def classify_text(text, fetch_deals=True):
|
493 |
"""
|
494 |
Classify the text using the model and fetch relevant deals
|
495 |
"""
|
496 |
+
global deals_cache, classifier, sentence_model, tokenizer, model, simple_classify
|
497 |
|
498 |
# Get the top categories based on the model type
|
499 |
if using_recommended_models:
|
500 |
# Using BART for zero-shot classification
|
501 |
+
try:
|
502 |
+
result = classifier(text, categories, multi_label=True)
|
|
|
|
|
|
|
|
|
|
|
503 |
|
504 |
+
# Extract categories and scores
|
505 |
+
top_categories = []
|
506 |
+
for i, (category, score) in enumerate(zip(result['labels'], result['scores'])):
|
507 |
+
if score > 0.1: # Lower threshold for zero-shot classification
|
508 |
+
top_categories.append((category, score))
|
509 |
+
|
510 |
+
# Limit to top 3 categories
|
511 |
+
if i >= 2:
|
512 |
+
break
|
513 |
+
except Exception as e:
|
514 |
+
print(f"Error using zero-shot classification: {str(e)}")
|
515 |
+
# Fallback to simple keyword-based classification
|
516 |
+
top_categories = []
|
517 |
+
for category, terms in {
|
518 |
+
"electronics": ["electronics", "gadget", "device", "tech", "electronic"],
|
519 |
+
"computers": ["computer", "laptop", "desktop", "pc", "monitor"],
|
520 |
+
"mobile": ["phone", "mobile", "smartphone", "cell", "iphone", "android"],
|
521 |
+
"audio": ["audio", "headphone", "speaker", "earbud", "sound"],
|
522 |
+
"clothing": ["clothing", "clothes", "shirt", "pants", "dress", "wear"],
|
523 |
+
"footwear": ["shoe", "boot", "sneaker", "footwear", "sandal"],
|
524 |
+
"home": ["home", "furniture", "decor", "house", "living"],
|
525 |
+
"kitchen": ["kitchen", "cook", "appliance", "food", "dining"],
|
526 |
+
"toys": ["toy", "game", "play", "kid", "child"],
|
527 |
+
"sports": ["sport", "fitness", "exercise", "workout", "athletic"],
|
528 |
+
"beauty": ["beauty", "makeup", "cosmetic", "skin", "hair"],
|
529 |
+
"books": ["book", "read", "novel", "textbook", "ebook"]
|
530 |
+
}.items():
|
531 |
+
score = 0
|
532 |
+
for term in terms:
|
533 |
+
if term in text.lower():
|
534 |
+
score += 1
|
535 |
+
if score > 0:
|
536 |
+
top_categories.append((category, score/5))
|
537 |
+
|
538 |
+
# Sort by score
|
539 |
+
top_categories.sort(key=lambda x: x[1], reverse=True)
|
540 |
+
top_categories = top_categories[:3] # Limit to top 3
|
541 |
+
elif simple_classify is not None:
|
542 |
+
# Using simple keyword-based classification
|
543 |
+
top_categories = simple_classify(text)
|
544 |
else:
|
545 |
+
try:
|
546 |
+
# Using the original classification model
|
547 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
548 |
+
|
549 |
+
# Get the model prediction
|
550 |
+
with torch.no_grad():
|
551 |
+
outputs = model(**inputs)
|
552 |
+
predictions = torch.sigmoid(outputs.logits)
|
553 |
+
|
554 |
+
# Get the top categories
|
555 |
+
top_categories = []
|
556 |
+
for i, score in enumerate(predictions[0]):
|
557 |
+
if score > 0.5: # Threshold for multi-label classification
|
558 |
+
top_categories.append((categories[i], score.item()))
|
559 |
+
|
560 |
+
# Sort by score
|
561 |
+
top_categories.sort(key=lambda x: x[1], reverse=True)
|
562 |
+
except Exception as e:
|
563 |
+
print(f"Error using local model: {str(e)}")
|
564 |
+
# Fallback to simple keyword-based classification
|
565 |
+
top_categories = []
|
566 |
+
for category, terms in {
|
567 |
+
"electronics": ["electronics", "gadget", "device", "tech", "electronic"],
|
568 |
+
"computers": ["computer", "laptop", "desktop", "pc", "monitor"],
|
569 |
+
"mobile": ["phone", "mobile", "smartphone", "cell", "iphone", "android"],
|
570 |
+
"audio": ["audio", "headphone", "speaker", "earbud", "sound"],
|
571 |
+
"clothing": ["clothing", "clothes", "shirt", "pants", "dress", "wear"],
|
572 |
+
"footwear": ["shoe", "boot", "sneaker", "footwear", "sandal"],
|
573 |
+
"home": ["home", "furniture", "decor", "house", "living"],
|
574 |
+
"kitchen": ["kitchen", "cook", "appliance", "food", "dining"],
|
575 |
+
"toys": ["toy", "game", "play", "kid", "child"],
|
576 |
+
"sports": ["sport", "fitness", "exercise", "workout", "athletic"],
|
577 |
+
"beauty": ["beauty", "makeup", "cosmetic", "skin", "hair"],
|
578 |
+
"books": ["book", "read", "novel", "textbook", "ebook"]
|
579 |
+
}.items():
|
580 |
+
score = 0
|
581 |
+
for term in terms:
|
582 |
+
if term in text.lower():
|
583 |
+
score += 1
|
584 |
+
if score > 0:
|
585 |
+
top_categories.append((category, score/5))
|
586 |
+
|
587 |
+
# Sort by score
|
588 |
+
top_categories.sort(key=lambda x: x[1], reverse=True)
|
589 |
+
top_categories = top_categories[:3] # Limit to top 3
|
590 |
|
591 |
# Format the classification results
|
592 |
if top_categories:
|