|
import gradio as gr |
|
from transformers import pipeline |
|
|
|
|
|
model1 = pipeline("text-classification", model="albertmartinez/bert-sdg-classification") |
|
model2 = pipeline("text-classification", model="albertmartinez/bert-multilingual-sdg-classification") |
|
model3 = pipeline("text-classification", model="albertmartinez/distilbert-multilingual-sdg-classification") |
|
model4 = pipeline("text-classification", model="albertmartinez/xlm-roberta-large-sdg-classification") |
|
|
|
|
|
def classify_text(text, model): |
|
result = model(text, top_k=16, truncation=True, max_length=512) |
|
return {p["label"]: p["score"] for p in result} |
|
|
|
|
|
def classify_all(text): |
|
return [ |
|
{p["label"]: p["score"] for p in model1(text, top_k=16, truncation=True, max_length=512)}, |
|
{p["label"]: p["score"] for p in model2(text, top_k=16, truncation=True, max_length=512)}, |
|
{p["label"]: p["score"] for p in model3(text, top_k=16, truncation=True, max_length=512)}, |
|
{p["label"]: p["score"] for p in model4(text, top_k=16, truncation=True, max_length=512)} |
|
] |
|
|
|
|
|
ifaceall = gr.Interface( |
|
fn=classify_all, |
|
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), |
|
outputs=[gr.Label(label="bert"), gr.Label(label="bert-multilingual"), gr.Label(label="distilbert-multilingual"), |
|
gr.Label(label="xlm-roberta-large")], |
|
title="SDG text classification", |
|
description="Enter a text and see the text classification result!", |
|
flagging_mode="never", |
|
api_name="classify_all" |
|
) |
|
|
|
|
|
iface1 = gr.Interface( |
|
fn=lambda text: classify_text(text, model1), |
|
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), |
|
outputs=gr.Label(label="Top SDG Predicted"), |
|
title="BERT SDG classification", |
|
description="Enter a text and see the text classification result!", |
|
flagging_mode="never", |
|
api_name="classify_bert" |
|
) |
|
|
|
|
|
iface2 = gr.Interface( |
|
fn=lambda text: classify_text(text, model2), |
|
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), |
|
outputs=gr.Label(label="Top SDG Predicted"), |
|
title="BERT multilingual SDG classification", |
|
description="Enter a text and see the text classification result!", |
|
flagging_mode="never", |
|
api_name="classify_bert-multilingual" |
|
) |
|
|
|
|
|
iface3 = gr.Interface( |
|
fn=lambda text: classify_text(text, model3), |
|
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), |
|
outputs=gr.Label(label="Top SDG Predicted"), |
|
title="DISTILBERT multilingual SDG classification", |
|
description="Enter a text and see the text classification result!", |
|
flagging_mode="never", |
|
api_name="classify_distilbert-multilingual" |
|
) |
|
|
|
|
|
iface4 = gr.Interface( |
|
fn=lambda text: classify_text(text, model4), |
|
inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."), |
|
outputs=gr.Label(label="Top SDG Predicted"), |
|
title="XLM-ROBERTA-LARGE SDG classification", |
|
description="Enter a text and see the text classification result!", |
|
flagging_mode="never", |
|
api_name="classify_xlm-roberta-large" |
|
) |
|
|
|
with gr.Blocks() as demo: |
|
|
|
gr.TabbedInterface( |
|
interface_list=[ifaceall, iface1, iface2, iface3, iface4], |
|
tab_names=["ALL", "bert-sdg-classification", "bert-multilingual-sdg-classification", |
|
"distilbert-multilingual-sdg-classification", "xlm-roberta-large-sdg-classification"], |
|
title="Sustainable Development Goals (SDG) Text Classifier App", |
|
theme='base' |
|
) |
|
|
|
if __name__ == "__main__": |
|
print(gr.__version__) |
|
demo.launch() |
|
|