Spaces:
Sleeping
Sleeping
File size: 4,772 Bytes
6c67d38 b90e13b 6c67d38 42f0920 6c67d38 42f0920 6c67d38 42f0920 6c67d38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import json, time, csv, os
import gradio as gr
from transformers import pipeline
# ββββββββββββββββ
# 1) Load taxonomies
# ββββββββββββββββ
with open("coarse_labels.json") as f:
coarse_labels = json.load(f)
with open("fine_labels.json") as f:
fine_map = json.load(f)
# ββββββββββββββββ
# 2) Available zero-shot models
# ββββββββββββββββ
MODEL_CHOICES = [
"facebook/bart-large-mnli",
"roberta-large-mnli",
"joeddav/xlm-roberta-large-xnli",
"mistralai/Mistral-7B-Instruct",
"huggyllama/llama-2-7b-chat",
"google/flan-t5-large",
"google/flan-ul2",
"clare-ai/llama-2-13b-instruct",
"allenai/longformer-base-4096",
"facebook/bart-large-mnli", # duplicate to test allow_custom_value
"valhalla/t5-base-qa-qg-hl",
"EleutherAI/gpt-neox-20b",
"EleutherAI/gpt-j-6b",
"bigscience/bloom-1b1",
"bigscience/bloom-560m",
"bigscience/bloom-3b",
"Salesforce/codegen-2B-multi",
"Salesforce/codegen-350M-multi",
"madlag/llama2-7b-finetuned-qa",
"tiiuae/falcon-7b-instruct",
"tiiuae/falcon-40b-instruct",
"milvus/milvus-embed-english",
"sentence-transformers/all-MiniLM-L6-v2",
"YOUR-OWN-CUSTOM-MODEL"
]
# ββββββββββββββββ
# Helper: ensure log files exist
# ββββββββββββββββ
LOG_FILE = "logs.csv"
FEEDBACK_FILE = "feedback.csv"
for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]),
(FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]:
if not os.path.exists(fn):
with open(fn, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(hdr)
# ββββββββββββββββ
# 3) Build the interface logic
# ββββββββββββββββ
def hierarchical_tag(question, model_name):
start = time.time()
# 3.1 Instantiate classifier per-run (to change models dynamically)
clf = pipeline("zero-shot-classification", model=model_name)
# 3.2 Stage 1: coarse label
coarse_out = clf(question, candidate_labels=coarse_labels)
subject = coarse_out["labels"][0]
# 3.3 Stage 2: fine labels within chosen subject
fine_labels = fine_map.get(subject, [])
fine_out = clf(question, candidate_labels=fine_labels)
top3 = fine_out["labels"][:3]
duration = round(time.time() - start, 3)
# 3.4 Log the run
with open(LOG_FILE, "a", newline="") as f:
csv.writer(f).writerow([
time.strftime("%Y-%m-%d %H:%M:%S"),
model_name,
question.replace("\n"," "),
subject,
";".join(top3),
duration
])
# 3.5 Return for display
return subject, {lbl: round(score,3)
for lbl,score in zip(fine_out["labels"][:3],
fine_out["scores"][:3]
)}, f"β± {duration}s"
def submit_feedback(question, subject, topics, corrected):
ts = time.strftime("%Y-%m-%d %H:%M:%S")
with open(FEEDBACK_FILE, "a", newline="") as f:
csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected])
return "Thank you for your feedback!"
# ββββββββββββββββ
# 4) Define the Gradio UI
# ββββββββββββββββ
with gr.Blocks() as demo:
gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging")
with gr.Row():
question_input = gr.Textbox(lines=3, label="Enter your question")
model_input = gr.Dropdown(
label="Choose model",
choices=MODEL_CHOICES,
value=MODEL_CHOICES[0],
allow_custom_value=True
)
run_button = gr.Button("Tag Question")
subject_out = gr.Textbox(label="Predicted Subject")
topics_out = gr.Label(label="Top-3 Topics")
time_out = gr.Textbox(label="Inference Time")
run_button.click(
hierarchical_tag,
inputs=[question_input, model_input],
outputs=[subject_out, topics_out, time_out]
)
gr.Markdown("---")
gr.Markdown("### Not quite right? Submit your corrections below:")
corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3")
feedback_button = gr.Button("Submit Feedback")
feedback_status = gr.Textbox(label="")
feedback_button.click(
submit_feedback,
inputs=[question_input, subject_out, topics_out, corrected_input],
outputs=[feedback_status]
)
demo.launch()
|