selvaonline commited on
Commit
01a3727
·
verified ·
1 Parent(s): ca99429

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import json
6
+
7
+ # Load the model and tokenizer
8
+ model_id = "selvaonline/shopping-assistant"
9
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
10
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
11
+
12
+ # Load the categories
13
+ try:
14
+ from huggingface_hub import hf_hub_download
15
+ categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
16
+ with open(categories_path, "r") as f:
17
+ categories = json.load(f)
18
+ except Exception as e:
19
+ print(f"Error loading categories: {str(e)}")
20
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
21
+
22
+ def classify_text(text):
23
+ """
24
+ Classify the text using the model
25
+ """
26
+ # Prepare the input
27
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
28
+
29
+ # Get the model prediction
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+ predictions = torch.sigmoid(outputs.logits)
33
+
34
+ # Get the top categories
35
+ top_categories = []
36
+ for i, score in enumerate(predictions[0]):
37
+ if score > 0.5: # Threshold for multi-label classification
38
+ top_categories.append((categories[i], score.item()))
39
+
40
+ # Sort by score
41
+ top_categories.sort(key=lambda x: x[1], reverse=True)
42
+
43
+ # Format the results
44
+ if top_categories:
45
+ result = f"Top categories for '{text}':\n\n"
46
+ for category, score in top_categories:
47
+ result += f"- {category}: {score:.4f}\n"
48
+
49
+ result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category."
50
+ else:
51
+ result = f"No categories found for '{text}'. Please try a different query."
52
+
53
+ return result
54
+
55
+ # Create the Gradio interface
56
+ demo = gr.Interface(
57
+ fn=classify_text,
58
+ inputs=gr.Textbox(
59
+ lines=2,
60
+ placeholder="Enter your shopping query here...",
61
+ label="Shopping Query"
62
+ ),
63
+ outputs=gr.Markdown(label="Results"),
64
+ title="Shopping Assistant",
65
+ description="""
66
+ This demo shows how to use the Shopping Assistant model to classify shopping queries into categories.
67
+ Enter a shopping query below to see which categories it belongs to.
68
+
69
+ Examples:
70
+ - "I'm looking for headphones"
71
+ - "Do you have any kitchen appliance deals?"
72
+ - "Show me the best laptop deals"
73
+ - "I need a new smart TV"
74
+ """,
75
+ examples=[
76
+ ["I'm looking for headphones"],
77
+ ["Do you have any kitchen appliance deals?"],
78
+ ["Show me the best laptop deals"],
79
+ ["I need a new smart TV"]
80
+ ],
81
+ theme=gr.themes.Soft()
82
+ )
83
+
84
+ # Launch the app
85
+ if __name__ == "__main__":
86
+ demo.launch()