Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import json | |
import os | |
import requests | |
import re | |
# Function to extract text from HTML (from shopping_assistant.py) | |
def extract_text_from_html(html): | |
""" | |
Extract text from HTML without using BeautifulSoup | |
""" | |
# Remove HTML tags | |
text = re.sub(r'<[^>]+>', ' ', html) | |
# Remove extra whitespace | |
text = re.sub(r'\s+', ' ', text) | |
# Decode HTML entities | |
text = text.replace(' ', ' ').replace('&', '&').replace('<', '<').replace('>', '>') | |
return text.strip() | |
# Function to fetch deals from DealsFinders.com (from shopping_assistant.py) | |
def fetch_deals_data(url="https://www.dealsfinders.com/wp-json/wp/v2/posts", num_pages=2, per_page=100): | |
""" | |
Fetch deals data exclusively from the DealsFinders API | |
""" | |
all_deals = [] | |
# Fetch from the DealsFinders API | |
for page in range(1, num_pages + 1): | |
try: | |
# Add a user agent to avoid being blocked | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.114 Safari/537.36' | |
} | |
response = requests.get(f"{url}?page={page}&per_page={per_page}", headers=headers) | |
if response.status_code == 200: | |
deals = response.json() | |
all_deals.extend(deals) | |
print(f"Fetched page {page} with {len(deals)} deals from DealsFinders API") | |
# If we get fewer deals than requested, we've reached the end | |
if len(deals) < per_page: | |
print(f"Reached the end of available deals at page {page}") | |
break | |
else: | |
print(f"Failed to fetch page {page} from DealsFinders API: {response.status_code}") | |
break | |
except Exception as e: | |
print(f"Error fetching page {page} from DealsFinders API: {str(e)}") | |
break | |
return all_deals | |
# Function to process deals data (from shopping_assistant.py) | |
def process_deals_data(deals_data): | |
""" | |
Process the deals data into a structured format | |
""" | |
processed_deals = [] | |
for deal in deals_data: | |
try: | |
# Extract relevant information using our HTML text extractor | |
content_html = deal.get('content', {}).get('rendered', '') | |
excerpt_html = deal.get('excerpt', {}).get('rendered', '') | |
clean_content = extract_text_from_html(content_html) | |
clean_excerpt = extract_text_from_html(excerpt_html) | |
processed_deal = { | |
'id': deal.get('id'), | |
'title': deal.get('title', {}).get('rendered', ''), | |
'link': deal.get('link', ''), | |
'date': deal.get('date', ''), | |
'content': clean_content, | |
'excerpt': clean_excerpt | |
} | |
processed_deals.append(processed_deal) | |
except Exception as e: | |
print(f"Error processing deal: {str(e)}") | |
return processed_deals | |
# Load the e-commerce specific model and tokenizer | |
try: | |
# Try to load the e-commerce BERT model | |
tokenizer = AutoTokenizer.from_pretrained("prithivida/ecommerce-bert-base-uncased") | |
model = AutoModelForSequenceClassification.from_pretrained("prithivida/ecommerce-bert-base-uncased") | |
# E-commerce BERT categories | |
categories = [ | |
"electronics", "computers", "mobile_phones", "accessories", | |
"clothing", "footwear", "watches", "jewelry", | |
"home", "kitchen", "furniture", "decor", | |
"beauty", "personal_care", "health", "wellness", | |
"toys", "games", "sports", "outdoors", | |
"books", "stationery", "music", "movies" | |
] | |
print("Using e-commerce BERT model") | |
except Exception as e: | |
# Fall back to local model if e-commerce BERT fails to load | |
print(f"Error loading e-commerce BERT model: {str(e)}") | |
print("Falling back to local model") | |
model_id = "selvaonline/shopping-assistant" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
# Load the local categories | |
try: | |
from huggingface_hub import hf_hub_download | |
categories_path = hf_hub_download(repo_id=model_id, filename="categories.json") | |
with open(categories_path, "r") as f: | |
categories = json.load(f) | |
except Exception as e: | |
print(f"Error loading categories: {str(e)}") | |
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] | |
# Global variable to store deals data | |
deals_cache = None | |
def classify_text(text, fetch_deals=True): | |
""" | |
Classify the text using the model and fetch relevant deals | |
""" | |
global deals_cache | |
# Prepare the input for classification | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
# Get the model prediction | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Handle different model output formats | |
if hasattr(outputs, 'logits'): | |
# For models that return logits | |
if outputs.logits.shape[1] == len(categories): | |
# Multi-label classification | |
predictions = torch.sigmoid(outputs.logits) | |
# Get the top categories | |
top_categories = [] | |
for i, score in enumerate(predictions[0]): | |
if score > 0.3: # Lower threshold for e-commerce model | |
top_categories.append((categories[i], score.item())) | |
else: | |
# Single-label classification | |
probabilities = torch.softmax(outputs.logits, dim=1) | |
values, indices = torch.topk(probabilities, 3) | |
top_categories = [] | |
for i, idx in enumerate(indices[0]): | |
if idx < len(categories): | |
top_categories.append((categories[idx.item()], values[0][i].item())) | |
else: | |
# Fallback for other model formats | |
predictions = torch.sigmoid(outputs[0]) | |
# Get the top categories | |
top_categories = [] | |
for i, score in enumerate(predictions[0]): | |
if score > 0.5: | |
top_categories.append((categories[i], score.item())) | |
# Sort by score | |
top_categories.sort(key=lambda x: x[1], reverse=True) | |
# Format the classification results | |
if top_categories: | |
result = f"Top categories for '{text}':\n\n" | |
for category, score in top_categories: | |
result += f"- {category}: {score:.4f}\n" | |
result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category.\n\n" | |
else: | |
result = f"No categories found for '{text}'. Please try a different query.\n\n" | |
# Fetch and display deals if requested | |
if fetch_deals: | |
result += "## Relevant Deals from DealsFinders.com\n\n" | |
try: | |
# Fetch deals data if not already cached | |
if deals_cache is None: | |
deals_data = fetch_deals_data(num_pages=2) # Limit to 2 pages for faster response | |
deals_cache = process_deals_data(deals_data) | |
# Extract query terms and expand with related terms | |
query_terms = text.lower().split() | |
expanded_terms = list(query_terms) | |
# Add related terms based on the query | |
if any(term in text.lower() for term in ['headphone', 'headphones']): | |
expanded_terms.extend(['earbuds', 'earphones', 'earpods', 'airpods', 'audio', 'bluetooth', 'wireless']) | |
elif any(term in text.lower() for term in ['laptop', 'computer']): | |
expanded_terms.extend(['notebook', 'macbook', 'chromebook', 'pc']) | |
elif any(term in text.lower() for term in ['tv', 'television']): | |
expanded_terms.extend(['smart tv', 'roku', 'streaming']) | |
elif any(term in text.lower() for term in ['kitchen', 'appliance']): | |
expanded_terms.extend(['mixer', 'blender', 'toaster', 'microwave', 'oven']) | |
# Score deals based on relevance to the query | |
scored_deals = [] | |
for deal in deals_cache: | |
title = deal['title'].lower() | |
content = deal['content'].lower() | |
excerpt = deal['excerpt'].lower() | |
score = 0 | |
# Check original query terms (higher weight) | |
for term in query_terms: | |
if term in title: | |
score += 10 | |
if term in content: | |
score += 3 | |
if term in excerpt: | |
score += 3 | |
# Check expanded terms (lower weight) | |
for term in expanded_terms: | |
if term not in query_terms: # Skip original terms | |
if term in title: | |
score += 5 | |
if term in content: | |
score += 1 | |
if term in excerpt: | |
score += 1 | |
# Add to scored deals if it has any relevance | |
if score > 0: | |
scored_deals.append((deal, score)) | |
# Sort by score (descending) | |
scored_deals.sort(key=lambda x: x[1], reverse=True) | |
# Extract the deals from the scored list | |
relevant_deals = [deal for deal, _ in scored_deals[:5]] | |
if relevant_deals: | |
for i, deal in enumerate(relevant_deals, 1): | |
result += f"{i}. [{deal['title']}]({deal['link']})\n\n" | |
else: | |
result += "No specific deals found for your query. Try a different search term or browse the recommended category.\n\n" | |
except Exception as e: | |
result += f"Error fetching deals: {str(e)}\n\n" | |
return result | |
# Create the Gradio interface | |
demo = gr.Interface( | |
fn=classify_text, | |
inputs=[ | |
gr.Textbox( | |
lines=2, | |
placeholder="Enter your shopping query here...", | |
label="Shopping Query" | |
), | |
gr.Checkbox( | |
label="Fetch Deals", | |
value=True, | |
info="Check to fetch and display deals from DealsFinders.com" | |
) | |
], | |
outputs=gr.Markdown(label="Results"), | |
title="Shopping Assistant", | |
description=""" | |
This demo shows how to use the Shopping Assistant model to classify shopping queries into categories and find relevant deals. | |
Enter a shopping query below to see which categories it belongs to and find deals from DealsFinders.com. | |
Examples: | |
- "I'm looking for headphones" | |
- "Do you have any kitchen appliance deals?" | |
- "Show me the best laptop deals" | |
- "I need a new smart TV" | |
""", | |
examples=[ | |
["I'm looking for headphones", True], | |
["Do you have any kitchen appliance deals?", True], | |
["Show me the best laptop deals", True], | |
["I need a new smart TV", True], | |
["headphone deals", True] | |
], | |
theme=gr.themes.Soft() | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() | |