Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
@@ -83,20 +83,40 @@ def process_deals_data(deals_data):
|
|
83 |
|
84 |
return processed_deals
|
85 |
|
86 |
-
# Load the model and tokenizer
|
87 |
-
model_id = "selvaonline/shopping-assistant"
|
88 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
89 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
90 |
-
|
91 |
-
# Load the categories
|
92 |
try:
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
|
95 |
with open(categories_path, "r") as f:
|
96 |
-
|
97 |
-
except Exception as e:
|
98 |
-
|
99 |
-
|
100 |
|
101 |
# Global variable to store deals data
|
102 |
deals_cache = None
|
@@ -113,13 +133,37 @@ def classify_text(text, fetch_deals=True):
|
|
113 |
# Get the model prediction
|
114 |
with torch.no_grad():
|
115 |
outputs = model(**inputs)
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Sort by score
|
125 |
top_categories.sort(key=lambda x: x[1], reverse=True)
|
|
|
83 |
|
84 |
return processed_deals
|
85 |
|
86 |
+
# Load the e-commerce specific model and tokenizer
|
|
|
|
|
|
|
|
|
|
|
87 |
try:
|
88 |
+
# Try to load the e-commerce BERT model
|
89 |
+
tokenizer = AutoTokenizer.from_pretrained("prithivida/ecommerce-bert-base-uncased")
|
90 |
+
model = AutoModelForSequenceClassification.from_pretrained("prithivida/ecommerce-bert-base-uncased")
|
91 |
+
|
92 |
+
# E-commerce BERT categories
|
93 |
+
categories = [
|
94 |
+
"electronics", "computers", "mobile_phones", "accessories",
|
95 |
+
"clothing", "footwear", "watches", "jewelry",
|
96 |
+
"home", "kitchen", "furniture", "decor",
|
97 |
+
"beauty", "personal_care", "health", "wellness",
|
98 |
+
"toys", "games", "sports", "outdoors",
|
99 |
+
"books", "stationery", "music", "movies"
|
100 |
+
]
|
101 |
+
print("Using e-commerce BERT model")
|
102 |
+
except Exception as e:
|
103 |
+
# Fall back to local model if e-commerce BERT fails to load
|
104 |
+
print(f"Error loading e-commerce BERT model: {str(e)}")
|
105 |
+
print("Falling back to local model")
|
106 |
+
|
107 |
+
model_id = "selvaonline/shopping-assistant"
|
108 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
109 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
110 |
+
|
111 |
+
# Load the local categories
|
112 |
+
try:
|
113 |
+
from huggingface_hub import hf_hub_download
|
114 |
categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
|
115 |
with open(categories_path, "r") as f:
|
116 |
+
categories = json.load(f)
|
117 |
+
except Exception as e:
|
118 |
+
print(f"Error loading categories: {str(e)}")
|
119 |
+
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
|
120 |
|
121 |
# Global variable to store deals data
|
122 |
deals_cache = None
|
|
|
133 |
# Get the model prediction
|
134 |
with torch.no_grad():
|
135 |
outputs = model(**inputs)
|
136 |
+
|
137 |
+
# Handle different model output formats
|
138 |
+
if hasattr(outputs, 'logits'):
|
139 |
+
# For models that return logits
|
140 |
+
if outputs.logits.shape[1] == len(categories):
|
141 |
+
# Multi-label classification
|
142 |
+
predictions = torch.sigmoid(outputs.logits)
|
143 |
+
|
144 |
+
# Get the top categories
|
145 |
+
top_categories = []
|
146 |
+
for i, score in enumerate(predictions[0]):
|
147 |
+
if score > 0.3: # Lower threshold for e-commerce model
|
148 |
+
top_categories.append((categories[i], score.item()))
|
149 |
+
else:
|
150 |
+
# Single-label classification
|
151 |
+
probabilities = torch.softmax(outputs.logits, dim=1)
|
152 |
+
values, indices = torch.topk(probabilities, 3)
|
153 |
+
|
154 |
+
top_categories = []
|
155 |
+
for i, idx in enumerate(indices[0]):
|
156 |
+
if idx < len(categories):
|
157 |
+
top_categories.append((categories[idx.item()], values[0][i].item()))
|
158 |
+
else:
|
159 |
+
# Fallback for other model formats
|
160 |
+
predictions = torch.sigmoid(outputs[0])
|
161 |
+
|
162 |
+
# Get the top categories
|
163 |
+
top_categories = []
|
164 |
+
for i, score in enumerate(predictions[0]):
|
165 |
+
if score > 0.5:
|
166 |
+
top_categories.append((categories[i], score.item()))
|
167 |
|
168 |
# Sort by score
|
169 |
top_categories.sort(key=lambda x: x[1], reverse=True)
|