Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
|
|
3 |
from transformers import pipeline
|
4 |
|
5 |
# ββββββββββββββββ
|
6 |
-
#
|
7 |
# ββββββββββββββββ
|
8 |
with open("coarse_labels.json") as f:
|
9 |
coarse_labels = json.load(f)
|
@@ -11,128 +11,118 @@ with open("fine_labels.json") as f:
|
|
11 |
fine_map = json.load(f)
|
12 |
|
13 |
# ββββββββββββββββ
|
14 |
-
#
|
15 |
# ββββββββββββββββ
|
16 |
MODEL_CHOICES = [
|
17 |
"facebook/bart-large-mnli",
|
18 |
"roberta-large-mnli",
|
19 |
"joeddav/xlm-roberta-large-xnli",
|
20 |
-
"
|
21 |
-
"
|
22 |
-
"google/flan-t5-large",
|
23 |
-
"google/flan-ul2",
|
24 |
-
"clare-ai/llama-2-13b-instruct",
|
25 |
-
"allenai/longformer-base-4096",
|
26 |
-
"facebook/bart-large-mnli", # duplicate to test allow_custom_value
|
27 |
-
"valhalla/t5-base-qa-qg-hl",
|
28 |
-
"EleutherAI/gpt-neox-20b",
|
29 |
-
"EleutherAI/gpt-j-6b",
|
30 |
-
"bigscience/bloom-1b1",
|
31 |
-
"bigscience/bloom-560m",
|
32 |
-
"bigscience/bloom-3b",
|
33 |
-
"Salesforce/codegen-2B-multi",
|
34 |
-
"Salesforce/codegen-350M-multi",
|
35 |
-
"madlag/llama2-7b-finetuned-qa",
|
36 |
-
"tiiuae/falcon-7b-instruct",
|
37 |
-
"tiiuae/falcon-40b-instruct",
|
38 |
-
"milvus/milvus-embed-english",
|
39 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
40 |
-
"YOUR-OWN-CUSTOM-MODEL"
|
41 |
]
|
42 |
|
43 |
# ββββββββββββββββ
|
44 |
-
#
|
45 |
# ββββββββββββββββ
|
46 |
LOG_FILE = "logs.csv"
|
47 |
FEEDBACK_FILE = "feedback.csv"
|
48 |
-
for fn, hdr in [
|
49 |
-
|
|
|
|
|
50 |
if not os.path.exists(fn):
|
51 |
with open(fn, "w", newline="") as f:
|
52 |
-
|
53 |
-
writer.writerow(hdr)
|
54 |
|
55 |
# ββββββββββββββββ
|
56 |
-
#
|
57 |
# ββββββββββββββββ
|
58 |
-
|
59 |
-
|
60 |
start = time.time()
|
61 |
-
# 3.1 Instantiate classifier per-run (to change models dynamically)
|
62 |
clf = pipeline("zero-shot-classification", model=model_name)
|
|
|
|
|
|
|
|
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
# 3.3 Stage 2: fine labels within chosen subject
|
69 |
fine_labels = fine_map.get(subject, [])
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
# 3.4 Log the run
|
76 |
with open(LOG_FILE, "a", newline="") as f:
|
77 |
csv.writer(f).writerow([
|
78 |
time.strftime("%Y-%m-%d %H:%M:%S"),
|
79 |
model_name,
|
80 |
question.replace("\n"," "),
|
81 |
subject,
|
82 |
-
";".join(
|
83 |
duration
|
84 |
])
|
|
|
85 |
|
86 |
-
|
87 |
-
return subject, {lbl: round(score,3)
|
88 |
-
for lbl,score in zip(fine_out["labels"][:3],
|
89 |
-
fine_out["scores"][:3]
|
90 |
-
)}, f"β± {duration}s"
|
91 |
-
|
92 |
-
def submit_feedback(question, subject, topics, corrected):
|
93 |
-
ts = time.strftime("%Y-%m-%d %H:%M:%S")
|
94 |
with open(FEEDBACK_FILE, "a", newline="") as f:
|
95 |
-
csv.writer(f).writerow([
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# ββββββββββββββββ
|
99 |
-
#
|
100 |
# ββββββββββββββββ
|
101 |
with gr.Blocks() as demo:
|
102 |
-
gr.Markdown("## Hierarchical Zero-Shot Tagger with
|
103 |
|
104 |
with gr.Row():
|
105 |
question_input = gr.Textbox(lines=3, label="Enter your question")
|
106 |
model_input = gr.Dropdown(
|
107 |
-
label="Choose model"
|
108 |
-
choices=MODEL_CHOICES,
|
109 |
-
value=MODEL_CHOICES[0],
|
110 |
-
allow_custom_value=True
|
111 |
)
|
|
|
112 |
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
topics_out = gr.Label(label="Top-3 Topics")
|
117 |
-
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
inputs=[question_input, model_input],
|
122 |
-
outputs=[
|
123 |
)
|
124 |
|
125 |
gr.Markdown("---")
|
126 |
-
gr.Markdown("###
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
131 |
|
132 |
-
|
133 |
-
submit_feedback,
|
134 |
-
inputs=[question_input,
|
135 |
-
outputs=[
|
136 |
)
|
137 |
|
138 |
-
demo.launch()
|
|
|
3 |
from transformers import pipeline
|
4 |
|
5 |
# ββββββββββββββββ
|
6 |
+
# Load taxonomies
|
7 |
# ββββββββββββββββ
|
8 |
with open("coarse_labels.json") as f:
|
9 |
coarse_labels = json.load(f)
|
|
|
11 |
fine_map = json.load(f)
|
12 |
|
13 |
# ββββββββββββββββ
|
14 |
+
# Model choices (5 only)
|
15 |
# ββββββββββββββββ
|
16 |
MODEL_CHOICES = [
|
17 |
"facebook/bart-large-mnli",
|
18 |
"roberta-large-mnli",
|
19 |
"joeddav/xlm-roberta-large-xnli",
|
20 |
+
"valhalla/distilbart-mnli-12-4",
|
21 |
+
"educationfoundation/Phantom-7B-JEE" # placeholder β replace with real phantom model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
]
|
23 |
|
24 |
# ββββββββββββββββ
|
25 |
+
# Ensure log files exist
|
26 |
# ββββββββββββββββ
|
27 |
LOG_FILE = "logs.csv"
|
28 |
FEEDBACK_FILE = "feedback.csv"
|
29 |
+
for fn, hdr in [
|
30 |
+
(LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
|
31 |
+
(FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
|
32 |
+
]:
|
33 |
if not os.path.exists(fn):
|
34 |
with open(fn, "w", newline="") as f:
|
35 |
+
csv.writer(f).writerow(hdr)
|
|
|
36 |
|
37 |
# ββββββββββββββββ
|
38 |
+
# Inference functions
|
39 |
# ββββββββββββββββ
|
40 |
+
def run_stage1(question, model_name):
|
41 |
+
"""Return top3 coarse subjects + duration."""
|
42 |
start = time.time()
|
|
|
43 |
clf = pipeline("zero-shot-classification", model=model_name)
|
44 |
+
out = clf(question, candidate_labels=coarse_labels)
|
45 |
+
labels, scores = out["labels"][:3], out["scores"][:3]
|
46 |
+
duration = round(time.time()-start,3)
|
47 |
+
return labels, duration
|
48 |
|
49 |
+
def run_stage2(question, model_name, subject):
|
50 |
+
"""Return top3 fine topics + duration."""
|
51 |
+
start = time.time()
|
52 |
+
clf = pipeline("zero-shot-classification", model=model_name)
|
|
|
53 |
fine_labels = fine_map.get(subject, [])
|
54 |
+
out = clf(question, candidate_labels=fine_labels)
|
55 |
+
labels, scores = out["labels"][:3], out["scores"][:3]
|
56 |
+
duration = round(time.time()-start,3)
|
57 |
+
# Log combined run
|
|
|
|
|
58 |
with open(LOG_FILE, "a", newline="") as f:
|
59 |
csv.writer(f).writerow([
|
60 |
time.strftime("%Y-%m-%d %H:%M:%S"),
|
61 |
model_name,
|
62 |
question.replace("\n"," "),
|
63 |
subject,
|
64 |
+
";".join(labels),
|
65 |
duration
|
66 |
])
|
67 |
+
return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"β± {duration}s"
|
68 |
|
69 |
+
def submit_feedback(question, subject_fb, topic_fb):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
with open(FEEDBACK_FILE, "a", newline="") as f:
|
71 |
+
csv.writer(f).writerow([
|
72 |
+
time.strftime("%Y-%m-%d %H:%M:%S"),
|
73 |
+
question.replace("\n"," "),
|
74 |
+
subject_fb,
|
75 |
+
topic_fb
|
76 |
+
])
|
77 |
+
return "β
Feedback recorded!"
|
78 |
|
79 |
# ββββββββββββββββ
|
80 |
+
# Build Gradio UI
|
81 |
# ββββββββββββββββ
|
82 |
with gr.Blocks() as demo:
|
83 |
+
gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")
|
84 |
|
85 |
with gr.Row():
|
86 |
question_input = gr.Textbox(lines=3, label="Enter your question")
|
87 |
model_input = gr.Dropdown(
|
88 |
+
choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model"
|
|
|
|
|
|
|
89 |
)
|
90 |
+
go_button = gr.Button("Run Stage 1")
|
91 |
|
92 |
+
# Stage 1 outputs
|
93 |
+
subj_radio = gr.Radio(choices=[], label="Top-3 Subjects",
|
94 |
+
info="Select to re-run Stage 2 for a different subject")
|
95 |
+
stage1_time = gr.Textbox(label="Stage 1 Time")
|
96 |
|
97 |
+
go_button.click(
|
98 |
+
fn=lambda q,m: (*run_stage1(q,m),),
|
99 |
+
inputs=[question_input, model_input],
|
100 |
+
outputs=[subj_radio, stage1_time]
|
101 |
+
)
|
102 |
+
|
103 |
+
# Stage 2 UI
|
104 |
+
go2_button = gr.Button("Run Stage 2")
|
105 |
topics_out = gr.Label(label="Top-3 Topics")
|
106 |
+
stage2_time = gr.Textbox(label="Stage 2 Time")
|
107 |
|
108 |
+
go2_button.click(
|
109 |
+
fn=run_stage2,
|
110 |
+
inputs=[question_input, model_input, subj_radio],
|
111 |
+
outputs=[topics_out, stage2_time]
|
112 |
)
|
113 |
|
114 |
gr.Markdown("---")
|
115 |
+
gr.Markdown("### Feedback / Correction")
|
116 |
|
117 |
+
subject_fb = gr.Textbox(label="Correct Subject")
|
118 |
+
topic_fb = gr.Textbox(label="Correct Topic(s)")
|
119 |
+
fb_button = gr.Button("Submit Feedback")
|
120 |
+
fb_status = gr.Textbox(label="")
|
121 |
|
122 |
+
fb_button.click(
|
123 |
+
fn=submit_feedback,
|
124 |
+
inputs=[question_input, subject_fb, topic_fb],
|
125 |
+
outputs=[fb_status]
|
126 |
)
|
127 |
|
128 |
+
demo.launch(share=True, ssr=False)
|