|
import torch |
|
import torch.nn as nn |
|
from transformers import RobertaTokenizer, RobertaModel |
|
import json |
|
import streamlit as st |
|
|
|
|
|
device = torch.device("cpu") |
|
|
|
|
|
model_path = 'mi-roberta-base' |
|
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') |
|
|
|
|
|
common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7} |
|
num_classes = 8 |
|
|
|
|
|
|
|
class MIRobertaClassifier(nn.Module): |
|
def __init__(self, num_classes, dropout_prob=0.3): |
|
super(MIRobertaClassifier, self).__init__() |
|
self.roberta = RobertaModel.from_pretrained(model_path) |
|
self.dropout = nn.Dropout(dropout_prob) |
|
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
last_hidden_state = outputs.last_hidden_state[:, 0, :] |
|
x = self.dropout(last_hidden_state) |
|
logits = self.fc(x) |
|
return logits |
|
|
|
|
|
|
|
class RobertaClassifier(nn.Module): |
|
def __init__(self, num_classes, dropout_prob=0.3): |
|
super(RobertaClassifier, self).__init__() |
|
self.roberta = RobertaModel.from_pretrained('roberta-base') |
|
self.dropout = nn.Dropout(dropout_prob) |
|
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) |
|
last_hidden_state = outputs.last_hidden_state[:, 0, :] |
|
x = self.dropout(last_hidden_state) |
|
logits = self.fc(x) |
|
return logits |
|
|
|
|
|
|
|
roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device) |
|
|
|
roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device) |
|
|
|
|
|
roberta_model.load_state_dict(roberta_loaded_model_state['state_dict']) |
|
|
|
|
|
mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device) |
|
|
|
mi_model = MIRobertaClassifier(num_classes=num_classes).to(device) |
|
|
|
|
|
mi_model.load_state_dict(mi_loaded_model_state['state_dict']) |
|
|
|
|
|
def predict_label(sentence, tokenizer, model1, model2, device): |
|
|
|
tokenized_input = tokenizer( |
|
sentence, |
|
add_special_tokens=True, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
|
|
|
|
input_ids = tokenized_input['input_ids'].to(device) |
|
attention_mask = tokenized_input['attention_mask'].to(device) |
|
|
|
|
|
mi_model.eval() |
|
roberta_model.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
outputs1 = mi_model(input_ids, attention_mask) |
|
outputs2 = roberta_model(input_ids, attention_mask) |
|
|
|
|
|
ensemble_outputs = (outputs1 + outputs2) / 2 |
|
|
|
|
|
probabilities = torch.softmax(ensemble_outputs, dim=1)[0].tolist() |
|
|
|
|
|
class_names = list(common_label_map.keys()) |
|
|
|
|
|
label_scores = {} |
|
for i, label in enumerate(class_names): |
|
label_index = common_label_map[label] |
|
label_scores[label] = probabilities[label_index] |
|
|
|
|
|
sorted_label_scores = {k: v for k, v in sorted(label_scores.items(), key=lambda item: item[1], reverse=True)} |
|
|
|
|
|
predicted_index = torch.argmax(ensemble_outputs, dim=1) |
|
|
|
return sorted_label_scores |
|
|
|
|
|
|
|
st.title('Mental Illness Prediction') |
|
|
|
|
|
sentence = st.text_area("Enter the long sentence to predict your mental illness state:") |
|
|
|
|
|
if st.button('Predict'): |
|
|
|
predicted_response = predict_label(sentence, tokenizer, mi_model, roberta_model, device) |
|
st.json(predicted_response) |