mohd43's picture
Update app.py
245b7fc verified
raw
history blame
2.78 kB
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import gradio as gr
tokenizer = BertTokenizer.from_pretrained('ProsusAI/finbert')
model = BertForSequenceClassification.from_pretrained('ProsusAI/finbert')
def predict(input_text):
tokens = tokenizer.encode_plus(input_text, add_special_tokens = False, return_tensors = 'pt')
input_id_chunks = tokens['input_ids'][0].split(510)
attention_mask_chunks = tokens['attention_mask'][0].split(510)
def get_input_ids_and_attention_mask_chunk():
"""
This function splits the input_ids and attention_mask into chunks of size 'chunksize'.
It also adds special tokens (101 for [CLS] and 102 for [SEP]) at the start and end of each chunk.
If the length of a chunk is less than 'chunksize', it pads the chunk with zeros at the end.
Returns:
input_id_chunks (List[torch.Tensor]): List of chunked input_ids.
attention_mask_chunks (List[torch.Tensor]): List of chunked attention_masks.
"""
chunksize = 512
input_id_chunks = list(tokens['input_ids'][0].split(chunksize - 2))
attention_mask_chunks = list(tokens['attention_mask'][0].split(chunksize - 2))
for i in range(len(input_id_chunks)):
input_id_chunks[i] = torch.cat([
torch.tensor([101]), input_id_chunks[i], torch.tensor([102])
])
attention_mask_chunks[i] = torch.cat([
torch.tensor([1]), attention_mask_chunks[i], torch.tensor([1])
])
pad_length = chunksize - input_id_chunks[i].shape[0]
if pad_length > 0:
input_id_chunks[i] = torch.cat([
input_id_chunks[i], torch.Tensor([0] * pad_length)
])
attention_mask_chunks[i] = torch.cat([
attention_mask_chunks[i], torch.Tensor([0] * pad_length)
])
return input_id_chunks, attention_mask_chunks
input_id_chunks, attention_mask_chunks = get_input_ids_and_attention_mask_chunk()
input_ids = torch.stack(input_id_chunks)
attention_mask = torch.stack(attention_mask_chunks)
input_dict = {
'input_ids' : input_ids.long(),
'attention_mask' : attention_mask.int()
}
outputs = model(**input_dict)
probabilities = torch.nn.functional.softmax(outputs[0], dim = -1 )
mean_probabilities = probabilities.mean(dim = 0)
output = torch.argmax(mean_probabilities).item()
if output==0:
return "positive"
elif output==1:
return "negative"
elif output==2 :
return "neutral"
gradio_app = gr.Interface(
predict,
inputs=gr.Textbox(label="Write a text"),
outputs=gr.Textbox(label="output a text"),
title="Financial Sentiment Analysis",
live=True,
allow_flagging="never",
)
gradio_app.launch()