selvaonline commited on
Commit
e667020
·
verified ·
1 Parent(s): 6f89f62

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +62 -18
app.py CHANGED
@@ -83,20 +83,40 @@ def process_deals_data(deals_data):
83
 
84
  return processed_deals
85
 
86
- # Load the model and tokenizer
87
- model_id = "selvaonline/shopping-assistant"
88
- tokenizer = AutoTokenizer.from_pretrained(model_id)
89
- model = AutoModelForSequenceClassification.from_pretrained(model_id)
90
-
91
- # Load the categories
92
  try:
93
- from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
95
  with open(categories_path, "r") as f:
96
- categories = json.load(f)
97
- except Exception as e:
98
- print(f"Error loading categories: {str(e)}")
99
- categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
100
 
101
  # Global variable to store deals data
102
  deals_cache = None
@@ -113,13 +133,37 @@ def classify_text(text, fetch_deals=True):
113
  # Get the model prediction
114
  with torch.no_grad():
115
  outputs = model(**inputs)
116
- predictions = torch.sigmoid(outputs.logits)
117
-
118
- # Get the top categories
119
- top_categories = []
120
- for i, score in enumerate(predictions[0]):
121
- if score > 0.5: # Threshold for multi-label classification
122
- top_categories.append((categories[i], score.item()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Sort by score
125
  top_categories.sort(key=lambda x: x[1], reverse=True)
 
83
 
84
  return processed_deals
85
 
86
+ # Load the e-commerce specific model and tokenizer
 
 
 
 
 
87
  try:
88
+ # Try to load the e-commerce BERT model
89
+ tokenizer = AutoTokenizer.from_pretrained("prithivida/ecommerce-bert-base-uncased")
90
+ model = AutoModelForSequenceClassification.from_pretrained("prithivida/ecommerce-bert-base-uncased")
91
+
92
+ # E-commerce BERT categories
93
+ categories = [
94
+ "electronics", "computers", "mobile_phones", "accessories",
95
+ "clothing", "footwear", "watches", "jewelry",
96
+ "home", "kitchen", "furniture", "decor",
97
+ "beauty", "personal_care", "health", "wellness",
98
+ "toys", "games", "sports", "outdoors",
99
+ "books", "stationery", "music", "movies"
100
+ ]
101
+ print("Using e-commerce BERT model")
102
+ except Exception as e:
103
+ # Fall back to local model if e-commerce BERT fails to load
104
+ print(f"Error loading e-commerce BERT model: {str(e)}")
105
+ print("Falling back to local model")
106
+
107
+ model_id = "selvaonline/shopping-assistant"
108
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
109
+ model = AutoModelForSequenceClassification.from_pretrained(model_id)
110
+
111
+ # Load the local categories
112
+ try:
113
+ from huggingface_hub import hf_hub_download
114
  categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
115
  with open(categories_path, "r") as f:
116
+ categories = json.load(f)
117
+ except Exception as e:
118
+ print(f"Error loading categories: {str(e)}")
119
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
120
 
121
  # Global variable to store deals data
122
  deals_cache = None
 
133
  # Get the model prediction
134
  with torch.no_grad():
135
  outputs = model(**inputs)
136
+
137
+ # Handle different model output formats
138
+ if hasattr(outputs, 'logits'):
139
+ # For models that return logits
140
+ if outputs.logits.shape[1] == len(categories):
141
+ # Multi-label classification
142
+ predictions = torch.sigmoid(outputs.logits)
143
+
144
+ # Get the top categories
145
+ top_categories = []
146
+ for i, score in enumerate(predictions[0]):
147
+ if score > 0.3: # Lower threshold for e-commerce model
148
+ top_categories.append((categories[i], score.item()))
149
+ else:
150
+ # Single-label classification
151
+ probabilities = torch.softmax(outputs.logits, dim=1)
152
+ values, indices = torch.topk(probabilities, 3)
153
+
154
+ top_categories = []
155
+ for i, idx in enumerate(indices[0]):
156
+ if idx < len(categories):
157
+ top_categories.append((categories[idx.item()], values[0][i].item()))
158
+ else:
159
+ # Fallback for other model formats
160
+ predictions = torch.sigmoid(outputs[0])
161
+
162
+ # Get the top categories
163
+ top_categories = []
164
+ for i, score in enumerate(predictions[0]):
165
+ if score > 0.5:
166
+ top_categories.append((categories[i], score.item()))
167
 
168
  # Sort by score
169
  top_categories.sort(key=lambda x: x[1], reverse=True)