File size: 4,772 Bytes
6c67d38
 
b90e13b
 
6c67d38
42f0920
6c67d38
42f0920
6c67d38
42f0920
6c67d38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import json, time, csv, os
import gradio as gr
from transformers import pipeline

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 1) Load taxonomies
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with open("coarse_labels.json") as f:
    coarse_labels = json.load(f)
with open("fine_labels.json") as f:
    fine_map = json.load(f)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 2) Available zero-shot models
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
MODEL_CHOICES = [
    "facebook/bart-large-mnli",
    "roberta-large-mnli",
    "joeddav/xlm-roberta-large-xnli",
    "mistralai/Mistral-7B-Instruct",
    "huggyllama/llama-2-7b-chat",
    "google/flan-t5-large",
    "google/flan-ul2",
    "clare-ai/llama-2-13b-instruct",
    "allenai/longformer-base-4096",
    "facebook/bart-large-mnli",  # duplicate to test allow_custom_value 
    "valhalla/t5-base-qa-qg-hl",
    "EleutherAI/gpt-neox-20b",
    "EleutherAI/gpt-j-6b",
    "bigscience/bloom-1b1",
    "bigscience/bloom-560m",
    "bigscience/bloom-3b",
    "Salesforce/codegen-2B-multi",
    "Salesforce/codegen-350M-multi",
    "madlag/llama2-7b-finetuned-qa",
    "tiiuae/falcon-7b-instruct",
    "tiiuae/falcon-40b-instruct",
    "milvus/milvus-embed-english",
    "sentence-transformers/all-MiniLM-L6-v2",
    "YOUR-OWN-CUSTOM-MODEL"
]

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# Helper: ensure log files exist
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
LOG_FILE      = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]),
                (FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]:
    if not os.path.exists(fn):
        with open(fn, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(hdr)

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 3) Build the interface logic
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”

def hierarchical_tag(question, model_name):
    start = time.time()
    # 3.1 Instantiate classifier per-run (to change models dynamically)
    clf = pipeline("zero-shot-classification", model=model_name)

    # 3.2 Stage 1: coarse label
    coarse_out = clf(question, candidate_labels=coarse_labels)
    subject = coarse_out["labels"][0]

    # 3.3 Stage 2: fine labels within chosen subject
    fine_labels = fine_map.get(subject, [])
    fine_out = clf(question, candidate_labels=fine_labels)
    top3 = fine_out["labels"][:3]

    duration = round(time.time() - start, 3)

    # 3.4 Log the run
    with open(LOG_FILE, "a", newline="") as f:
        csv.writer(f).writerow([
            time.strftime("%Y-%m-%d %H:%M:%S"),
            model_name,
            question.replace("\n"," "),
            subject,
            ";".join(top3),
            duration
        ])

    # 3.5 Return for display
    return subject, {lbl: round(score,3) 
                     for lbl,score in zip(fine_out["labels"][:3],
                                         fine_out["scores"][:3]
                                        )}, f"⏱ {duration}s"

def submit_feedback(question, subject, topics, corrected):
    ts = time.strftime("%Y-%m-%d %H:%M:%S")
    with open(FEEDBACK_FILE, "a", newline="") as f:
        csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected])
    return "Thank you for your feedback!"

# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
# 4) Define the Gradio UI
# β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
with gr.Blocks() as demo:
    gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging")

    with gr.Row():
        question_input = gr.Textbox(lines=3, label="Enter your question")
        model_input    = gr.Dropdown(
            label="Choose model",
            choices=MODEL_CHOICES,
            value=MODEL_CHOICES[0],
            allow_custom_value=True
        )

    run_button = gr.Button("Tag Question")

    subject_out = gr.Textbox(label="Predicted Subject")
    topics_out  = gr.Label(label="Top-3 Topics")
    time_out    = gr.Textbox(label="Inference Time")

    run_button.click(
        hierarchical_tag,
        inputs=[question_input, model_input],
        outputs=[subject_out, topics_out, time_out]
    )

    gr.Markdown("---")
    gr.Markdown("### Not quite right? Submit your corrections below:")

    corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3")
    feedback_button = gr.Button("Submit Feedback")
    feedback_status = gr.Textbox(label="")

    feedback_button.click(
        submit_feedback,
        inputs=[question_input, subject_out, topics_out, corrected_input],
        outputs=[feedback_status]
    )

demo.launch()