Spaces:
Sleeping
Sleeping
import json, time, csv, os | |
import gradio as gr | |
from transformers import pipeline | |
# ββββββββββββββββ | |
# 1) Load taxonomies | |
# ββββββββββββββββ | |
with open("coarse_labels.json") as f: | |
coarse_labels = json.load(f) | |
with open("fine_labels.json") as f: | |
fine_map = json.load(f) | |
# ββββββββββββββββ | |
# 2) Available zero-shot models | |
# ββββββββββββββββ | |
MODEL_CHOICES = [ | |
"facebook/bart-large-mnli", | |
"roberta-large-mnli", | |
"joeddav/xlm-roberta-large-xnli", | |
"mistralai/Mistral-7B-Instruct", | |
"huggyllama/llama-2-7b-chat", | |
"google/flan-t5-large", | |
"google/flan-ul2", | |
"clare-ai/llama-2-13b-instruct", | |
"allenai/longformer-base-4096", | |
"facebook/bart-large-mnli", # duplicate to test allow_custom_value | |
"valhalla/t5-base-qa-qg-hl", | |
"EleutherAI/gpt-neox-20b", | |
"EleutherAI/gpt-j-6b", | |
"bigscience/bloom-1b1", | |
"bigscience/bloom-560m", | |
"bigscience/bloom-3b", | |
"Salesforce/codegen-2B-multi", | |
"Salesforce/codegen-350M-multi", | |
"madlag/llama2-7b-finetuned-qa", | |
"tiiuae/falcon-7b-instruct", | |
"tiiuae/falcon-40b-instruct", | |
"milvus/milvus-embed-english", | |
"sentence-transformers/all-MiniLM-L6-v2", | |
"YOUR-OWN-CUSTOM-MODEL" | |
] | |
# ββββββββββββββββ | |
# Helper: ensure log files exist | |
# ββββββββββββββββ | |
LOG_FILE = "logs.csv" | |
FEEDBACK_FILE = "feedback.csv" | |
for fn, hdr in [(LOG_FILE, ["timestamp","model","question","subject","top3_topics","duration"]), | |
(FEEDBACK_FILE, ["timestamp","question","pred_subject","pred_topics","corrected"])]: | |
if not os.path.exists(fn): | |
with open(fn, "w", newline="") as f: | |
writer = csv.writer(f) | |
writer.writerow(hdr) | |
# ββββββββββββββββ | |
# 3) Build the interface logic | |
# ββββββββββββββββ | |
def hierarchical_tag(question, model_name): | |
start = time.time() | |
# 3.1 Instantiate classifier per-run (to change models dynamically) | |
clf = pipeline("zero-shot-classification", model=model_name) | |
# 3.2 Stage 1: coarse label | |
coarse_out = clf(question, candidate_labels=coarse_labels) | |
subject = coarse_out["labels"][0] | |
# 3.3 Stage 2: fine labels within chosen subject | |
fine_labels = fine_map.get(subject, []) | |
fine_out = clf(question, candidate_labels=fine_labels) | |
top3 = fine_out["labels"][:3] | |
duration = round(time.time() - start, 3) | |
# 3.4 Log the run | |
with open(LOG_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ | |
time.strftime("%Y-%m-%d %H:%M:%S"), | |
model_name, | |
question.replace("\n"," "), | |
subject, | |
";".join(top3), | |
duration | |
]) | |
# 3.5 Return for display | |
return subject, {lbl: round(score,3) | |
for lbl,score in zip(fine_out["labels"][:3], | |
fine_out["scores"][:3] | |
)}, f"β± {duration}s" | |
def submit_feedback(question, subject, topics, corrected): | |
ts = time.strftime("%Y-%m-%d %H:%M:%S") | |
with open(FEEDBACK_FILE, "a", newline="") as f: | |
csv.writer(f).writerow([ts, question.replace("\n"," "), subject, ";".join(topics), corrected]) | |
return "Thank you for your feedback!" | |
# ββββββββββββββββ | |
# 4) Define the Gradio UI | |
# ββββββββββββββββ | |
with gr.Blocks() as demo: | |
gr.Markdown("## Hierarchical Zero-Shot Tagger with Model Selection & Logging") | |
with gr.Row(): | |
question_input = gr.Textbox(lines=3, label="Enter your question") | |
model_input = gr.Dropdown( | |
label="Choose model", | |
choices=MODEL_CHOICES, | |
value=MODEL_CHOICES[0], | |
allow_custom_value=True | |
) | |
run_button = gr.Button("Tag Question") | |
subject_out = gr.Textbox(label="Predicted Subject") | |
topics_out = gr.Label(label="Top-3 Topics") | |
time_out = gr.Textbox(label="Inference Time") | |
run_button.click( | |
hierarchical_tag, | |
inputs=[question_input, model_input], | |
outputs=[subject_out, topics_out, time_out] | |
) | |
gr.Markdown("---") | |
gr.Markdown("### Not quite right? Submit your corrections below:") | |
corrected_input = gr.Textbox(lines=1, placeholder="Correct subject;topic1;topic2;topic3") | |
feedback_button = gr.Button("Submit Feedback") | |
feedback_status = gr.Textbox(label="") | |
feedback_button.click( | |
submit_feedback, | |
inputs=[question_input, subject_out, topics_out, corrected_input], | |
outputs=[feedback_status] | |
) | |
demo.launch() | |