mavinsao's picture
create app.py
a681e0a verified
raw
history blame
4.22 kB
import streamlit as st
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
import json
# Set device (GPU if available, otherwise CPU)
device = torch.device("cpu")
# Define the path to the model and tokenizer
model_path = 'mi-roberta-base' # pre-trained model
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Create a common label map
common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
num_classes = 8
# MiRoBERTa
class MIRobertaClassifier(nn.Module):
def __init__(self, num_classes, dropout_prob=0.3):
super(MIRobertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained(model_path)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state[:, 0, :]
x = self.dropout(last_hidden_state)
logits = self.fc(x)
return logits
# RoBERTa
class RobertaClassifier(nn.Module):
def __init__(self, num_classes, dropout_prob=0.3):
super(RobertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained('roberta-base')
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state[:, 0, :]
x = self.dropout(last_hidden_state)
logits = self.fc(x)
return logits
# Load the state dictionary into the model
roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device)
# Create an instance of your model
roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device)
# Load the state dictionary into the model
roberta_model.load_state_dict(roberta_loaded_model_state['state_dict'])
# Load the state dictionary into the model
mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device)
# Create an instance of your model
mi_model = MIRobertaClassifier(num_classes=num_classes).to(device)
# Load the state dictionary into the model
mi_model.load_state_dict(mi_loaded_model_state['state_dict'])
def predict_label(sentence, tokenizer, model, device):
# Tokenize the sentence and create attention mask
tokenized_input = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# Move the input tensors to the device
input_ids = tokenized_input['input_ids'].to(device)
attention_mask = tokenized_input['attention_mask'].to(device)
# Set the model to evaluation mode
mi_model.eval()
roberta_model.eval()
# Make a prediction
with torch.no_grad():
outputs1 = mi_model(input_ids, attention_mask)
outputs2 = roberta_model(input_ids, attention_mask)
# Ensemble predictions: averaging logits from both models
ensemble_outputs = (outputs1 + outputs2) / 2
# Apply softmax to get probabilities
probabilities = torch.max(ensemble_outputs, dim=1)[0].tolist()
# Map the predicted index back to the original class label using class_names
class_names = list(common_label_map.keys())
predicted_index = torch.argmax(output, dim=1)
predicted_label = class_names[predicted_index.item()]
# Create JSON response
response = {
"label": predicted_label,
"score": probabilities[common_label_map[predicted_label]]
}
return json.dumps(response)
# Streamlit app
st.title('Mental Illness Prediction')
# Input text area for user input
sentence = st.text_area("Enter the sentence to predict your mental illness state:")
# Prediction button
if st.button('Predict'):
# Predict label
predicted_response = predict_label(sentence, tokenizer, model, device)
st.json(predicted_response)