mavinsao commited on
Commit
a681e0a
·
verified ·
1 Parent(s): 0100fa7

create app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import RobertaTokenizer, RobertaModel
5
+ import json
6
+
7
+ # Set device (GPU if available, otherwise CPU)
8
+ device = torch.device("cpu")
9
+
10
+ # Define the path to the model and tokenizer
11
+ model_path = 'mi-roberta-base' # pre-trained model
12
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
13
+
14
+ # Create a common label map
15
+ common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
16
+ num_classes = 8
17
+
18
+
19
+ # MiRoBERTa
20
+ class MIRobertaClassifier(nn.Module):
21
+ def __init__(self, num_classes, dropout_prob=0.3):
22
+ super(MIRobertaClassifier, self).__init__()
23
+ self.roberta = RobertaModel.from_pretrained(model_path)
24
+ self.dropout = nn.Dropout(dropout_prob)
25
+ self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
26
+
27
+ def forward(self, input_ids, attention_mask):
28
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
29
+ last_hidden_state = outputs.last_hidden_state[:, 0, :]
30
+ x = self.dropout(last_hidden_state)
31
+ logits = self.fc(x)
32
+ return logits
33
+
34
+
35
+ # RoBERTa
36
+ class RobertaClassifier(nn.Module):
37
+ def __init__(self, num_classes, dropout_prob=0.3):
38
+ super(RobertaClassifier, self).__init__()
39
+ self.roberta = RobertaModel.from_pretrained('roberta-base')
40
+ self.dropout = nn.Dropout(dropout_prob)
41
+ self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
42
+
43
+ def forward(self, input_ids, attention_mask):
44
+ outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
45
+ last_hidden_state = outputs.last_hidden_state[:, 0, :]
46
+ x = self.dropout(last_hidden_state)
47
+ logits = self.fc(x)
48
+ return logits
49
+
50
+
51
+ # Load the state dictionary into the model
52
+ roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device)
53
+ # Create an instance of your model
54
+ roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device)
55
+
56
+ # Load the state dictionary into the model
57
+ roberta_model.load_state_dict(roberta_loaded_model_state['state_dict'])
58
+
59
+
60
+
61
+ # Load the state dictionary into the model
62
+ mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device)
63
+ # Create an instance of your model
64
+ mi_model = MIRobertaClassifier(num_classes=num_classes).to(device)
65
+
66
+ # Load the state dictionary into the model
67
+ mi_model.load_state_dict(mi_loaded_model_state['state_dict'])
68
+
69
+
70
+
71
+ def predict_label(sentence, tokenizer, model, device):
72
+ # Tokenize the sentence and create attention mask
73
+ tokenized_input = tokenizer(
74
+ sentence,
75
+ add_special_tokens=True,
76
+ max_length=512,
77
+ padding="max_length",
78
+ truncation=True,
79
+ return_tensors="pt"
80
+ )
81
+
82
+ # Move the input tensors to the device
83
+ input_ids = tokenized_input['input_ids'].to(device)
84
+ attention_mask = tokenized_input['attention_mask'].to(device)
85
+
86
+ # Set the model to evaluation mode
87
+ mi_model.eval()
88
+ roberta_model.eval()
89
+
90
+ # Make a prediction
91
+ with torch.no_grad():
92
+ outputs1 = mi_model(input_ids, attention_mask)
93
+ outputs2 = roberta_model(input_ids, attention_mask)
94
+
95
+ # Ensemble predictions: averaging logits from both models
96
+ ensemble_outputs = (outputs1 + outputs2) / 2
97
+
98
+ # Apply softmax to get probabilities
99
+ probabilities = torch.max(ensemble_outputs, dim=1)[0].tolist()
100
+
101
+ # Map the predicted index back to the original class label using class_names
102
+ class_names = list(common_label_map.keys())
103
+ predicted_index = torch.argmax(output, dim=1)
104
+ predicted_label = class_names[predicted_index.item()]
105
+
106
+ # Create JSON response
107
+ response = {
108
+ "label": predicted_label,
109
+ "score": probabilities[common_label_map[predicted_label]]
110
+ }
111
+
112
+ return json.dumps(response)
113
+
114
+
115
+ # Streamlit app
116
+ st.title('Mental Illness Prediction')
117
+
118
+ # Input text area for user input
119
+ sentence = st.text_area("Enter the sentence to predict your mental illness state:")
120
+
121
+ # Prediction button
122
+ if st.button('Predict'):
123
+ # Predict label
124
+ predicted_response = predict_label(sentence, tokenizer, model, device)
125
+ st.json(predicted_response)