ZeroShotTagger / app.py
naveenus's picture
Update app.py
42f0920 verified
raw
history blame
1.37 kB
# app.py
import json, gradio as gr
from transformers import pipeline
# 1) Load taxonomies
with open("coarse_labels.json") as f:
coarse_labels = json.load(f) # :contentReference[oaicite:4]{index=4}
with open("fine_labels.json") as f:
fine_map = json.load(f) # :contentReference[oaicite:5]{index=5}
# 2) Init classifier
classifier = pipeline("zero-shot-classification",
model="facebook/bart-large-mnli")
# 3) Tagging fn
def hierarchical_tag(question):
# Stage 1: pick coarse subject
coarse_out = classifier(question, candidate_labels=coarse_labels)
chosen = coarse_out["labels"][0]
# Stage 2: fine-grained tags within that subject
fine_labels = fine_map.get(chosen, [])
fine_out = classifier(question, candidate_labels=fine_labels)
# Return both
return {
"Subject": chosen,
**{lbl: round(score,3)
for lbl, score in zip(fine_out["labels"], fine_out["scores"])}
}
# 4) Build UI
iface = gr.Interface(
fn=hierarchical_tag,
inputs=gr.Textbox(lines=3, label="Enter your question"),
outputs=gr.JSON(label="Hierarchical Tags"),
title="Two-Stage Zero-Shot Question Tagger",
description="Stage 1: classify subject; Stage 2: classify topic within subject."
)
if __name__=="__main__":
iface.launch()