selvaonline's picture
Upload app.py with huggingface_hub
01a3727 verified
raw
history blame
2.76 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
# Load the model and tokenizer
model_id = "selvaonline/shopping-assistant"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id)
# Load the categories
try:
from huggingface_hub import hf_hub_download
categories_path = hf_hub_download(repo_id=model_id, filename="categories.json")
with open(categories_path, "r") as f:
categories = json.load(f)
except Exception as e:
print(f"Error loading categories: {str(e)}")
categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
def classify_text(text):
"""
Classify the text using the model
"""
# Prepare the input
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Get the model prediction
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.sigmoid(outputs.logits)
# Get the top categories
top_categories = []
for i, score in enumerate(predictions[0]):
if score > 0.5: # Threshold for multi-label classification
top_categories.append((categories[i], score.item()))
# Sort by score
top_categories.sort(key=lambda x: x[1], reverse=True)
# Format the results
if top_categories:
result = f"Top categories for '{text}':\n\n"
for category, score in top_categories:
result += f"- {category}: {score:.4f}\n"
result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category."
else:
result = f"No categories found for '{text}'. Please try a different query."
return result
# Create the Gradio interface
demo = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(
lines=2,
placeholder="Enter your shopping query here...",
label="Shopping Query"
),
outputs=gr.Markdown(label="Results"),
title="Shopping Assistant",
description="""
This demo shows how to use the Shopping Assistant model to classify shopping queries into categories.
Enter a shopping query below to see which categories it belongs to.
Examples:
- "I'm looking for headphones"
- "Do you have any kitchen appliance deals?"
- "Show me the best laptop deals"
- "I need a new smart TV"
""",
examples=[
["I'm looking for headphones"],
["Do you have any kitchen appliance deals?"],
["Show me the best laptop deals"],
["I need a new smart TV"]
],
theme=gr.themes.Soft()
)
# Launch the app
if __name__ == "__main__":
demo.launch()