|
from DistilBERT import model_DB |
|
import streamlit as st |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
import logging |
|
logging.basicConfig(level=logging.ERROR) |
|
import torch |
|
|
|
MAX_LEN = 100 |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True) |
|
|
|
def sentiment_analysis_DB(input): |
|
inputs = tokenizer.encode_plus( |
|
input, |
|
None, |
|
add_special_tokens=True, |
|
max_length=MAX_LEN, |
|
pad_to_max_length=True, |
|
return_token_type_ids=True |
|
) |
|
ids = inputs['input_ids'] |
|
mask = inputs['attention_mask'] |
|
token_type_ids = inputs["token_type_ids"] |
|
output = model_DB(ids, mask, token_type_ids) |
|
final_outputs = np.array(output) |
|
final_outputs = final_outputs[0] |
|
if final_outputs == True: |
|
result = 1 |
|
else: |
|
result = 0 |
|
return result |
|
|
|
|
|
st.title("Sentiment Analysis App") |
|
|
|
|
|
user_input = st.text_area("Enter some text:") |
|
|
|
|
|
if st.button("Analyze Sentiment"): |
|
|
|
result = sentiment_analysis_DB(user_input) |
|
|
|
|
|
if result == 1: |
|
st.success("Positive sentiment detected!") |
|
else: |
|
st.error("Negative sentiment detected.") |