mavinsao commited on
Commit
0d00d6d
·
verified ·
1 Parent(s): c4020ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -4
app.py CHANGED
@@ -2,7 +2,26 @@
2
  from transformers import pipeline
3
  import streamlit as st
4
 
5
- pipe = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-illness")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Streamlit app
8
  st.title('Mental Illness Prediction')
@@ -10,7 +29,8 @@ st.title('Mental Illness Prediction')
10
  # Input text area for user input
11
  sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
12
 
13
- # Prediction button
14
  if st.button('Predict'):
15
-
16
- st.write("Predicted labels:", pipe(sentence))
 
 
 
2
  from transformers import pipeline
3
  import streamlit as st
4
 
5
+ pipe_1 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-illness")
6
+ pipe_2 = pipeline("text-classification", model="mavinsao/roberta-mental-finetuned")
7
+
8
+
9
+ def ensemble_predict(text):
10
+ results_1 = pipe_1(text)
11
+ results_2 = pipe_2(text)
12
+
13
+ # Implement your chosen ensemble strategy here.
14
+ # Example with simple averaging:
15
+ ensemble_scores = {}
16
+ for result in results_1 + results_2:
17
+ label = result['label']
18
+ score = result['score']
19
+ ensemble_scores[label] = ensemble_scores.get(label, 0) + score / 2
20
+
21
+ predicted_label = max(ensemble_scores, key=ensemble_scores.get)
22
+ confidence = ensemble_scores[predicted_label]
23
+
24
+ return predicted_label, confidence
25
 
26
  # Streamlit app
27
  st.title('Mental Illness Prediction')
 
29
  # Input text area for user input
30
  sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
31
 
 
32
  if st.button('Predict'):
33
+ # ... (input validation ... )
34
+ predicted_label, confidence = ensemble_predict(sentence)
35
+ st.write("Predicted label:", predicted_label)
36
+ st.write("Confidence:", confidence)