ZeroShotTagger / app.py
naveenus's picture
Update app.py
6c67d38 verified
raw
history blame
4.77 kB
import json, time, csv, os
import gradio as gr
from transformers import pipeline
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) Load taxonomies
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with open("coarse_labels.json") as f:
coarse_labels = json.load(f)
with open("fine_labels.json") as f:
fine_map = json.load(f)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) Available zero-shot models
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
MODEL_CHOICES = [
"facebook/bart-large-mnli",
"roberta-large-mnli",
"joeddav/xlm-roberta-large-xnli",
"mistralai/Mistral-7B-Instruct",
"huggyllama/llama-2-7b-chat",
"google/flan-t5-large",
"google/flan-ul2",
"clare-ai/llama-2-13b-instruct",
"allenai/longformer-base-4096",
"facebook/bart-large-mnli", # duplicate to test allow_custom_value
"valhalla/t5-base-qa-qg-hl",
"EleutherAI/gpt-neox-20b",
"EleutherAI/gpt-j-6b",
"bigscience/bloom-1b1",
"bigscience/bloom-560m",
"bigscience/bloom-3b",
"Salesforce/codegen-2B-multi",
"Salesforce/codegen-350M-multi",
"madlag/llama2-7b-finetuned-qa",
"tiiuae/falcon-7b-instruct",
"tiiuae/falcon-40b-instruct",
"milvus/milvus-embed-english",
"sentence-transformers/all-MiniLM-L6-v2",
"YOUR-OWN-CUSTOM-MODEL"
]
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Helper: ensure log files exist
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
LOG_FILE = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]),
(FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]:
if not os.path.exists(fn):
with open(fn, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(hdr)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) Build the interface logic
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def hierarchical_tag(question, model_name):
start = time.time()
# 3.1 Instantiate classifier per-run (to change models dynamically)
clf = pipeline("zero-shot-classification", model=model_name)
# 3.2 Stage 1: coarse label
coarse_out = clf(question, candidate_labels=coarse_labels)
subject = coarse_out["labels"][0]
# 3.3 Stage 2: fine labels within chosen subject
fine_labels = fine_map.get(subject, [])
fine_out = clf(question, candidate_labels=fine_labels)
top3 = fine_out["labels"][:3]
duration = round(time.time() - start, 3)
# 3.4 Log the run
with open(LOG_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
model_name,
question.replace("\n"," "),
subject,
";".join(top3),
duration
])
# 3.5 Return for display
return subject, {lbl: round(score,3)
for lbl,score in zip(fine_out["labels"][:3],
fine_out["scores"][:3]
)}, f"⏱ {duration}s"
def submit_feedback(question, subject, topics, corrected):
ts = time.strftime("%Y-%m-%d %H:%M:%S")
with open(FEEDBACK_FILE, "a", newline="") as f:
csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected])
return "Thank you for your feedback!"
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 4) Define the Gradio UI
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging")
with gr.Row():
question_input = gr.Textbox(lines=3, label="Enter your question")
model_input = gr.Dropdown(
label="Choose model",
choices=MODEL_CHOICES,
value=MODEL_CHOICES[0],
allow_custom_value=True
)
run_button = gr.Button("Tag Question")
subject_out = gr.Textbox(label="Predicted Subject")
topics_out = gr.Label(label="Top-3 Topics")
time_out = gr.Textbox(label="Inference Time")
run_button.click(
hierarchical_tag,
inputs=[question_input, model_input],
outputs=[subject_out, topics_out, time_out]
)
gr.Markdown("---")
gr.Markdown("### Not quite right? Submit your corrections below:")
corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3")
feedback_button = gr.Button("Submit Feedback")
feedback_status = gr.Textbox(label="")
feedback_button.click(
submit_feedback,
inputs=[question_input, subject_out, topics_out, corrected_input],
outputs=[feedback_status]
)
demo.launch()