File size: 4,427 Bytes
6c67d38
 
b90e13b
 
6c67d38
a389429
6c67d38
42f0920
6c67d38
42f0920
6c67d38
 
 
a389429
6c67d38
 
 
 
 
a389429
 
6c67d38
 
 
a389429
6c67d38
 
 
a389429
 
 
 
6c67d38
 
a389429
6c67d38
 
a389429
6c67d38
a389429
 
6c67d38
 
a389429
 
 
 
6c67d38
a389429
 
 
 
6c67d38
a389429
 
 
 
6c67d38
 
 
 
 
 
a389429
6c67d38
 
a389429
6c67d38
a389429
6c67d38
a389429
 
 
 
 
 
 
6c67d38
 
a389429
6c67d38
 
a389429
6c67d38
 
 
 
a389429
6c67d38
a389429
6c67d38
a389429
 
 
 
6c67d38
a389429
 
 
 
 
 
 
 
6c67d38
a389429
6c67d38
a389429
 
 
 
6c67d38
 
 
a389429
6c67d38
a389429
 
 
 
6c67d38
a389429
 
 
 
6c67d38
 
a389429
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import json, time, csv, os
import gradio as gr
from transformers import pipeline

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Load taxonomies
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with open("coarse_labels.json") as f:
    coarse_labels = json.load(f)
with open("fine_labels.json") as f:
    fine_map = json.load(f)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Model choices (5 only)
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
MODEL_CHOICES = [
    "facebook/bart-large-mnli",
    "roberta-large-mnli",
    "joeddav/xlm-roberta-large-xnli",
    "valhalla/distilbart-mnli-12-4",
    "educationfoundation/Phantom-7B-JEE"  # placeholder β€” replace with real phantom model
]

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Ensure log files exist
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
LOG_FILE      = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [
    (LOG_FILE,      ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
    (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
]:
    if not os.path.exists(fn):
        with open(fn, "w", newline="") as f:
            csv.writer(f).writerow(hdr)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Inference functions
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
def run_stage1(question, model_name):
    """Return top3 coarse subjects + duration."""
    start = time.time()
    clf = pipeline("zero-shot-classification", model=model_name)
    out = clf(question, candidate_labels=coarse_labels)
    labels, scores = out["labels"][:3], out["scores"][:3]
    duration = round(time.time()-start,3)
    return labels, duration

def run_stage2(question, model_name, subject):
    """Return top3 fine topics + duration."""
    start = time.time()
    clf = pipeline("zero-shot-classification", model=model_name)
    fine_labels = fine_map.get(subject, [])
    out = clf(question, candidate_labels=fine_labels)
    labels, scores = out["labels"][:3], out["scores"][:3]
    duration = round(time.time()-start,3)
    # Log combined run
    with open(LOG_FILE, "a", newline="") as f:
        csv.writer(f).writerow([
            time.strftime("%Y-%m-%d %H:%M:%S"),
            model_name,
            question.replace("\n"," "),
            subject,
            ";".join(labels),
            duration
        ])
    return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"⏱ {duration}s"

def submit_feedback(question, subject_fb, topic_fb):
    with open(FEEDBACK_FILE, "a", newline="") as f:
        csv.writer(f).writerow([
            time.strftime("%Y-%m-%d %H:%M:%S"),
            question.replace("\n"," "),
            subject_fb,
            topic_fb
        ])
    return "βœ… Feedback recorded!"

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Build Gradio UI
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
    gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")

    with gr.Row():
        question_input = gr.Textbox(lines=3, label="Enter your question")
        model_input    = gr.Dropdown(
            choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model"
        )
        go_button      = gr.Button("Run Stage 1")

    # Stage 1 outputs
    subj_radio     = gr.Radio(choices=[], label="Top-3 Subjects",
                              info="Select to re-run Stage 2 for a different subject")
    stage1_time    = gr.Textbox(label="Stage 1 Time")

    go_button.click(
        fn=lambda q,m: (*run_stage1(q,m),), 
        inputs=[question_input, model_input],
        outputs=[subj_radio, stage1_time]
    )

    # Stage 2 UI
    go2_button  = gr.Button("Run Stage 2")
    topics_out  = gr.Label(label="Top-3 Topics")
    stage2_time = gr.Textbox(label="Stage 2 Time")

    go2_button.click(
        fn=run_stage2,
        inputs=[question_input, model_input, subj_radio],
        outputs=[topics_out, stage2_time]
    )

    gr.Markdown("---")
    gr.Markdown("### Feedback / Correction")

    subject_fb = gr.Textbox(label="Correct Subject")
    topic_fb   = gr.Textbox(label="Correct Topic(s)")
    fb_button  = gr.Button("Submit Feedback")
    fb_status  = gr.Textbox(label="")

    fb_button.click(
        fn=submit_feedback,
        inputs=[question_input, subject_fb, topic_fb],
        outputs=[fb_status]
    )

demo.launch(share=True, ssr=False)