Spaces:
Sleeping
Sleeping
File size: 4,427 Bytes
6c67d38 b90e13b 6c67d38 a389429 6c67d38 42f0920 6c67d38 42f0920 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 6c67d38 a389429 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import json, time, csv, os
import gradio as gr
from transformers import pipeline
# ββββββββββββββββ
# Load taxonomies
# ββββββββββββββββ
with open("coarse_labels.json") as f:
coarse_labels = json.load(f)
with open("fine_labels.json") as f:
fine_map = json.load(f)
# ββββββββββββββββ
# Model choices (5 only)
# ββββββββββββββββ
MODEL_CHOICES = [
"facebook/bart-large-mnli",
"roberta-large-mnli",
"joeddav/xlm-roberta-large-xnli",
"valhalla/distilbart-mnli-12-4",
"educationfoundation/Phantom-7B-JEE" # placeholder β replace with real phantom model
]
# ββββββββββββββββ
# Ensure log files exist
# ββββββββββββββββ
LOG_FILE = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
]:
if not os.path.exists(fn):
with open(fn, "w", newline="") as f:
csv.writer(f).writerow(hdr)
# ββββββββββββββββ
# Inference functions
# ββββββββββββββββ
def run_stage1(question, model_name):
"""Return top3 coarse subjects + duration."""
start = time.time()
clf = pipeline("zero-shot-classification", model=model_name)
out = clf(question, candidate_labels=coarse_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time()-start,3)
return labels, duration
def run_stage2(question, model_name, subject):
"""Return top3 fine topics + duration."""
start = time.time()
clf = pipeline("zero-shot-classification", model=model_name)
fine_labels = fine_map.get(subject, [])
out = clf(question, candidate_labels=fine_labels)
labels, scores = out["labels"][:3], out["scores"][:3]
duration = round(time.time()-start,3)
# Log combined run
with open(LOG_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
model_name,
question.replace("\n"," "),
subject,
";".join(labels),
duration
])
return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"β± {duration}s"
def submit_feedback(question, subject_fb, topic_fb):
with open(FEEDBACK_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
question.replace("\n"," "),
subject_fb,
topic_fb
])
return "β
Feedback recorded!"
# ββββββββββββββββ
# Build Gradio UI
# ββββββββββββββββ
with gr.Blocks() as demo:
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")
with gr.Row():
question_input = gr.Textbox(lines=3, label="Enter your question")
model_input = gr.Dropdown(
choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model"
)
go_button = gr.Button("Run Stage 1")
# Stage 1 outputs
subj_radio = gr.Radio(choices=[], label="Top-3 Subjects",
info="Select to re-run Stage 2 for a different subject")
stage1_time = gr.Textbox(label="Stage 1 Time")
go_button.click(
fn=lambda q,m: (*run_stage1(q,m),),
inputs=[question_input, model_input],
outputs=[subj_radio, stage1_time]
)
# Stage 2 UI
go2_button = gr.Button("Run Stage 2")
topics_out = gr.Label(label="Top-3 Topics")
stage2_time = gr.Textbox(label="Stage 2 Time")
go2_button.click(
fn=run_stage2,
inputs=[question_input, model_input, subj_radio],
outputs=[topics_out, stage2_time]
)
gr.Markdown("---")
gr.Markdown("### Feedback / Correction")
subject_fb = gr.Textbox(label="Correct Subject")
topic_fb = gr.Textbox(label="Correct Topic(s)")
fb_button = gr.Button("Submit Feedback")
fb_status = gr.Textbox(label="")
fb_button.click(
fn=submit_feedback,
inputs=[question_input, subject_fb, topic_fb],
outputs=[fb_status]
)
demo.launch(share=True, ssr=False)
|