Spaces:
Sleeping
Sleeping
import json, time, csv, os | |
import gradio as gr | |
from transformers import pipeline | |
# ββββββββββββββββ | |
# Load taxonomies | |
# ββββββββββββββββ | |
with open("coarse_labels.json") as f: | |
coarse_labels = json.load(f) | |
with open("fine_labels.json") as f: | |
fine_map = json.load(f) | |
# ββββββββββββββββ | |
# Model choices (5 only) | |
# ββββββββββββββββ | |
MODEL_CHOICES = [ | |
"facebook/bart-large-mnli", | |
"roberta-large-mnli", | |
"joeddav/xlm-roberta-large-xnli", | |
"valhalla/distilbart-mnli-12-4", | |
"FractalAIResearch/Fathom-R1-14B" # placeholder β replace with real phantom model | |
] | |
# ββββββββββββββββ | |
# Ensure log files exist | |
# ββββββββββββββββ | |
LOG_FILE = "logs.csv" | |
FEEDBACK_FILE = "feedback.csv" | |
for fn, hdr in [ | |
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]), | |
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"]) | |
]: | |
if not os.path.exists(fn): | |
with open(fn, "w", newline="") as f: | |
csv.writer(f).writerow(hdr) | |
# ββββββββββββββββ | |
# Inference functions | |
# ββββββββββββββββ | |
def run_stage1(question, model_name): | |
"""Return top3 coarse subjects + duration.""" | |
start = time.time() | |
clf = pipeline("zero-shot-classification", model=model_name) | |
out = clf(question, candidate_labels=coarse_labels) | |
labels, scores = out["labels"][:3], out["scores"][:3] | |
duration = round(time.time()-start,3) | |
return labels, duration | |
def run_stage2(question, model_name, subject): | |
"""Return top3 fine topics + duration.""" | |
start = time.time() | |
clf = pipeline("zero-shot-classification", model=model_name) | |
fine_labels = fine_map.get(subject, []) | |
out = clf(question, candidate_labels=fine_labels) | |
labels, scores = out["labels"][:3], out["scores"][:3] | |
duration = round(time.time()-start,3) | |
# Log combined run | |
with open(LOG_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ | |
time.strftime("%Y-%m-%d %H:%M:%S"), | |
model_name, | |
question.replace("\n"," "), | |
subject, | |
";".join(labels), | |
duration | |
]) | |
return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"β± {duration}s" | |
def submit_feedback(question, subject_fb, topic_fb): | |
with open(FEEDBACK_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ | |
time.strftime("%Y-%m-%d %H:%M:%S"), | |
question.replace("\n"," "), | |
subject_fb, | |
topic_fb | |
]) | |
return "β Feedback recorded!" | |
# ββββββββββββββββ | |
# Build Gradio UI | |
# ββββββββββββββββ | |
with gr.Blocks() as demo: | |
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback") | |
with gr.Row(): | |
question_input = gr.Textbox(lines=3, label="Enter your question") | |
model_input = gr.Dropdown( | |
choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model" | |
) | |
go_button = gr.Button("Run Stage 1") | |
# Stage 1 outputs | |
subj_radio = gr.Radio(choices=[], label="Top-3 Subjects", | |
info="Select to re-run Stage 2 for a different subject") | |
stage1_time = gr.Textbox(label="Stage 1 Time") | |
go_button.click( | |
fn=lambda q,m: (*run_stage1(q,m),), | |
inputs=[question_input, model_input], | |
outputs=[subj_radio, stage1_time] | |
) | |
# Stage 2 UI | |
go2_button = gr.Button("Run Stage 2") | |
topics_out = gr.Label(label="Top-3 Topics") | |
stage2_time = gr.Textbox(label="Stage 2 Time") | |
go2_button.click( | |
fn=run_stage2, | |
inputs=[question_input, model_input, subj_radio], | |
outputs=[topics_out, stage2_time] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Feedback / Correction") | |
subject_fb = gr.Textbox(label="Correct Subject") | |
topic_fb = gr.Textbox(label="Correct Topic(s)") | |
fb_button = gr.Button("Submit Feedback") | |
fb_status = gr.Textbox(label="") | |
fb_button.click( | |
fn=submit_feedback, | |
inputs=[question_input, subject_fb, topic_fb], | |
outputs=[fb_status] | |
) | |
demo.launch(share=True, ssr=False) | |