naveenus commited on
Commit
a389429
Β·
verified Β·
1 Parent(s): 6c67d38

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -76
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from transformers import pipeline
4
 
5
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
6
- # 1) Load taxonomies
7
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
8
  with open("coarse_labels.json") as f:
9
  coarse_labels = json.load(f)
@@ -11,128 +11,118 @@ with open("fine_labels.json") as f:
11
  fine_map = json.load(f)
12
 
13
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
14
- # 2) Available zero-shot models
15
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
16
  MODEL_CHOICES = [
17
  "facebook/bart-large-mnli",
18
  "roberta-large-mnli",
19
  "joeddav/xlm-roberta-large-xnli",
20
- "mistralai/Mistral-7B-Instruct",
21
- "huggyllama/llama-2-7b-chat",
22
- "google/flan-t5-large",
23
- "google/flan-ul2",
24
- "clare-ai/llama-2-13b-instruct",
25
- "allenai/longformer-base-4096",
26
- "facebook/bart-large-mnli", # duplicate to test allow_custom_value
27
- "valhalla/t5-base-qa-qg-hl",
28
- "EleutherAI/gpt-neox-20b",
29
- "EleutherAI/gpt-j-6b",
30
- "bigscience/bloom-1b1",
31
- "bigscience/bloom-560m",
32
- "bigscience/bloom-3b",
33
- "Salesforce/codegen-2B-multi",
34
- "Salesforce/codegen-350M-multi",
35
- "madlag/llama2-7b-finetuned-qa",
36
- "tiiuae/falcon-7b-instruct",
37
- "tiiuae/falcon-40b-instruct",
38
- "milvus/milvus-embed-english",
39
- "sentence-transformers/all-MiniLM-L6-v2",
40
- "YOUR-OWN-CUSTOM-MODEL"
41
  ]
42
 
43
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
44
- # Helper: ensure log files exist
45
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
46
  LOG_FILE = "logs.csv"
47
  FEEDBACK_FILE = "feedback.csv"
48
- for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]),
49
- (FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]:
 
 
50
  if not os.path.exists(fn):
51
  with open(fn, "w", newline="") as f:
52
- writer = csv.writer(f)
53
- writer.writerow(hdr)
54
 
55
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
56
- # 3) Build the interface logic
57
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
58
-
59
- def hierarchical_tag(question, model_name):
60
  start = time.time()
61
- # 3.1 Instantiate classifier per-run (to change models dynamically)
62
  clf = pipeline("zero-shot-classification", model=model_name)
 
 
 
 
63
 
64
- # 3.2 Stage 1: coarse label
65
- coarse_out = clf(question, candidate_labels=coarse_labels)
66
- subject = coarse_out["labels"][0]
67
-
68
- # 3.3 Stage 2: fine labels within chosen subject
69
  fine_labels = fine_map.get(subject, [])
70
- fine_out = clf(question, candidate_labels=fine_labels)
71
- top3 = fine_out["labels"][:3]
72
-
73
- duration = round(time.time() - start, 3)
74
-
75
- # 3.4 Log the run
76
  with open(LOG_FILE, "a", newline="") as f:
77
  csv.writer(f).writerow([
78
  time.strftime("%Y-%m-%d %H:%M:%S"),
79
  model_name,
80
  question.replace("\n"," "),
81
  subject,
82
- ";".join(top3),
83
  duration
84
  ])
 
85
 
86
- # 3.5 Return for display
87
- return subject, {lbl: round(score,3)
88
- for lbl,score in zip(fine_out["labels"][:3],
89
- fine_out["scores"][:3]
90
- )}, f"⏱ {duration}s"
91
-
92
- def submit_feedback(question, subject, topics, corrected):
93
- ts = time.strftime("%Y-%m-%d %H:%M:%S")
94
  with open(FEEDBACK_FILE, "a", newline="") as f:
95
- csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected])
96
- return "Thank you for your feedback!"
 
 
 
 
 
97
 
98
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
99
- # 4) Define the Gradio UI
100
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
101
  with gr.Blocks() as demo:
102
- gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging")
103
 
104
  with gr.Row():
105
  question_input = gr.Textbox(lines=3, label="Enter your question")
106
  model_input = gr.Dropdown(
107
- label="Choose model",
108
- choices=MODEL_CHOICES,
109
- value=MODEL_CHOICES[0],
110
- allow_custom_value=True
111
  )
 
112
 
113
- run_button = gr.Button("Tag Question")
 
 
 
114
 
115
- subject_out = gr.Textbox(label="Predicted Subject")
 
 
 
 
 
 
 
116
  topics_out = gr.Label(label="Top-3 Topics")
117
- time_out = gr.Textbox(label="Inference Time")
118
 
119
- run_button.click(
120
- hierarchical_tag,
121
- inputs=[question_input, model_input],
122
- outputs=[subject_out, topics_out, time_out]
123
  )
124
 
125
  gr.Markdown("---")
126
- gr.Markdown("### Not quite right? Submit your corrections below:")
127
 
128
- corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3")
129
- feedback_button = gr.Button("Submit Feedback")
130
- feedback_status = gr.Textbox(label="")
 
131
 
132
- feedback_button.click(
133
- submit_feedback,
134
- inputs=[question_input, subject_out, topics_out, corrected_input],
135
- outputs=[feedback_status]
136
  )
137
 
138
- demo.launch()
 
3
  from transformers import pipeline
4
 
5
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
6
+ # Load taxonomies
7
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
8
  with open("coarse_labels.json") as f:
9
  coarse_labels = json.load(f)
 
11
  fine_map = json.load(f)
12
 
13
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
14
+ # Model choices (5 only)
15
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
16
  MODEL_CHOICES = [
17
  "facebook/bart-large-mnli",
18
  "roberta-large-mnli",
19
  "joeddav/xlm-roberta-large-xnli",
20
+ "valhalla/distilbart-mnli-12-4",
21
+ "educationfoundation/Phantom-7B-JEE" # placeholder β€” replace with real phantom model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ]
23
 
24
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
25
+ # Ensure log files exist
26
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
27
  LOG_FILE = "logs.csv"
28
  FEEDBACK_FILE = "feedback.csv"
29
+ for fn, hdr in [
30
+ (LOG_FILE, ["timestamp","model","question","chosen_subject","top3_topics","duration"]),
31
+ (FEEDBACK_FILE, ["timestamp","question","subject_feedback","topic_feedback"])
32
+ ]:
33
  if not os.path.exists(fn):
34
  with open(fn, "w", newline="") as f:
35
+ csv.writer(f).writerow(hdr)
 
36
 
37
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
38
+ # Inference functions
39
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
40
+ def run_stage1(question, model_name):
41
+ """Return top3 coarse subjects + duration."""
42
  start = time.time()
 
43
  clf = pipeline("zero-shot-classification", model=model_name)
44
+ out = clf(question, candidate_labels=coarse_labels)
45
+ labels, scores = out["labels"][:3], out["scores"][:3]
46
+ duration = round(time.time()-start,3)
47
+ return labels, duration
48
 
49
+ def run_stage2(question, model_name, subject):
50
+ """Return top3 fine topics + duration."""
51
+ start = time.time()
52
+ clf = pipeline("zero-shot-classification", model=model_name)
 
53
  fine_labels = fine_map.get(subject, [])
54
+ out = clf(question, candidate_labels=fine_labels)
55
+ labels, scores = out["labels"][:3], out["scores"][:3]
56
+ duration = round(time.time()-start,3)
57
+ # Log combined run
 
 
58
  with open(LOG_FILE, "a", newline="") as f:
59
  csv.writer(f).writerow([
60
  time.strftime("%Y-%m-%d %H:%M:%S"),
61
  model_name,
62
  question.replace("\n"," "),
63
  subject,
64
+ ";".join(labels),
65
  duration
66
  ])
67
+ return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"⏱ {duration}s"
68
 
69
+ def submit_feedback(question, subject_fb, topic_fb):
 
 
 
 
 
 
 
70
  with open(FEEDBACK_FILE, "a", newline="") as f:
71
+ csv.writer(f).writerow([
72
+ time.strftime("%Y-%m-%d %H:%M:%S"),
73
+ question.replace("\n"," "),
74
+ subject_fb,
75
+ topic_fb
76
+ ])
77
+ return "βœ… Feedback recorded!"
78
 
79
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
80
+ # Build Gradio UI
81
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
82
  with gr.Blocks() as demo:
83
+ gr.Markdown("## Hierarchical Zero-Shot Tagger with Subject Toggle & Feedback")
84
 
85
  with gr.Row():
86
  question_input = gr.Textbox(lines=3, label="Enter your question")
87
  model_input = gr.Dropdown(
88
+ choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model"
 
 
 
89
  )
90
+ go_button = gr.Button("Run Stage 1")
91
 
92
+ # Stage 1 outputs
93
+ subj_radio = gr.Radio(choices=[], label="Top-3 Subjects",
94
+ info="Select to re-run Stage 2 for a different subject")
95
+ stage1_time = gr.Textbox(label="Stage 1 Time")
96
 
97
+ go_button.click(
98
+ fn=lambda q,m: (*run_stage1(q,m),),
99
+ inputs=[question_input, model_input],
100
+ outputs=[subj_radio, stage1_time]
101
+ )
102
+
103
+ # Stage 2 UI
104
+ go2_button = gr.Button("Run Stage 2")
105
  topics_out = gr.Label(label="Top-3 Topics")
106
+ stage2_time = gr.Textbox(label="Stage 2 Time")
107
 
108
+ go2_button.click(
109
+ fn=run_stage2,
110
+ inputs=[question_input, model_input, subj_radio],
111
+ outputs=[topics_out, stage2_time]
112
  )
113
 
114
  gr.Markdown("---")
115
+ gr.Markdown("### Feedback / Correction")
116
 
117
+ subject_fb = gr.Textbox(label="Correct Subject")
118
+ topic_fb = gr.Textbox(label="Correct Topic(s)")
119
+ fb_button = gr.Button("Submit Feedback")
120
+ fb_status = gr.Textbox(label="")
121
 
122
+ fb_button.click(
123
+ fn=submit_feedback,
124
+ inputs=[question_input, subject_fb, topic_fb],
125
+ outputs=[fb_status]
126
  )
127
 
128
+ demo.launch(share=True, ssr=False)