Spaces:
Sleeping
Sleeping
File size: 1,278 Bytes
e5b7bb1 b90e13b e5b7bb1 b90e13b e5b7bb1 b90e13b e5b7bb1 b90e13b e5b7bb1 b90e13b e5b7bb1 b90e13b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# app.py
import json
import gradio as gr
from transformers import pipeline
# 1) Load base labels from JSON
with open("labels.json", "r") as f:
base_labels = json.load(f)
# 2) Prepare default textbox value
default_label_str = ", ".join(base_labels)
# 3) Initialize zero-shot classifier
classifier = pipeline(
task="zero-shot-classification",
model="facebook/bart-large-mnli"
)
# 4) Interface function that merges runtime labels
def tag_question(question: str, labels_str: str):
# Split & clean the user-supplied string
labels = [lbl.strip() for lbl in labels_str.split(",") if lbl.strip()]
# Zero-shot classify
out = classifier(question, candidate_labels=labels)
# Return top-3 labels with scores
return {lbl: round(score,3) for lbl, score in zip(out["labels"], out["scores"])}
# 5) Build the Gradio UI
iface = gr.Interface(
fn=tag_question,
inputs=[
gr.Textbox(lines=3, label="Question"),
gr.Textbox(lines=2, label="Candidate Labels (comma-separated)",
value=default_label_str)
],
outputs=gr.Label(num_top_classes=3),
title="Hybrid Zero-Shot Question Tagger",
description="Loaded labels from `labels.json`, editable at runtime."
)
if __name__ == "__main__":
iface.launch()
|