mesutdmn's picture
Streamlit App
110d80a
raw
history blame
3.49 kB
import torch
from tokenizers import Tokenizer
from torch.utils.data import DataLoader
import streamlit as st
import base64
from model import CustomDataset, TransformerEncoder
st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico")
tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4}
id2tag = {value: key for key, value in tag2id.items()}
device = torch.device('cpu')
@st.cache_resource
def load_model_to_cpu(_model, path="model.pth"):
checkpoint = torch.load(path, map_location=torch.device('cpu'))
_model.load_state_dict(checkpoint)
return _model
def get_base64(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
def predict_fonk(model, device, example, tokenizer):
model.to(device)
model.eval()
predictions = []
encodings_prdict = tokenizer.encode(example)
predict_texts = [encodings_prdict.tokens]
predict_input_ids = [encodings_prdict.ids]
predict_attention_masks = [encodings_prdict.attention_mask]
predict_token_type_ids = [encodings_prdict.type_ids]
prediction_labels = [encodings_prdict.type_ids]
predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids,
prediction_labels)
predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False)
with torch.no_grad():
for dataset in predict_loader:
batch_input_ids = dataset['input_ids'].to(device)
batch_att_mask = dataset['attention_mask'].to(device)
outputs = model(batch_input_ids, batch_att_mask)
logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs
_, predicted = torch.max(logits, 1)
# Ignore padding tokens for predictions
predictions.append(predicted)
results_list = []
entity_list = []
results_dict = {}
trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0])
for i, (token, label, attention) in enumerate(trio):
if attention != 0 and label != 0 and label !=4:
for next_ones in predictions[0].tolist()[i+1:]:
i+=1
if next_ones == 4:
token = token +" "+ predict_loader.dataset[0]["text"][i]
else:break
if token not in entity_list:
entity_list.append(token)
results_list.append({"entity":token,"sentiment":id2tag.get(label)})
results_dict["entity_list"] = entity_list
results_dict["results"] = results_list
return results_dict
model = TransformerEncoder()
model = load_model_to_cpu(model, "model.pth")
tokenizer = Tokenizer.from_file("tokenizer.json")
background = get_base64("./media/background.jpg")
with open("./style/style.css", "r") as style:
css=f"""<style>{style.read().format(background=background)}</style>"""
st.markdown(css, unsafe_allow_html=True)
left, middle, right = st.columns([1,1.5,1])
main, comps , result = middle.tabs([" ", " ", " "])
with main:
example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın")
if st.button("Predict"):
predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer)
st.write(predict_list)