|
import streamlit as st |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
import logging |
|
logging.basicConfig(level=logging.ERROR) |
|
import torch |
|
|
|
MAX_LEN = 100 |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True) |
|
|
|
model_DB = DistilBERTClass() |
|
loaded_model_path = './model_DB_1.pt' |
|
model_DB.load_state_dict(torch.load(loaded_model_path, map_location=torch.device('cpu'))) |
|
model_DB.to(device) |
|
|
|
|
|
def sentiment_analysis_DB(input): |
|
inputs = tokenizer.encode_plus( |
|
input, |
|
None, |
|
add_special_tokens=True, |
|
max_length=100, |
|
pad_to_max_length=True, |
|
return_token_type_ids=True |
|
) |
|
ids = torch.tensor([inputs['input_ids']]) |
|
mask = torch.tensor([inputs['attention_mask']]) |
|
token_type_ids = torch.tensor([inputs["token_type_ids"]]) |
|
|
|
|
|
output = model_DB(ids, mask, token_type_ids) |
|
print('Raw output is ', output) |
|
|
|
sigmoid_output = torch.sigmoid(output) |
|
print('Sigmoid output is ', sigmoid_output) |
|
|
|
|
|
result = 1 if sigmoid_output.item() > 0.5 else 0 |
|
|
|
return result |
|
|
|
|
|
st.title("Sentiment Analysis App") |
|
|
|
|
|
user_input = st.text_area("Enter some text:") |
|
|
|
|
|
if st.button("Analyze Sentiment"): |
|
|
|
result = sentiment_analysis_DB(user_input) |
|
|
|
|
|
if result == 1: |
|
st.success("Positive sentiment detected!") |
|
else: |
|
st.error("Negative sentiment detected.") |