import json, time, csv, os import gradio as gr from transformers import pipeline # ———————————————— # 1) Load taxonomies # ———————————————— with open("coarse_labels.json") as f: coarse_labels = json.load(f) with open("fine_labels.json") as f: fine_map = json.load(f) # ———————————————— # 2) Available zero-shot models # ———————————————— MODEL_CHOICES = [ "facebook/bart-large-mnli", "roberta-large-mnli", "joeddav/xlm-roberta-large-xnli", "mistralai/Mistral-7B-Instruct", "huggyllama/llama-2-7b-chat", "google/flan-t5-large", "google/flan-ul2", "clare-ai/llama-2-13b-instruct", "allenai/longformer-base-4096", "facebook/bart-large-mnli", # duplicate to test allow_custom_value "valhalla/t5-base-qa-qg-hl", "EleutherAI/gpt-neox-20b", "EleutherAI/gpt-j-6b", "bigscience/bloom-1b1", "bigscience/bloom-560m", "bigscience/bloom-3b", "Salesforce/codegen-2B-multi", "Salesforce/codegen-350M-multi", "madlag/llama2-7b-finetuned-qa", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b-instruct", "milvus/milvus-embed-english", "sentence-transformers/all-MiniLM-L6-v2", "YOUR-OWN-CUSTOM-MODEL" ] # ———————————————— # Helper: ensure log files exist # ———————————————— LOG_FILE = "logs.csv" FEEDBACK_FILE = "feedback.csv" for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]), (FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]: if not os.path.exists(fn): with open(fn, "w", newline="") as f: writer = csv.writer(f) writer.writerow(hdr) # ———————————————— # 3) Build the interface logic # ———————————————— def hierarchical_tag(question, model_name): start = time.time() # 3.1 Instantiate classifier per-run (to change models dynamically) clf = pipeline("zero-shot-classification", model=model_name) # 3.2 Stage 1: coarse label coarse_out = clf(question, candidate_labels=coarse_labels) subject = coarse_out["labels"][0] # 3.3 Stage 2: fine labels within chosen subject fine_labels = fine_map.get(subject, []) fine_out = clf(question, candidate_labels=fine_labels) top3 = fine_out["labels"][:3] duration = round(time.time() - start, 3) # 3.4 Log the run with open(LOG_FILE, "a", newline="") as f: csv.writer(f).writerow([ time.strftime("%Y-%m-%d %H:%M:%S"), model_name, question.replace("\n"," "), subject, ";".join(top3), duration ]) # 3.5 Return for display return subject, {lbl: round(score,3) for lbl,score in zip(fine_out["labels"][:3], fine_out["scores"][:3] )}, f"⏱ {duration}s" def submit_feedback(question, subject, topics, corrected): ts = time.strftime("%Y-%m-%d %H:%M:%S") with open(FEEDBACK_FILE, "a", newline="") as f: csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected]) return "Thank you for your feedback!" # ———————————————— # 4) Define the Gradio UI # ———————————————— with gr.Blocks() as demo: gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging") with gr.Row(): question_input = gr.Textbox(lines=3, label="Enter your question") model_input = gr.Dropdown( label="Choose model", choices=MODEL_CHOICES, value=MODEL_CHOICES[0], allow_custom_value=True ) run_button = gr.Button("Tag Question") subject_out = gr.Textbox(label="Predicted Subject") topics_out = gr.Label(label="Top-3 Topics") time_out = gr.Textbox(label="Inference Time") run_button.click( hierarchical_tag, inputs=[question_input, model_input], outputs=[subject_out, topics_out, time_out] ) gr.Markdown("---") gr.Markdown("### Not quite right? Submit your corrections below:") corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3") feedback_button = gr.Button("Submit Feedback") feedback_status = gr.Textbox(label="") feedback_button.click( submit_feedback, inputs=[question_input, subject_out, topics_out, corrected_input], outputs=[feedback_status] ) demo.launch()