File size: 3,769 Bytes
5c0db7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
from transformers import pipeline

# Define the models
model1 = pipeline("text-classification", model="albertmartinez/bert-sdg-classification")
model2 = pipeline("text-classification", model="albertmartinez/bert-multilingual-sdg-classification")
model3 = pipeline("text-classification", model="albertmartinez/distilbert-multilingual-sdg-classification")
model4 = pipeline("text-classification", model="albertmartinez/xlm-roberta-large-sdg-classification")


def classify_text(text, model):
    result = model(text, top_k=16, truncation=True, max_length=512)
    return {p["label"]: p["score"] for p in result}


def classify_all(text):
    return [
        {p["label"]: p["score"] for p in model1(text, top_k=16, truncation=True, max_length=512)},
        {p["label"]: p["score"] for p in model2(text, top_k=16, truncation=True, max_length=512)},
        {p["label"]: p["score"] for p in model3(text, top_k=16, truncation=True, max_length=512)},
        {p["label"]: p["score"] for p in model4(text, top_k=16, truncation=True, max_length=512)}
    ]


ifaceall = gr.Interface(
    fn=classify_all,
    inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
    outputs=[gr.Label(label="bert"), gr.Label(label="bert-multilingual"), gr.Label(label="distilbert-multilingual"),
             gr.Label(label="xlm-roberta-large")],
    title="SDG text classification",
    description="Enter a text and see the text classification result!",
    flagging_mode="never",
    api_name="classify_all"
)

# Interface for the first model
iface1 = gr.Interface(
    fn=lambda text: classify_text(text, model1),
    inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
    outputs=gr.Label(label="Top SDG Predicted"),
    title="BERT SDG classification",
    description="Enter a text and see the text classification result!",
    flagging_mode="never",
    api_name="classify_bert"
)

# Interface for the second model
iface2 = gr.Interface(
    fn=lambda text: classify_text(text, model2),
    inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
    outputs=gr.Label(label="Top SDG Predicted"),
    title="BERT multilingual SDG classification",
    description="Enter a text and see the text classification result!",
    flagging_mode="never",
    api_name="classify_bert-multilingual"
)

# Interface for the three model
iface3 = gr.Interface(
    fn=lambda text: classify_text(text, model3),
    inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
    outputs=gr.Label(label="Top SDG Predicted"),
    title="DISTILBERT multilingual SDG classification",
    description="Enter a text and see the text classification result!",
    flagging_mode="never",
    api_name="classify_distilbert-multilingual"
)

# Interface for the four model
iface4 = gr.Interface(
    fn=lambda text: classify_text(text, model4),
    inputs=gr.Textbox(lines=2, label="Text", placeholder="Enter text here..."),
    outputs=gr.Label(label="Top SDG Predicted"),
    title="XLM-ROBERTA-LARGE SDG classification",
    description="Enter a text and see the text classification result!",
    flagging_mode="never",
    api_name="classify_xlm-roberta-large"
)

with gr.Blocks() as demo:
    # Combine both interfaces into a tabbed interface
    gr.TabbedInterface(
        interface_list=[ifaceall, iface1, iface2, iface3, iface4],
        tab_names=["ALL", "bert-sdg-classification", "bert-multilingual-sdg-classification",
                   "distilbert-multilingual-sdg-classification", "xlm-roberta-large-sdg-classification"],
        title="Sustainable Development Goals (SDG) Text Classifier App",
        theme='base'
    )

if __name__ == "__main__":
    print(gr.__version__)
    demo.launch()