mavinsao commited on
Commit
bb77b43
·
verified ·
1 Parent(s): 35bd57e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -39
app.py CHANGED
@@ -1,40 +1,36 @@
1
- # Use a pipeline as a high-level helper
2
  from transformers import pipeline
3
  import streamlit as st
4
  import streamlit.components.v1 as components
5
 
 
6
  pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
7
  pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
8
 
9
-
10
- # Streamlit app with background image
11
- def st_display_background(image_url):
12
- st_style = f"""
13
- <style>
14
- body {{
15
- background-image: url("{image_url}") !important;
16
- background-size: cover;
17
- }}
18
- </style>
19
- """
20
- st.markdown(st_style, unsafe_allow_html=True)
21
-
22
- image_url = "https://images.unsplash.com/photo-1504701954957-2010ec3bcec1?q=80&w=1974&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
23
-
24
-
25
- def ensemble_predict(text):
26
  # Store results from each model
27
  results_1 = pipe_1(text)
28
  results_2 = pipe_2(text)
29
 
30
- ensemble_scores = {}
31
-
32
- for results in [results_1, results_2]: # Iterate through predictions
 
 
 
 
 
 
 
 
 
 
33
  for result in results:
34
  label = result['label']
35
  score = result['score']
36
- ensemble_scores[label] = ensemble_scores.get(label, 0) + score / 2
37
 
 
38
  predicted_label = max(ensemble_scores, key=ensemble_scores.get)
39
  confidence = ensemble_scores[predicted_label] # Ensemble confidence
40
 
@@ -46,21 +42,20 @@ st.title('Mental Illness Prediction')
46
  # Input text area for user input
47
  sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
48
 
49
- st_display_background(image_url)
50
-
51
  if st.button('Predict'):
52
- # ... (input validation ... )
53
- predicted_label, confidence = ensemble_predict(sentence)
54
-
55
- # CSS injection to target the labels
56
- st.markdown("""
57
- <style>
58
- div[data-testid="metric-container"] {
59
- font-weight: bold;
60
- font-size: 18px; /* Adjust the font size as desired */
61
- }
62
- </style>
63
- """, unsafe_allow_html=True)
64
-
65
- st.write("Result:", predicted_label)
66
- st.write("Confidence:", confidence)
 
 
 
1
  from transformers import pipeline
2
  import streamlit as st
3
  import streamlit.components.v1 as components
4
 
5
+ # Load the models
6
  pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
7
  pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
8
 
9
+ # Function for ensemble prediction
10
+ def ensemble_predict(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Store results from each model
12
  results_1 = pipe_1(text)
13
  results_2 = pipe_2(text)
14
 
15
+ # Initialize a dictionary with all potential labels to ensure they are considered
16
+ ensemble_scores = {}
17
+
18
+ # Add all labels from the first model's output
19
+ for result in results_1:
20
+ ensemble_scores[result['label']] = 0
21
+
22
+ # Add all labels from the second model's output
23
+ for result in results_2:
24
+ ensemble_scores[result['label']] = 0
25
+
26
+ # Aggregate scores from both models
27
+ for results in [results_1, results_2]:
28
  for result in results:
29
  label = result['label']
30
  score = result['score']
31
+ ensemble_scores[label] += score / 2 # Averaging the scores
32
 
33
+ # Determine the predicted label and confidence
34
  predicted_label = max(ensemble_scores, key=ensemble_scores.get)
35
  confidence = ensemble_scores[predicted_label] # Ensemble confidence
36
 
 
42
  # Input text area for user input
43
  sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
44
 
 
 
45
  if st.button('Predict'):
46
+ # Perform the prediction
47
+ predicted_label, confidence = ensemble_predict(sentence)
48
+
49
+ # CSS injection to target the labels
50
+ st.markdown("""
51
+ <style>
52
+ div[data-testid="metric-container"] {
53
+ font-weight: bold;
54
+ font-size: 18px; /* Adjust the font size as desired */
55
+ }
56
+ </style>
57
+ """, unsafe_allow_html=True)
58
+
59
+ # Display the result
60
+ st.write("Result:", predicted_label)
61
+ st.write("Confidence:", confidence)