Spaces:
Sleeping
Sleeping
# app.py | |
import streamlit as st | |
from groq import Groq | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
import random | |
from datetime import datetime | |
from reportlab.lib.pagesizes import letter | |
from reportlab.pdfgen import canvas | |
import io | |
# Initialize components | |
try: | |
groq_client = Groq(api_key=st.secrets["GROQ_API_KEY"]) | |
except KeyError: | |
st.error("GROQ_API_KEY missing in secrets!") | |
st.stop() | |
# Load personality model | |
try: | |
personality_model = AutoModelForSequenceClassification.from_pretrained( | |
"KevSun/Personality_LM", | |
ignore_mismatched_sizes=True | |
) | |
personality_tokenizer = AutoTokenizer.from_pretrained("KevSun/Personality_LM") | |
except Exception as e: | |
st.error(f"Model loading error: {str(e)}") | |
st.stop() | |
# Configure Streamlit | |
st.set_page_config(page_title="π§ PsychBot Pro", layout="wide", page_icon="π€") | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
@keyframes float { | |
0% { transform: translateY(0px); } | |
50% { transform: translateY(-20px); } | |
100% { transform: translateY(0px); } | |
} | |
.personality-title { | |
animation: float 3s ease-in-out infinite; | |
} | |
.social-post { | |
border: 1px solid #e0e0e0; | |
border-radius: 15px; | |
padding: 20px; | |
margin: 15px 0; | |
background: #f8f9fa; | |
} | |
.platform-selector { | |
display: flex; | |
gap: 10px; | |
margin: 20px 0; | |
} | |
.response-box { | |
border-left: 3px solid #4CAF50; | |
padding: 10px; | |
margin: 10px 0; | |
} | |
.question-card { | |
padding: 20px; | |
border-radius: 15px; | |
margin: 10px 0; | |
background: #f8f9fa; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Enhanced question bank with mix of funny/serious questions | |
QUESTION_BANK = [ | |
{"text": "If you were a kitchen appliance, what would you be? πͺ", "type": "funny", "trait": "openness"}, | |
{"text": "How do you handle unexpected changes? π", "type": "serious", "trait": "neuroticism"}, | |
{"text": "What's your spirit animal and why? π¦", "type": "funny", "trait": "agreeableness"}, | |
{"text": "Describe your ideal work environment π₯οΈ", "type": "serious", "trait": "conscientiousness"}, | |
{"text": "If your life was a movie title, what would it be? π¬", "type": "funny", "trait": "openness"}, | |
{"text": "How do you approach conflict resolution? βοΈ", "type": "serious", "trait": "agreeableness"}, | |
{"text": "What superhero power matches your personality? π¦Έ", "type": "funny", "trait": "extraversion"}, | |
{"text": "How do you prioritize tasks? π ", "type": "serious", "trait": "conscientiousness"} | |
] | |
def initialize_session(): | |
"""Initialize all session state variables""" | |
defaults = { | |
'questions': random.sample(QUESTION_BANK, 5), | |
'current_q': 0, | |
'show_post': False, | |
'responses': [], | |
'personality': None | |
} | |
for key, value in defaults.items(): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
def analyze_personality(text): | |
"""Predict personality traits using LM""" | |
inputs = personality_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = personality_model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
traits = ["openness", "conscientiousness", "extraversion", "agreeableness", "neuroticism"] | |
return {trait: float(prob) for trait, prob in zip(traits, probs[0])} | |
def generate_social_post(platform, traits): | |
"""Generate platform-specific social post""" | |
platform_prompts = { | |
"instagram": { | |
"emoji": "πΈ", | |
"prompt": "Create an Instagram post with 3 emojis and 2 hashtags about personal growth:" | |
}, | |
"linkedin": { | |
"emoji": "πΌ", | |
"prompt": "Create a professional LinkedIn post about self-improvement:" | |
}, | |
"facebook": { | |
"emoji": "π₯", | |
"prompt": "Create a friendly Facebook post with 2 emojis:" | |
}, | |
"whatsapp": { | |
"emoji": "π¬", | |
"prompt": "Create a casual WhatsApp status with 2 emojis:" | |
} | |
} | |
# Define the style for each platform separately | |
styles = { | |
'instagram': 'visually appealing with hashtags', | |
'linkedin': 'professional and inspiring', | |
'facebook': 'friendly and engaging', | |
'whatsapp': 'casual and fun' | |
} | |
prompt = f"""{platform_prompts[platform]['prompt']} | |
Based on these personality traits: {traits} | |
Include {platform_prompts[platform]['emoji']} emoji and make it {styles[platform]}""" | |
response = groq_client.chat.completions.create( | |
model="mixtral-8x7b-32768", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.7 | |
) | |
return response.choices[0].message.content | |
# Initialize session state | |
initialize_session() | |
# Main UI | |
st.title("π§ PsychBot Pro") | |
st.markdown("### Your AI Personality Companion π€π¬") | |
# Dynamic question flow | |
if st.session_state.current_q < len(st.session_state.questions): | |
q = st.session_state.questions[st.session_state.current_q] | |
with st.chat_message("assistant"): | |
st.markdown(f""" | |
<div class="question-card"> | |
<h4>{q['text']}</h4> | |
<p>{'π Fun question!' if q['type'] == 'funny' else 'π€ Serious reflection'}</p> | |
</div> | |
""", unsafe_allow_html=True) | |
user_input = st.text_input("Your response:", key=f"q{st.session_state.current_q}") | |
if st.button("Next β‘οΈ"): | |
st.session_state.responses.append(user_input) | |
st.session_state.current_q += 1 | |
st.rerun() | |
else: | |
# Generate personality report | |
if not st.session_state.personality: | |
combined_text = " ".join(st.session_state.responses) | |
st.session_state.personality = analyze_personality(combined_text) | |
traits = st.session_state.personality | |
st.balloons() | |
# Personality report section | |
st.markdown(f"## <div class='personality-title'>π Your Personality Report</div>", unsafe_allow_html=True) | |
# Personality visualization | |
cols = st.columns(5) | |
sorted_traits = sorted(traits.items(), key=lambda x: x[1], reverse=True) | |
for i, (trait, score) in enumerate(sorted_traits): | |
cols[i].metric(label=trait.upper(), value=f"{score*100:.0f}%") | |
# Social post generation | |
st.markdown("---") | |
st.markdown("### π± Social Media Post Generator") | |
platforms = ["instagram", "linkedin", "facebook", "whatsapp"] | |
selected = st.radio("Choose platform:", platforms, | |
format_func=lambda x: f"{x.capitalize()} {'πΈπΌπ₯π¬'[platforms.index(x)]}", | |
horizontal=True) | |
if st.button("Generate Post β¨"): | |
post = generate_social_post(selected, traits) | |
st.markdown(f""" | |
<div class="social-post"> | |
<h4>π¨ {selected.capitalize()} Post Draft</h4> | |
<div class="response-box">{post}</div> | |
<button onclick="navigator.clipboard.writeText(`{post}`)" | |
style="margin-top:10px; background: #4CAF50; color: white; border: none; padding: 8px 15px; border-radius: 5px;"> | |
π Copy Text | |
</button> | |
</div> | |
""", unsafe_allow_html=True) | |
# Restart conversation | |
if st.button("π Start New Analysis"): | |
for key in list(st.session_state.keys()): | |
del st.session_state[key] | |
st.rerun() | |
# Sidebar | |
with st.sidebar: | |
st.markdown("## π Features") | |
st.markdown(""" | |
- π Dynamic personality assessment | |
- π€ AI-generated social posts | |
- π Visual trait analysis | |
- π¬ Mix of fun/serious questions | |
- π₯ PDF report download | |
""") | |