naveenus commited on
Commit
6c67d38
Β·
verified Β·
1 Parent(s): 62ee195

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -35
app.py CHANGED
@@ -1,40 +1,138 @@
1
- # app.py
2
- import json, gradio as gr
3
  from transformers import pipeline
4
 
 
5
  # 1) Load taxonomies
 
6
  with open("coarse_labels.json") as f:
7
- coarse_labels = json.load(f) # :contentReference[oaicite:4]{index=4}
8
  with open("fine_labels.json") as f:
9
- fine_map = json.load(f) # :contentReference[oaicite:5]{index=5}
10
-
11
- # 2) Init classifier
12
- classifier = pipeline("zero-shot-classification",
13
- model="facebook/bart-large-mnli")
14
-
15
- # 3) Tagging fn
16
- def hierarchical_tag(question):
17
- # Stage 1: pick coarse subject
18
- coarse_out = classifier(question, candidate_labels=coarse_labels)
19
- chosen = coarse_out["labels"][0]
20
- # Stage 2: fine-grained tags within that subject
21
- fine_labels = fine_map.get(chosen, [])
22
- fine_out = classifier(question, candidate_labels=fine_labels)
23
- # Return both
24
- return {
25
- "Subject": chosen,
26
- **{lbl: round(score,3)
27
- for lbl, score in zip(fine_out["labels"], fine_out["scores"])}
28
- }
29
-
30
- # 4) Build UI
31
- iface = gr.Interface(
32
- fn=hierarchical_tag,
33
- inputs=gr.Textbox(lines=3, label="Enter your question"),
34
- outputs=gr.JSON(label="Hierarchical Tags"),
35
- title="Two-Stage Zero-Shot Question Tagger",
36
- description="Stage 1: classify subject; Stage 2: classify topic within subject."
37
- )
38
-
39
- if __name__=="__main__":
40
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, time, csv, os
2
+ import gradio as gr
3
  from transformers import pipeline
4
 
5
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
6
  # 1) Load taxonomies
7
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
8
  with open("coarse_labels.json") as f:
9
+ coarse_labels = json.load(f)
10
  with open("fine_labels.json") as f:
11
+ fine_map = json.load(f)
12
+
13
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
14
+ # 2) Available zero-shot models
15
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
16
+ MODEL_CHOICES = [
17
+ "facebook/bart-large-mnli",
18
+ "roberta-large-mnli",
19
+ "joeddav/xlm-roberta-large-xnli",
20
+ "mistralai/Mistral-7B-Instruct",
21
+ "huggyllama/llama-2-7b-chat",
22
+ "google/flan-t5-large",
23
+ "google/flan-ul2",
24
+ "clare-ai/llama-2-13b-instruct",
25
+ "allenai/longformer-base-4096",
26
+ "facebook/bart-large-mnli", # duplicate to test allow_custom_value
27
+ "valhalla/t5-base-qa-qg-hl",
28
+ "EleutherAI/gpt-neox-20b",
29
+ "EleutherAI/gpt-j-6b",
30
+ "bigscience/bloom-1b1",
31
+ "bigscience/bloom-560m",
32
+ "bigscience/bloom-3b",
33
+ "Salesforce/codegen-2B-multi",
34
+ "Salesforce/codegen-350M-multi",
35
+ "madlag/llama2-7b-finetuned-qa",
36
+ "tiiuae/falcon-7b-instruct",
37
+ "tiiuae/falcon-40b-instruct",
38
+ "milvus/milvus-embed-english",
39
+ "sentence-transformers/all-MiniLM-L6-v2",
40
+ "YOUR-OWN-CUSTOM-MODEL"
41
+ ]
42
+
43
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
44
+ # Helper: ensure log files exist
45
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
46
+ LOG_FILE = "logs.csv"
47
+ FEEDBACK_FILE = "feedback.csv"
48
+ for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]),
49
+ (FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]:
50
+ if not os.path.exists(fn):
51
+ with open(fn, "w", newline="") as f:
52
+ writer = csv.writer(f)
53
+ writer.writerow(hdr)
54
+
55
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
56
+ # 3) Build the interface logic
57
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
58
+
59
+ def hierarchical_tag(question, model_name):
60
+ start = time.time()
61
+ # 3.1 Instantiate classifier per-run (to change models dynamically)
62
+ clf = pipeline("zero-shot-classification", model=model_name)
63
+
64
+ # 3.2 Stage 1: coarse label
65
+ coarse_out = clf(question, candidate_labels=coarse_labels)
66
+ subject = coarse_out["labels"][0]
67
+
68
+ # 3.3 Stage 2: fine labels within chosen subject
69
+ fine_labels = fine_map.get(subject, [])
70
+ fine_out = clf(question, candidate_labels=fine_labels)
71
+ top3 = fine_out["labels"][:3]
72
+
73
+ duration = round(time.time() - start, 3)
74
+
75
+ # 3.4 Log the run
76
+ with open(LOG_FILE, "a", newline="") as f:
77
+ csv.writer(f).writerow([
78
+ time.strftime("%Y-%m-%d %H:%M:%S"),
79
+ model_name,
80
+ question.replace("\n"," "),
81
+ subject,
82
+ ";".join(top3),
83
+ duration
84
+ ])
85
+
86
+ # 3.5 Return for display
87
+ return subject, {lbl: round(score,3)
88
+ for lbl,score in zip(fine_out["labels"][:3],
89
+ fine_out["scores"][:3]
90
+ )}, f"⏱ {duration}s"
91
+
92
+ def submit_feedback(question, subject, topics, corrected):
93
+ ts = time.strftime("%Y-%m-%d %H:%M:%S")
94
+ with open(FEEDBACK_FILE, "a", newline="") as f:
95
+ csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected])
96
+ return "Thank you for your feedback!"
97
+
98
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
99
+ # 4) Define the Gradio UI
100
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging")
103
+
104
+ with gr.Row():
105
+ question_input = gr.Textbox(lines=3, label="Enter your question")
106
+ model_input = gr.Dropdown(
107
+ label="Choose model",
108
+ choices=MODEL_CHOICES,
109
+ value=MODEL_CHOICES[0],
110
+ allow_custom_value=True
111
+ )
112
+
113
+ run_button = gr.Button("Tag Question")
114
+
115
+ subject_out = gr.Textbox(label="Predicted Subject")
116
+ topics_out = gr.Label(label="Top-3 Topics")
117
+ time_out = gr.Textbox(label="Inference Time")
118
+
119
+ run_button.click(
120
+ hierarchical_tag,
121
+ inputs=[question_input, model_input],
122
+ outputs=[subject_out, topics_out, time_out]
123
+ )
124
+
125
+ gr.Markdown("---")
126
+ gr.Markdown("### Not quite right? Submit your corrections below:")
127
+
128
+ corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3")
129
+ feedback_button = gr.Button("Submit Feedback")
130
+ feedback_status = gr.Textbox(label="")
131
+
132
+ feedback_button.click(
133
+ submit_feedback,
134
+ inputs=[question_input, subject_out, topics_out, corrected_input],
135
+ outputs=[feedback_status]
136
+ )
137
+
138
+ demo.launch()