File size: 2,764 Bytes
01a3727
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json

# Load the model and tokenizer
model_id = "selvaonline/shopping-assistant"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)

# Load the categories
try:
    from huggingface_hub import hf_hub_download
    categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
    with open(categories_path, "r") as f:
        categories = json.load(f)
except Exception as e:
    print(f"Error loading categories: {str(e)}")
    categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]

def classify_text(text):
    """
    Classify the text using the model
    """
    # Prepare the input
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    
    # Get the model prediction
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.sigmoid(outputs.logits)
    
    # Get the top categories
    top_categories = []
    for i, score in enumerate(predictions[0]):
        if score > 0.5:  # Threshold for multi-label classification
            top_categories.append((categories[i], score.item()))
    
    # Sort by score
    top_categories.sort(key=lambda x: x[1], reverse=True)
    
    # Format the results
    if top_categories:
        result = f"Top categories for '{text}':\n\n"
        for category, score in top_categories:
            result += f"- {category}: {score:.4f}\n"
        
        result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category."
    else:
        result = f"No categories found for '{text}'. Please try a different query."
    
    return result

# Create the Gradio interface
demo = gr.Interface(
    fn=classify_text,
    inputs=gr.Textbox(
        lines=2, 
        placeholder="Enter your shopping query here...",
        label="Shopping Query"
    ),
    outputs=gr.Markdown(label="Results"),
    title="Shopping Assistant",
    description="""
    This demo shows how to use the Shopping Assistant model to classify shopping queries into categories.
    Enter a shopping query below to see which categories it belongs to.
    
    Examples:
    - "I'm looking for headphones"
    - "Do you have any kitchen appliance deals?"
    - "Show me the best laptop deals"
    - "I need a new smart TV"
    """,
    examples=[
        ["I'm looking for headphones"],
        ["Do you have any kitchen appliance deals?"],
        ["Show me the best laptop deals"],
        ["I need a new smart TV"]
    ],
    theme=gr.themes.Soft()
)

# Launch the app
if __name__ == "__main__":
    demo.launch()