import torch from tokenizers import Tokenizer from torch.utils.data import DataLoader import streamlit as st import base64 from model import CustomDataset, TransformerEncoder st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico") tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} id2tag = {value: key for key, value in tag2id.items()} device = torch.device('cpu') @st.cache_resource def load_model_to_cpu(_model, path="model.pth"): checkpoint = torch.load(path, map_location=torch.device('cpu')) _model.load_state_dict(checkpoint) return _model def get_base64(bin_file): with open(bin_file, 'rb') as f: data = f.read() return base64.b64encode(data).decode() def predict_fonk(model, device, example, tokenizer): model.to(device) model.eval() predictions = [] encodings_prdict = tokenizer.encode(example) predict_texts = [encodings_prdict.tokens] predict_input_ids = [encodings_prdict.ids] predict_attention_masks = [encodings_prdict.attention_mask] predict_token_type_ids = [encodings_prdict.type_ids] prediction_labels = [encodings_prdict.type_ids] predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, prediction_labels) predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) with torch.no_grad(): for dataset in predict_loader: batch_input_ids = dataset['input_ids'].to(device) batch_att_mask = dataset['attention_mask'].to(device) outputs = model(batch_input_ids, batch_att_mask) logits = outputs.view(-1, outputs.size(-1)) # Flatten the outputs _, predicted = torch.max(logits, 1) # Ignore padding tokens for predictions predictions.append(predicted) results_list = [] entity_list = [] results_dict = {} trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) for i, (token, label, attention) in enumerate(trio): if attention != 0 and label != 0 and label !=4: for next_ones in predictions[0].tolist()[i+1:]: i+=1 if next_ones == 4: token = token +" "+ predict_loader.dataset[0]["text"][i] else:break if token not in entity_list: entity_list.append(token) results_list.append({"entity":token,"sentiment":id2tag.get(label)}) results_dict["entity_list"] = entity_list results_dict["results"] = results_list return results_dict model = TransformerEncoder() model = load_model_to_cpu(model, "model.pth") tokenizer = Tokenizer.from_file("tokenizer.json") background = get_base64("./media/background.jpg") with open("./style/style.css", "r") as style: css=f"""""" st.markdown(css, unsafe_allow_html=True) left, middle, right = st.columns([1,1.5,1]) main, comps , result = middle.tabs([" ", " ", " "]) with main: example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın") if st.button("Predict"): predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer) st.write(predict_list)