File size: 2,681 Bytes
291793f c12c8cd 29448e1 a681e0a bb77b43 35bd57e 0d00d6d bb77b43 3f7ec6a 0d00d6d bb77b43 3f7ec6a bb77b43 0d00d6d bb77b43 3f7ec6a 0d00d6d a681e0a ff694a9 a681e0a bb77b43 43f3fd4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import pipeline
import streamlit as st
import streamlit.components.v1 as components
# Load the models
pipe_1 = pipeline("text-classification", model="mavinsao/roberta-base-finetuned-mental-health")
pipe_2 = pipeline("text-classification", model="mavinsao/mi-roberta-base-finetuned-mental-health")
# Function for ensemble prediction
def ensemble_predict(text):
# Store results from each model
results_1 = pipe_1(text)
results_2 = pipe_2(text)
# Initialize a dictionary with all potential labels to ensure they are considered
ensemble_scores = {}
# Add all labels from the first model's output
for result in results_1:
ensemble_scores[result['label']] = 0
# Add all labels from the second model's output
for result in results_2:
ensemble_scores[result['label']] = 0
# Aggregate scores from both models
for results in [results_1, results_2]:
for result in results:
label = result['label']
score = result['score']
ensemble_scores[label] += score / 2 # Averaging the scores
# Determine the predicted label and confidence
predicted_label = max(ensemble_scores, key=ensemble_scores.get)
confidence = ensemble_scores[predicted_label] # Ensemble confidence
return predicted_label, confidence
# Streamlit app
st.title('Mental Illness Prediction')
# Input text area for user input
sentence = st.text_area("Enter the long sentence to predict your mental illness state:")
if st.button('Predict'):
# Perform the prediction
predicted_label, confidence = ensemble_predict(sentence)
# CSS injection to target the labels
st.markdown("""
<style>
div[data-testid="metric-container"] {
font-weight: bold;
font-size: 18px; /* Adjust the font size as desired */
}
</style>
""", unsafe_allow_html=True)
# Display the result
st.write("Result:", predicted_label)
st.write("Confidence:", confidence)
# Additional reminder after prediction
st.info("Remember: This prediction is not a diagnosis. Always consult with a healthcare professional for proper evaluation and advice.")
# Additional information
st.markdown("""
### About Our Method
Our method is designed to assist mental health professionals, such as psychologists and psychiatrists, rather than replace them. Using our model to directly calculate mental illness labels can introduce biases, potentially leading to inaccurate diagnoses. Therefore, the predictions made by our model should only be used as a reference, with the final diagnosis being carefully determined by qualified professionals.
""")
|