create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from transformers import RobertaTokenizer, RobertaModel
|
5 |
+
import json
|
6 |
+
|
7 |
+
# Set device (GPU if available, otherwise CPU)
|
8 |
+
device = torch.device("cpu")
|
9 |
+
|
10 |
+
# Define the path to the model and tokenizer
|
11 |
+
model_path = 'mi-roberta-base' # pre-trained model
|
12 |
+
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
|
13 |
+
|
14 |
+
# Create a common label map
|
15 |
+
common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
|
16 |
+
num_classes = 8
|
17 |
+
|
18 |
+
|
19 |
+
# MiRoBERTa
|
20 |
+
class MIRobertaClassifier(nn.Module):
|
21 |
+
def __init__(self, num_classes, dropout_prob=0.3):
|
22 |
+
super(MIRobertaClassifier, self).__init__()
|
23 |
+
self.roberta = RobertaModel.from_pretrained(model_path)
|
24 |
+
self.dropout = nn.Dropout(dropout_prob)
|
25 |
+
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
|
26 |
+
|
27 |
+
def forward(self, input_ids, attention_mask):
|
28 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
29 |
+
last_hidden_state = outputs.last_hidden_state[:, 0, :]
|
30 |
+
x = self.dropout(last_hidden_state)
|
31 |
+
logits = self.fc(x)
|
32 |
+
return logits
|
33 |
+
|
34 |
+
|
35 |
+
# RoBERTa
|
36 |
+
class RobertaClassifier(nn.Module):
|
37 |
+
def __init__(self, num_classes, dropout_prob=0.3):
|
38 |
+
super(RobertaClassifier, self).__init__()
|
39 |
+
self.roberta = RobertaModel.from_pretrained('roberta-base')
|
40 |
+
self.dropout = nn.Dropout(dropout_prob)
|
41 |
+
self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
|
42 |
+
|
43 |
+
def forward(self, input_ids, attention_mask):
|
44 |
+
outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
|
45 |
+
last_hidden_state = outputs.last_hidden_state[:, 0, :]
|
46 |
+
x = self.dropout(last_hidden_state)
|
47 |
+
logits = self.fc(x)
|
48 |
+
return logits
|
49 |
+
|
50 |
+
|
51 |
+
# Load the state dictionary into the model
|
52 |
+
roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device)
|
53 |
+
# Create an instance of your model
|
54 |
+
roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device)
|
55 |
+
|
56 |
+
# Load the state dictionary into the model
|
57 |
+
roberta_model.load_state_dict(roberta_loaded_model_state['state_dict'])
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
# Load the state dictionary into the model
|
62 |
+
mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device)
|
63 |
+
# Create an instance of your model
|
64 |
+
mi_model = MIRobertaClassifier(num_classes=num_classes).to(device)
|
65 |
+
|
66 |
+
# Load the state dictionary into the model
|
67 |
+
mi_model.load_state_dict(mi_loaded_model_state['state_dict'])
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
def predict_label(sentence, tokenizer, model, device):
|
72 |
+
# Tokenize the sentence and create attention mask
|
73 |
+
tokenized_input = tokenizer(
|
74 |
+
sentence,
|
75 |
+
add_special_tokens=True,
|
76 |
+
max_length=512,
|
77 |
+
padding="max_length",
|
78 |
+
truncation=True,
|
79 |
+
return_tensors="pt"
|
80 |
+
)
|
81 |
+
|
82 |
+
# Move the input tensors to the device
|
83 |
+
input_ids = tokenized_input['input_ids'].to(device)
|
84 |
+
attention_mask = tokenized_input['attention_mask'].to(device)
|
85 |
+
|
86 |
+
# Set the model to evaluation mode
|
87 |
+
mi_model.eval()
|
88 |
+
roberta_model.eval()
|
89 |
+
|
90 |
+
# Make a prediction
|
91 |
+
with torch.no_grad():
|
92 |
+
outputs1 = mi_model(input_ids, attention_mask)
|
93 |
+
outputs2 = roberta_model(input_ids, attention_mask)
|
94 |
+
|
95 |
+
# Ensemble predictions: averaging logits from both models
|
96 |
+
ensemble_outputs = (outputs1 + outputs2) / 2
|
97 |
+
|
98 |
+
# Apply softmax to get probabilities
|
99 |
+
probabilities = torch.max(ensemble_outputs, dim=1)[0].tolist()
|
100 |
+
|
101 |
+
# Map the predicted index back to the original class label using class_names
|
102 |
+
class_names = list(common_label_map.keys())
|
103 |
+
predicted_index = torch.argmax(output, dim=1)
|
104 |
+
predicted_label = class_names[predicted_index.item()]
|
105 |
+
|
106 |
+
# Create JSON response
|
107 |
+
response = {
|
108 |
+
"label": predicted_label,
|
109 |
+
"score": probabilities[common_label_map[predicted_label]]
|
110 |
+
}
|
111 |
+
|
112 |
+
return json.dumps(response)
|
113 |
+
|
114 |
+
|
115 |
+
# Streamlit app
|
116 |
+
st.title('Mental Illness Prediction')
|
117 |
+
|
118 |
+
# Input text area for user input
|
119 |
+
sentence = st.text_area("Enter the sentence to predict your mental illness state:")
|
120 |
+
|
121 |
+
# Prediction button
|
122 |
+
if st.button('Predict'):
|
123 |
+
# Predict label
|
124 |
+
predicted_response = predict_label(sentence, tokenizer, model, device)
|
125 |
+
st.json(predicted_response)
|