File size: 3,973 Bytes
04e7b78
9604b3c
 
2adecad
 
 
9604b3c
2adecad
 
 
 
 
 
 
 
 
 
 
fb5842d
 
 
 
 
 
2adecad
 
d3061d0
2adecad
 
 
 
 
d3061d0
2adecad
 
 
 
 
 
 
 
fb5842d
 
2adecad
 
 
 
 
 
 
 
 
068f0da
fb5842d
 
 
 
 
 
 
 
 
 
 
 
2adecad
fb5842d
 
 
 
 
 
 
 
 
 
 
 
b38e092
fb5842d
 
 
 
 
2adecad
b38e092
fb5842d
 
 
 
 
 
 
b38e092
fb5842d
b38e092
d3061d0
2adecad
 
 
 
 
d3061d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from transformers import pipeline

# Initialize the classifiers
zero_shot_classifier = pipeline("zero-shot-classification", model="tasksource/ModernBERT-base-nli")
nli_classifier = pipeline("text-classification", model="tasksource/ModernBERT-base-nli")

def process_input(text_input, labels_or_premise, mode):
    if mode == "Zero-Shot Classification":
        labels = [label.strip() for label in labels_or_premise.split(',')]
        prediction = zero_shot_classifier(text_input, labels)
        results = {label: score for label, score in zip(prediction['labels'], prediction['scores'])}
        return results, ''
    else:  # NLI mode
        prediction = nli_classifier([{"text": text_input, "text_pair": labels_or_premise}])
        results = {pred['label']: pred['score'] for pred in prediction}
        return results, ''

def update_interface(mode):
    if mode == "Zero-Shot Classification":
        return gr.update(label="🏷️ Categories", placeholder="Enter comma-separated categories...")
    else:
        return gr.update(label="Hypothesis", placeholder="Enter a hypothesis to compare with the premise...")

with gr.Blocks() as demo:
    gr.Markdown("# 🤖 ModernBERT Text Analysis")
    
    mode = gr.Radio(
        ["Zero-Shot Classification", "Natural Language Inference"],
        label="Select Mode",
        value="Zero-Shot Classification"
    )
    
    with gr.Column():
        text_input = gr.Textbox(
            label="✍️ Input Text",
            placeholder="Enter your text...",
            lines=3
        )
        
        labels_or_premise = gr.Textbox(
            label="🏷️ Categories",
            placeholder="Enter comma-separated categories...",
            lines=2
        )
        
        submit_btn = gr.Button("Submit")
        
        outputs = [
            gr.Label(label="📊 Results"),
            gr.Markdown(label="📈 Analysis", visible=False)
        ]

        with gr.Column(variant="panel") as zero_shot_examples_panel:
            gr.Examples(
                examples=[
                    ["I need to buy groceries", "shopping, urgent tasks, leisure, philosophy"],
                    ["The sun is very bright today", "weather, astronomy, complaints, poetry"],
                    ["I love playing video games", "entertainment, sports, education, business"],
                    ["The car won't start", "transportation, art, cooking, literature"],
                    ["She wrote a beautiful poem", "creativity, finance, exercise, technology"]
                ],
                inputs=[text_input, labels_or_premise],
                label="Zero-Shot Classification Examples"
            )

        with gr.Column(variant="panel") as nli_examples_panel:
            gr.Examples(
                examples=[
                    ["A man is sleeping on a couch", "The man is awake"],
                    ["The restaurant is full of people", "The place is empty"],
                    ["The child is playing with toys", "The kid is having fun"],
                    ["It's raining outside", "The weather is wet"],
                    ["The dog is barking at the mailman", "There is a cat"]
                ],
                inputs=[text_input, labels_or_premise],
                label="Natural Language Inference Examples"
            )

    def update_visibility(mode):
        return (
            gr.update(visible=(mode == "Zero-Shot Classification")),
            gr.update(visible=(mode == "Natural Language Inference"))
        )

    mode.change(
        fn=update_interface,
        inputs=[mode],
        outputs=[labels_or_premise]
    )
    
    mode.change(
        fn=update_visibility,
        inputs=[mode],
        outputs=[zero_shot_examples_panel, nli_examples_panel]
    )
    
    submit_btn.click(
        fn=process_input,
        inputs=[text_input, labels_or_premise, mode],
        outputs=outputs
    )

if __name__ == "__main__":
    demo.launch()