File size: 1,597 Bytes
da8c970
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2fa23b
 
da8c970
 
 
 
a4e86cd
da8c970
 
 
 
 
 
 
 
 
 
 
 
 
04a19cd
da8c970
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline


class ToxicCommentClassification:
    def __init__(self, model_name: str):
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.pipeline = pipeline(
            "text-classification",
            model=self.model,
            tokenizer=self.tokenizer,
            return_all_scores=True,
        )

    def predict(self, text):
        res = self.pipeline(text)[0]
        results = dict()
        is_normal = True
        for x in res:
            results[x['label']] = x['score']

            if float(x['score']) > 0.8:
                is_normal = False

        # if is_normal:
        #     results['normal'] = 1
        return results


def main():
    model = ToxicCommentClassification("DuongTrongChi/d-filter-v1.3")
    iface = gr.Interface(
        fn=model.predict,
        inputs=gr.Textbox(
            lines=3,
            placeholder="Hãy nhập nội dung vào đây",
            label="Input Text",
        ),
        outputs="label",
        title="Toxic Comment Classification",
        examples=[
            "Ôi chú chó này nhìn dễ thương thế!",
            "Cái lúc óc chó sống làm chi cho chật đất",
            "Cầm con dao này và đâm chết con chó này đi!",
            "Tôi dắt con mèo cưng của tôi đi đạo phố ở công viên."
        ],
    )

    iface.launch()


if __name__ == "__main__":
    main()