File size: 1,705 Bytes
1a5f890
bf79d0d
9a33247
 
 
bf79d0d
2268b75
5c49b11
33ec467
5c49b11
2268b75
5c49b11
5441b26
 
5c49b11
 
5441b26
5c49b11
 
5441b26
5c49b11
 
 
 
 
 
 
2268b75
33ec467
5c49b11
 
 
9a33247
33ec467
9a33247
d72e17f
 
bf79d0d
5c49b11
79c7e0d
 
8b74bcc
3965ceb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import streamlit as st
import plotly.express as px
import torch

from torch import nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification

option = st.selectbox("Select a toxicity analysis model:", ("RoBERTa", "DistilBERT", "XLM-RoBERTa"))
defaultTxt = "I hate you cancerous insects so much"
txt = st.text_area("Text to analyze", defaultTxt)

# Load tokenizer and model weights, try to default to RoBERTa.
matcho option:
    case RoBERTa:
        tokenizerPath = "s-nlp/roberta_toxicity_classifier"
        modelPath = "s-nlp/roberta_toxicity_classifier"
    case DistilBERT:
        tokenizerPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
        modelPath = "citizenlab/distilbert-base-multilingual-cased-toxicity"
    case XLM-RoBERTa:
        tokenizerPath = "unitary/multilingual-toxic-xlm-roberta"
        modelPath = "unitary/multilingual-toxic-xlm-roberta"
    case _:
        tokenizerPath = "s-nlp/roberta_toxicity_classifier"
        modelPath = "s-nlp/roberta_toxicity_classifier"
tokenizer = AutoTokenizer.from_pretrained(tokenizerPath)
model = AutoModelForSequenceClassification.from_pretrained(modelPath)

# run encoding through model to get classification output
# RoBERTA: [0]: neutral, [1]: toxic
encoding = tokenizer.encode(txt, return_tensors='pt')
result = model(encoding)

# transform logit to get probabilities
prediction = nn.functional.softmax(result.logits, dim=-1)
neutralProb = prediction.data[0][0]
toxicProb = prediction.data[0][1]

# Expected returns from RoBERTa on default text:
# Neutral: 0.0052
# Toxic: 0.9948
st.write("Classification Probabilities")
st.write(f"{neutralProb:4.4} - NEUTRAL")
st.write(f"{toxicProb:4.4} - TOXIC")