Update app.py
Browse files
app.py
CHANGED
@@ -1,40 +1,36 @@
|
|
1 |
-
# Use a pipeline as a high-level helper
|
2 |
from transformers import pipeline
|
3 |
import streamlit as st
|
4 |
import streamlit.components.v1 as components
|
5 |
|
|
|
6 |
pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
|
7 |
pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
def st_display_background(image_url):
|
12 |
-
st_style = f"""
|
13 |
-
<style>
|
14 |
-
body {{
|
15 |
-
background-image: url("{image_url}") !important;
|
16 |
-
background-size: cover;
|
17 |
-
}}
|
18 |
-
</style>
|
19 |
-
"""
|
20 |
-
st.markdown(st_style, unsafe_allow_html=True)
|
21 |
-
|
22 |
-
image_url = "https://images.unsplash.com/photo-1504701954957-2010ec3bcec1?q=80&w=1974&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
23 |
-
|
24 |
-
|
25 |
-
def ensemble_predict(text):
|
26 |
# Store results from each model
|
27 |
results_1 = pipe_1(text)
|
28 |
results_2 = pipe_2(text)
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
for result in results:
|
34 |
label = result['label']
|
35 |
score = result['score']
|
36 |
-
ensemble_scores[label]
|
37 |
|
|
|
38 |
predicted_label = max(ensemble_scores, key=ensemble_scores.get)
|
39 |
confidence = ensemble_scores[predicted_label] # Ensemble confidence
|
40 |
|
@@ -46,21 +42,20 @@ st.title('Mental Illness Prediction')
|
|
46 |
# Input text area for user input
|
47 |
sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
|
48 |
|
49 |
-
st_display_background(image_url)
|
50 |
-
|
51 |
if st.button('Predict'):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
1 |
from transformers import pipeline
|
2 |
import streamlit as st
|
3 |
import streamlit.components.v1 as components
|
4 |
|
5 |
+
# Load the models
|
6 |
pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
|
7 |
pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
|
8 |
|
9 |
+
# Function for ensemble prediction
|
10 |
+
def ensemble_predict(text):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# Store results from each model
|
12 |
results_1 = pipe_1(text)
|
13 |
results_2 = pipe_2(text)
|
14 |
|
15 |
+
# Initialize a dictionary with all potential labels to ensure they are considered
|
16 |
+
ensemble_scores = {}
|
17 |
+
|
18 |
+
# Add all labels from the first model's output
|
19 |
+
for result in results_1:
|
20 |
+
ensemble_scores[result['label']] = 0
|
21 |
+
|
22 |
+
# Add all labels from the second model's output
|
23 |
+
for result in results_2:
|
24 |
+
ensemble_scores[result['label']] = 0
|
25 |
+
|
26 |
+
# Aggregate scores from both models
|
27 |
+
for results in [results_1, results_2]:
|
28 |
for result in results:
|
29 |
label = result['label']
|
30 |
score = result['score']
|
31 |
+
ensemble_scores[label] += score / 2 # Averaging the scores
|
32 |
|
33 |
+
# Determine the predicted label and confidence
|
34 |
predicted_label = max(ensemble_scores, key=ensemble_scores.get)
|
35 |
confidence = ensemble_scores[predicted_label] # Ensemble confidence
|
36 |
|
|
|
42 |
# Input text area for user input
|
43 |
sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
|
44 |
|
|
|
|
|
45 |
if st.button('Predict'):
|
46 |
+
# Perform the prediction
|
47 |
+
predicted_label, confidence = ensemble_predict(sentence)
|
48 |
+
|
49 |
+
# CSS injection to target the labels
|
50 |
+
st.markdown("""
|
51 |
+
<style>
|
52 |
+
div[data-testid="metric-container"] {
|
53 |
+
font-weight: bold;
|
54 |
+
font-size: 18px; /* Adjust the font size as desired */
|
55 |
+
}
|
56 |
+
</style>
|
57 |
+
""", unsafe_allow_html=True)
|
58 |
+
|
59 |
+
# Display the result
|
60 |
+
st.write("Result:", predicted_label)
|
61 |
+
st.write("Confidence:", confidence)
|