import json, time, csv, os import gradio as gr from transformers import pipeline # ———————————————— # Load taxonomies # ———————————————— with open("coarse_labels.json") as f: coarse_labels = json.load(f) with open("fine_labels.json") as f: fine_map = json.load(f) # ———————————————— # Model choices (5 only) # ———————————————— MODEL_CHOICES = [ "facebook/bart-large-mnli", "roberta-large-mnli", "joeddav/xlm-roberta-large-xnli", "valhalla/distilbart-mnli-12-4", "FractalAIResearch/Fathom-R1-14B" # placeholder — replace with real phantom model ] # ———————————————— # Ensure log files exist # ———————————————— LOG_FILE = "logs.csv" FEEDBACK_FILE = "feedback.csv" for fn, hdr in [ (LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]), (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"]) ]: if not os.path.exists(fn): with open(fn, "w", newline="") as f: csv.writer(f).writerow(hdr) # ———————————————— # Inference functions # ———————————————— def run_stage1(question, model_name): """Return top3 coarse subjects + duration.""" start = time.time() clf = pipeline("zero-shot-classification", model=model_name) out = clf(question, candidate_labels=coarse_labels) labels, scores = out["labels"][:3], out["scores"][:3] duration = round(time.time()-start,3) return labels, duration def run_stage2(question, model_name, subject): """Return top3 fine topics + duration.""" start = time.time() clf = pipeline("zero-shot-classification", model=model_name) fine_labels = fine_map.get(subject, []) out = clf(question, candidate_labels=fine_labels) labels, scores = out["labels"][:3], out["scores"][:3] duration = round(time.time()-start,3) # Log combined run with open(LOG_FILE, "a", newline="") as f: csv.writer(f).writerow([ time.strftime("%Y-%m-%d %H:%M:%S"), model_name, question.replace("\n"," "), subject, ";".join(labels), duration ]) return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"⏱ {duration}s" def submit_feedback(question, subject_fb, topic_fb): with open(FEEDBACK_FILE, "a", newline="") as f: csv.writer(f).writerow([ time.strftime("%Y-%m-%d %H:%M:%S"), question.replace("\n"," "), subject_fb, topic_fb ]) return "✅ Feedback recorded!" # ———————————————— # Build Gradio UI # ———————————————— with gr.Blocks() as demo: gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback") with gr.Row(): question_input = gr.Textbox(lines=3, label="Enter your question") model_input = gr.Dropdown( choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model" ) go_button = gr.Button("Run Stage 1") # Stage 1 outputs subj_radio = gr.Radio(choices=[], label="Top-3 Subjects", info="Select to re-run Stage 2 for a different subject") stage1_time = gr.Textbox(label="Stage 1 Time") go_button.click( fn=lambda q,m: (*run_stage1(q,m),), inputs=[question_input, model_input], outputs=[subj_radio, stage1_time] ) # Stage 2 UI go2_button = gr.Button("Run Stage 2") topics_out = gr.Label(label="Top-3 Topics") stage2_time = gr.Textbox(label="Stage 2 Time") go2_button.click( fn=run_stage2, inputs=[question_input, model_input, subj_radio], outputs=[topics_out, stage2_time] ) gr.Markdown("---") gr.Markdown("### Feedback / Correction") subject_fb = gr.Textbox(label="Correct Subject") topic_fb = gr.Textbox(label="Correct Topic(s)") fb_button = gr.Button("Submit Feedback") fb_status = gr.Textbox(label="") fb_button.click( fn=submit_feedback, inputs=[question_input, subject_fb, topic_fb], outputs=[fb_status] ) demo.launch(share=True, ssr=False)