Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
# Define the models | |
model1 = pipeline("text-classification", model="albertmartinez/bert-sdg-classification") | |
model2 = pipeline("text-classification", model="albertmartinez/bert-multilingual-sdg-classification") | |
model3 = pipeline("text-classification", model="albertmartinez/distilbert-multilingual-sdg-classification") | |
model4 = pipeline("text-classification", model="albertmartinez/xlm-roberta-large-sdg-classification") | |
def classify_text(text, model): | |
result = model(text, top_k=16, truncation=True, max_length=512) | |
return {p["label"]: p["score"] for p in result} | |
def classify_all(text): | |
return [ | |
{p["label"]: p["score"] for p in model1(text, top_k=16, truncation=True, max_length=512)}, | |
{p["label"]: p["score"] for p in model2(text, top_k=16, truncation=True, max_length=512)}, | |
{p["label"]: p["score"] for p in model3(text, top_k=16, truncation=True, max_length=512)}, | |
{p["label"]: p["score"] for p in model4(text, top_k=16, truncation=True, max_length=512)} | |
] | |
ifaceall = gr.Interface( | |
fn=classify_all, | |
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), | |
outputs=[gr.Label(label="bert"), gr.Label(label="bert-multilingual"), gr.Label(label="distilbert-multilingual"), | |
gr.Label(label="xlm-roberta-large")], | |
title="SDG text classification", | |
description="Enter a text and see the text classification result!", | |
flagging_mode="never", | |
api_name="classify_all" | |
) | |
# Interface for the first model | |
iface1 = gr.Interface( | |
fn=lambda text: classify_text(text, model1), | |
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), | |
outputs=gr.Label(label="Top SDG Predicted"), | |
title="BERT SDG classification", | |
description="Enter a text and see the text classification result!", | |
flagging_mode="never", | |
api_name="classify_bert" | |
) | |
# Interface for the second model | |
iface2 = gr.Interface( | |
fn=lambda text: classify_text(text, model2), | |
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), | |
outputs=gr.Label(label="Top SDG Predicted"), | |
title="BERT multilingual SDG classification", | |
description="Enter a text and see the text classification result!", | |
flagging_mode="never", | |
api_name="classify_bert-multilingual" | |
) | |
# Interface for the three model | |
iface3 = gr.Interface( | |
fn=lambda text: classify_text(text, model3), | |
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), | |
outputs=gr.Label(label="Top SDG Predicted"), | |
title="DISTILBERT multilingual SDG classification", | |
description="Enter a text and see the text classification result!", | |
flagging_mode="never", | |
api_name="classify_distilbert-multilingual" | |
) | |
# Interface for the four model | |
iface4 = gr.Interface( | |
fn=lambda text: classify_text(text, model4), | |
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), | |
outputs=gr.Label(label="Top SDG Predicted"), | |
title="XLM-ROBERTA-LARGE SDG classification", | |
description="Enter a text and see the text classification result!", | |
flagging_mode="never", | |
api_name="classify_xlm-roberta-large" | |
) | |
with gr.Blocks() as demo: | |
# Combine both interfaces into a tabbed interface | |
gr.TabbedInterface( | |
interface_list=[ifaceall, iface1, iface2, iface3, iface4], | |
tab_names=["ALL", "bert-sdg-classification", "bert-multilingual-sdg-classification", | |
"distilbert-multilingual-sdg-classification", "xlm-roberta-large-sdg-classification"], | |
title="Sustainable Development Goals (SDG) Text Classifier App", | |
theme='base' | |
) | |
if __name__ == "__main__": | |
print(gr.__version__) | |
demo.launch() | |