mavinsao commited on
Commit
436b48d
·
verified ·
1 Parent(s): ff694a9

change to model

Browse files
Files changed (1) hide show
  1. app.py +21 -77
app.py CHANGED
@@ -5,67 +5,20 @@ import json
5
  import streamlit as st
6
 
7
  # Set device (GPU if available, otherwise CPU)
8
- device = torch.device("cpu")
9
 
10
- # Define the path to the model and tokenizer
11
- model_path = 'mi-roberta-base' # pre-trained model
12
- tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
 
 
13
 
14
  # Create a common label map
15
  common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
16
  num_classes = 8
17
 
18
 
19
- # MiRoBERTa
20
- class MIRobertaClassifier(nn.Module):
21
- def __init__(self, num_classes, dropout_prob=0.3):
22
- super(MIRobertaClassifier, self).__init__()
23
- self.roberta = RobertaModel.from_pretrained(model_path)
24
- self.dropout = nn.Dropout(dropout_prob)
25
- self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
26
-
27
- def forward(self, input_ids, attention_mask):
28
- outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
29
- last_hidden_state = outputs.last_hidden_state[:, 0, :]
30
- x = self.dropout(last_hidden_state)
31
- logits = self.fc(x)
32
- return logits
33
-
34
-
35
- # RoBERTa
36
- class RobertaClassifier(nn.Module):
37
- def __init__(self, num_classes, dropout_prob=0.3):
38
- super(RobertaClassifier, self).__init__()
39
- self.roberta = RobertaModel.from_pretrained('roberta-base')
40
- self.dropout = nn.Dropout(dropout_prob)
41
- self.fc = nn.Linear(self.roberta.config.hidden_size, num_classes)
42
-
43
- def forward(self, input_ids, attention_mask):
44
- outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
45
- last_hidden_state = outputs.last_hidden_state[:, 0, :]
46
- x = self.dropout(last_hidden_state)
47
- logits = self.fc(x)
48
- return logits
49
-
50
-
51
- # Load the state dictionary into the model
52
- roberta_loaded_model_state = torch.load('reddit_roberta_state.pth', map_location=device)
53
- # Create an instance of your model
54
- roberta_model = MIRobertaClassifier(num_classes=num_classes).to(device)
55
-
56
- # Load the state dictionary into the model
57
- roberta_model.load_state_dict(roberta_loaded_model_state['state_dict'])
58
-
59
- # Load the state dictionary into the model
60
- mi_loaded_model_state = torch.load('reddit_miroberta_state.pth', map_location=device)
61
- # Create an instance of your model
62
- mi_model = MIRobertaClassifier(num_classes=num_classes).to(device)
63
-
64
- # Load the state dictionary into the model
65
- mi_model.load_state_dict(mi_loaded_model_state['state_dict'])
66
-
67
-
68
- def predict_label(sentence, tokenizer, model1, model2, device):
69
  # Tokenize the sentence and create attention mask
70
  tokenized_input = tokenizer(
71
  sentence,
@@ -81,36 +34,27 @@ def predict_label(sentence, tokenizer, model1, model2, device):
81
  attention_mask = tokenized_input['attention_mask'].to(device)
82
 
83
  # Set the model to evaluation mode
84
- mi_model.eval()
85
- roberta_model.eval()
86
 
87
  # Make a prediction
88
  with torch.no_grad():
89
- outputs1 = mi_model(input_ids, attention_mask)
90
- outputs2 = roberta_model(input_ids, attention_mask)
91
-
92
- # Ensemble predictions: averaging logits from both models
93
- ensemble_outputs = (outputs1 + outputs2) / 2
94
-
95
- # Apply softmax to get probabilities
96
- probabilities = torch.softmax(ensemble_outputs, dim=1)[0].tolist()
97
 
98
- # Map the predicted index back to the original class label using class_names
99
- class_names = list(common_label_map.keys())
 
 
100
 
101
- # Get predicted index and score for each label
102
- label_scores = {}
103
- for i, label in enumerate(class_names):
104
- label_index = common_label_map[label]
105
- label_scores[label] = probabilities[label_index]
106
 
107
- # Sort label scores by score values in descending order
108
- sorted_label_scores = {k: v for k, v in sorted(label_scores.items(), key=lambda item: item[1], reverse=True)}
109
 
110
- # Get the predicted label
111
- predicted_index = torch.argmax(ensemble_outputs, dim=1)
112
 
113
- return sorted_label_scores
114
 
115
 
116
  # Streamlit app
@@ -122,5 +66,5 @@ sentence = st.text_area("Enter the long sentence to predict your mental illness
122
  # Prediction button
123
  if st.button('Predict'):
124
  # Predict label
125
- predicted_response = predict_label(sentence, tokenizer, mi_model, roberta_model, device)
126
  st.json(predicted_response)
 
5
  import streamlit as st
6
 
7
  # Set device (GPU if available, otherwise CPU)
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
 
10
+ # Load model directly
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("mavinsao/mi-roberta-mental-illness")
14
+ model = AutoModelForSequenceClassification.from_pretrained("mavinsao/mi-roberta-mental-illness")
15
 
16
  # Create a common label map
17
  common_label_map = {'ADHD': 0, 'Anxiety': 1, 'bipolar': 2, 'BPD': 3, 'depression': 4, 'OCD': 5, 'ptsd': 6, 'none': 7}
18
  num_classes = 8
19
 
20
 
21
+ def predict_labels(sentence, tokenizer, model, device, threshold=0.5, top_n=5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Tokenize the sentence and create attention mask
23
  tokenized_input = tokenizer(
24
  sentence,
 
34
  attention_mask = tokenized_input['attention_mask'].to(device)
35
 
36
  # Set the model to evaluation mode
37
+ model.eval()
 
38
 
39
  # Make a prediction
40
  with torch.no_grad():
41
+ output = model(input_ids, attention_mask)
 
 
 
 
 
 
 
42
 
43
+ # Apply thresholding to the logits to obtain predicted labels
44
+ logits = output.logits
45
+ sigmoid_output = torch.sigmoid(logits.squeeze(dim=0))
46
+ indices_above_threshold = torch.arange(logits.shape[-1], device=device)[sigmoid_output > threshold]
47
 
48
+ # Sort the indices by their sigmoid values
49
+ sorted_indices = indices_above_threshold[torch.argsort(sigmoid_output[indices_above_threshold], descending=True)]
 
 
 
50
 
51
+ # Map the predicted label indices back to the original class labels using the common label map
52
+ predicted_labels_with_score = [{"label": list(common_label_map.keys())[index], "score": sigmoid_output[index].item()} for index in sorted_indices[:top_n]]
53
 
54
+ # Create a JSON object with labels, scores, and short forms
55
+ json_result = [{"label": entry["label"], "score": entry["score"]} for entry in predicted_labels_with_score]
56
 
57
+ return json.dumps(json_result, indent=4)
58
 
59
 
60
  # Streamlit app
 
66
  # Prediction button
67
  if st.button('Predict'):
68
  # Predict label
69
+ predicted_response = predict_labels(sentence, tokenizer, model, device)
70
  st.json(predicted_response)