|
|
|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import json |
|
|
|
|
|
model_id = "selvaonline/shopping-assistant" |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
categories_path = hf_hub_download(repo_id=model_id, filename="categories.json") |
|
with open(categories_path, "r") as f: |
|
categories = json.load(f) |
|
except Exception as e: |
|
print(f"Error loading categories: {str(e)}") |
|
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"] |
|
|
|
def classify_text(text): |
|
""" |
|
Classify the text using the model |
|
""" |
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
predictions = torch.sigmoid(outputs.logits) |
|
|
|
|
|
top_categories = [] |
|
for i, score in enumerate(predictions[0]): |
|
if score > 0.5: |
|
top_categories.append((categories[i], score.item())) |
|
|
|
|
|
top_categories.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
if top_categories: |
|
result = f"Top categories for '{text}':\n\n" |
|
for category, score in top_categories: |
|
result += f"- {category}: {score:.4f}\n" |
|
|
|
result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category." |
|
else: |
|
result = f"No categories found for '{text}'. Please try a different query." |
|
|
|
return result |
|
|
|
|
|
demo = gr.Interface( |
|
fn=classify_text, |
|
inputs=gr.Textbox( |
|
lines=2, |
|
placeholder="Enter your shopping query here...", |
|
label="Shopping Query" |
|
), |
|
outputs=gr.Markdown(label="Results"), |
|
title="Shopping Assistant", |
|
description=""" |
|
This demo shows how to use the Shopping Assistant model to classify shopping queries into categories. |
|
Enter a shopping query below to see which categories it belongs to. |
|
|
|
Examples: |
|
- "I'm looking for headphones" |
|
- "Do you have any kitchen appliance deals?" |
|
- "Show me the best laptop deals" |
|
- "I need a new smart TV" |
|
""", |
|
examples=[ |
|
["I'm looking for headphones"], |
|
["Do you have any kitchen appliance deals?"], |
|
["Show me the best laptop deals"], |
|
["I need a new smart TV"] |
|
], |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|