sunbal7 commited on
Commit
7e297f6
Β·
verified Β·
1 Parent(s): 74a4cc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -28
app.py CHANGED
@@ -1,15 +1,13 @@
1
  # app.py
2
  import streamlit as st
3
  from groq import Groq
4
- from transformers import pipeline
5
- import re
6
  import random
7
  from datetime import datetime
8
  from reportlab.lib.pagesizes import letter
9
  from reportlab.pdfgen import canvas
10
  import io
11
- import torch
12
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
13
 
14
  # Initialize components
15
  try:
@@ -20,7 +18,10 @@ except KeyError:
20
 
21
  # Load personality model
22
  try:
23
- personality_model = AutoModelForSequenceClassification.from_pretrained("KevSun/Personality_LM", ignore_mismatched_sizes=True)
 
 
 
24
  personality_tokenizer = AutoTokenizer.from_pretrained("KevSun/Personality_LM")
25
  except Exception as e:
26
  st.error(f"Model loading error: {str(e)}")
@@ -52,6 +53,11 @@ st.markdown("""
52
  gap: 10px;
53
  margin: 20px 0;
54
  }
 
 
 
 
 
55
  </style>
56
  """, unsafe_allow_html=True)
57
 
@@ -67,9 +73,20 @@ QUESTION_BANK = [
67
  {"text": "How do you recharge after a stressful day? 🧘", "type": "serious", "trait": "neuroticism"}
68
  ]
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def get_dynamic_questions():
71
  """Generate random mix of funny/serious questions"""
72
- random.shuffle(QUESTION_BANK)
73
  return random.sample(QUESTION_BANK, 5)
74
 
75
  def analyze_personality(text):
@@ -84,14 +101,14 @@ def analyze_personality(text):
84
  def generate_social_post(platform, traits):
85
  """Generate platform-specific social post"""
86
  prompts = {
87
- "instagram": "Create an Instagram post with 3 emojis and 2 hashtags about:",
88
- "linkedin": "Create a professional LinkedIn post about personal growth with 1 emoji:",
89
- "facebook": "Create a friendly Facebook post with 2 emojis:",
90
- "whatsapp": "Create a casual WhatsApp status with 2 emojis:"
91
  }
92
  prompt = f"""{prompts[platform]}
93
- My personality strengths: {traits}
94
- Make it {['playful and visual', 'professional', 'friendly', 'casual'][list(prompts.keys()).index(platform)]}"""
95
 
96
  response = groq_client.chat.completions.create(
97
  model="mixtral-8x7b-32768",
@@ -100,13 +117,8 @@ def generate_social_post(platform, traits):
100
  )
101
  return response.choices[0].message.content
102
 
103
- # Session state management
104
- if 'questions' not in st.session_state:
105
- st.session_state.questions = get_dynamic_questions()
106
- if 'current_q' not in st.session_state:
107
- st.session_state.current_q = 0
108
- if 'show_post' not in st.session_state:
109
- st.session_state.show_post = False
110
 
111
  # Main UI
112
  st.title("🧠 PsychBot Pro")
@@ -126,11 +138,12 @@ if st.session_state.current_q < len(st.session_state.questions):
126
  user_input = st.text_input("Your response:", key=f"q{st.session_state.current_q}")
127
 
128
  if st.button("Next ➑️"):
 
129
  st.session_state.current_q += 1
130
  st.rerun()
131
  else:
132
  # Generate personality report
133
- combined_text = " ".join([st.session_state[f"q{i}"] for i in range(len(st.session_state.questions))])
134
  traits = analyze_personality(combined_text)
135
 
136
  st.balloons()
@@ -138,34 +151,42 @@ else:
138
 
139
  # Personality visualization
140
  cols = st.columns(5)
141
- traits = {k: v for k, v in sorted(traits.items(), key=lambda item: item[1], reverse=True)}
142
- for i, (trait, score) in enumerate(traits.items()):
143
  cols[i].metric(label=trait.upper(), value=f"{score*100:.0f}%")
144
 
145
  # Social post generation
 
146
  if st.button("πŸ“± Generate Social Media Post"):
147
  st.session_state.show_post = True
148
 
149
  if st.session_state.show_post:
150
  platforms = ["instagram", "linkedin", "facebook", "whatsapp"]
151
- selected = st.radio("Choose platform:", platforms, horizontal=True)
152
 
153
  if post := generate_social_post(selected, traits):
154
  st.markdown(f"""
155
  <div class="social-post">
156
  <h4>🎨 {selected.capitalize()} Post Draft</h4>
157
- <p>{post}</p>
158
- <button onclick="navigator.clipboard.writeText(`{post}`)">πŸ“‹ Copy</button>
159
  </div>
160
  """, unsafe_allow_html=True)
161
 
 
 
 
 
 
 
162
  # Sidebar
163
  with st.sidebar:
164
  st.markdown("## 🌈 Features")
165
  st.markdown("""
166
- - 🎭 Dynamic personality assessment
167
  - πŸ€– AI-generated social posts
168
  - πŸ“Š Visual trait analysis
169
- - πŸ’¬ Mix of fun/serious questions
 
170
  """)
171
- st.image("https://i.imgur.com/7Q4X4yN.png", width=200)
 
1
  # app.py
2
  import streamlit as st
3
  from groq import Groq
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
+ import torch
6
  import random
7
  from datetime import datetime
8
  from reportlab.lib.pagesizes import letter
9
  from reportlab.pdfgen import canvas
10
  import io
 
 
11
 
12
  # Initialize components
13
  try:
 
18
 
19
  # Load personality model
20
  try:
21
+ personality_model = AutoModelForSequenceClassification.from_pretrained(
22
+ "KevSun/Personality_LM",
23
+ ignore_mismatched_sizes=True
24
+ )
25
  personality_tokenizer = AutoTokenizer.from_pretrained("KevSun/Personality_LM")
26
  except Exception as e:
27
  st.error(f"Model loading error: {str(e)}")
 
53
  gap: 10px;
54
  margin: 20px 0;
55
  }
56
+ .response-box {
57
+ border-left: 3px solid #4CAF50;
58
+ padding: 10px;
59
+ margin: 10px 0;
60
+ }
61
  </style>
62
  """, unsafe_allow_html=True)
63
 
 
73
  {"text": "How do you recharge after a stressful day? 🧘", "type": "serious", "trait": "neuroticism"}
74
  ]
75
 
76
+ def initialize_session():
77
+ """Initialize all session state variables"""
78
+ if 'questions' not in st.session_state:
79
+ st.session_state.questions = random.sample(QUESTION_BANK, 5)
80
+ for i in range(5):
81
+ st.session_state[f'q{i}'] = ""
82
+ st.session_state.responses = []
83
+ if 'current_q' not in st.session_state:
84
+ st.session_state.current_q = 0
85
+ if 'show_post' not in st.session_state:
86
+ st.session_state.show_post = False
87
+
88
  def get_dynamic_questions():
89
  """Generate random mix of funny/serious questions"""
 
90
  return random.sample(QUESTION_BANK, 5)
91
 
92
  def analyze_personality(text):
 
101
  def generate_social_post(platform, traits):
102
  """Generate platform-specific social post"""
103
  prompts = {
104
+ "instagram": "Create an Instagram post with 3 emojis and 2 hashtags about personal growth:",
105
+ "linkedin": "Create a professional LinkedIn post about self-improvement with 1 emoji:",
106
+ "facebook": "Create a friendly Facebook post about personality insights with 2 emojis:",
107
+ "whatsapp": "Create a casual WhatsApp status about self-discovery with 2 emojis:"
108
  }
109
  prompt = f"""{prompts[platform]}
110
+ Based on these personality traits: {traits}
111
+ Make it {['visual and trendy', 'professional', 'friendly', 'casual'][list(prompts.keys()).index(platform)]}"""
112
 
113
  response = groq_client.chat.completions.create(
114
  model="mixtral-8x7b-32768",
 
117
  )
118
  return response.choices[0].message.content
119
 
120
+ # Initialize session state
121
+ initialize_session()
 
 
 
 
 
122
 
123
  # Main UI
124
  st.title("🧠 PsychBot Pro")
 
138
  user_input = st.text_input("Your response:", key=f"q{st.session_state.current_q}")
139
 
140
  if st.button("Next ➑️"):
141
+ st.session_state.responses.append(user_input)
142
  st.session_state.current_q += 1
143
  st.rerun()
144
  else:
145
  # Generate personality report
146
+ combined_text = " ".join(st.session_state.responses)
147
  traits = analyze_personality(combined_text)
148
 
149
  st.balloons()
 
151
 
152
  # Personality visualization
153
  cols = st.columns(5)
154
+ sorted_traits = sorted(traits.items(), key=lambda x: x[1], reverse=True)
155
+ for i, (trait, score) in enumerate(sorted_traits):
156
  cols[i].metric(label=trait.upper(), value=f"{score*100:.0f}%")
157
 
158
  # Social post generation
159
+ st.markdown("---")
160
  if st.button("πŸ“± Generate Social Media Post"):
161
  st.session_state.show_post = True
162
 
163
  if st.session_state.show_post:
164
  platforms = ["instagram", "linkedin", "facebook", "whatsapp"]
165
+ selected = st.radio("Choose platform:", platforms, format_func=lambda x: x.capitalize(), horizontal=True)
166
 
167
  if post := generate_social_post(selected, traits):
168
  st.markdown(f"""
169
  <div class="social-post">
170
  <h4>🎨 {selected.capitalize()} Post Draft</h4>
171
+ <div class="response-box">{post}</div>
172
+ <button onclick="navigator.clipboard.writeText(`{post}`)" style="margin-top:10px;">πŸ“‹ Copy Text</button>
173
  </div>
174
  """, unsafe_allow_html=True)
175
 
176
+ # Restart conversation
177
+ if st.button("πŸ”„ Start New Analysis"):
178
+ for key in list(st.session_state.keys()):
179
+ del st.session_state[key]
180
+ st.rerun()
181
+
182
  # Sidebar
183
  with st.sidebar:
184
  st.markdown("## 🌈 Features")
185
  st.markdown("""
186
+ - 🎭 Personality assessment
187
  - πŸ€– AI-generated social posts
188
  - πŸ“Š Visual trait analysis
189
+ - πŸ’¬ Dynamic questions
190
+ - πŸ“₯ PDF report download
191
  """)
192
+