File size: 4,467 Bytes
a681e0a ff694a9 a681e0a dd1c6f2 a681e0a ff694a9 a681e0a ff694a9 a681e0a ff694a9 a681e0a ff694a9 a681e0a dd1c6f2 1346a67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import torch.nn as nn
from transformers import RobertaTokenizer, RobertaModel
import json
import streamlit as st
# Set device (GPU if available, otherwise CPU)
device = torch.device("cpu")
# Define the path to the model and tokenizer
model_path = 'mi-roberta-base' # pre-trained model
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
# Create a common label map
common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
num_classes = 8
# MiRoBERTa
class MIRobertaClassifier(nn.Module):
def __init__(self, num_classes, dropout_prob=0.3):
super(MIRobertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained(model_path)
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state[:, 0, :]
x = self.dropout(last_hidden_state)
logits = self.fc(x)
return logits
# RoBERTa
class RobertaClassifier(nn.Module):
def __init__(self, num_classes, dropout_prob=0.3):
super(RobertaClassifier, self).__init__()
self.roberta = RobertaModel.from_pretrained('roberta-base')
self.dropout = nn.Dropout(dropout_prob)
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
def forward(self, input_ids, attention_mask):
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
last_hidden_state = outputs.last_hidden_state[:, 0, :]
x = self.dropout(last_hidden_state)
logits = self.fc(x)
return logits
# Load the state dictionary into the model
roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device)
# Create an instance of your model
roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device)
# Load the state dictionary into the model
roberta_model.load_state_dict(roberta_loaded_model_state['state_dict'])
# Load the state dictionary into the model
mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device)
# Create an instance of your model
mi_model = MIRobertaClassifier(num_classes=num_classes).to(device)
# Load the state dictionary into the model
mi_model.load_state_dict(mi_loaded_model_state['state_dict'])
def predict_label(sentence, tokenizer, model1, model2, device):
# Tokenize the sentence and create attention mask
tokenized_input = tokenizer(
sentence,
add_special_tokens=True,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# Move the input tensors to the device
input_ids = tokenized_input['input_ids'].to(device)
attention_mask = tokenized_input['attention_mask'].to(device)
# Set the model to evaluation mode
mi_model.eval()
roberta_model.eval()
# Make a prediction
with torch.no_grad():
outputs1 = mi_model(input_ids, attention_mask)
outputs2 = roberta_model(input_ids, attention_mask)
# Ensemble predictions: averaging logits from both models
ensemble_outputs = (outputs1 + outputs2) / 2
# Apply softmax to get probabilities
probabilities = torch.softmax(ensemble_outputs, dim=1)[0].tolist()
# Map the predicted index back to the original class label using class_names
class_names = list(common_label_map.keys())
# Get predicted index and score for each label
label_scores = {}
for i, label in enumerate(class_names):
label_index = common_label_map[label]
label_scores[label] = probabilities[label_index]
# Sort label scores by score values in descending order
sorted_label_scores = {k: v for k, v in sorted(label_scores.items(), key=lambda item: item[1], reverse=True)}
# Get the predicted label
predicted_index = torch.argmax(ensemble_outputs, dim=1)
return sorted_label_scores
# Streamlit app
st.title('Mental Illness Prediction')
# Input text area for user input
sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
# Prediction button
if st.button('Predict'):
# Predict label
predicted_response = predict_label(sentence, tokenizer, mi_model, roberta_model, device)
st.json(predicted_response) |