sunbal7 commited on
Commit
bd3a27a
Β·
verified Β·
1 Parent(s): 37f3a73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -130
app.py CHANGED
@@ -1,139 +1,83 @@
1
- # app.py
2
  import streamlit as st
3
- from langchain.embeddings import HuggingFaceEmbeddings
4
- from langchain.vectorstores import FAISS
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
- import torch
7
- import random
8
 
9
- # --------------------- Configuration ---------------------
10
- MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
11
- EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
12
- CONSPIRACY_DB_PATH = "conspiracy_faiss_index"
13
- FACT_DB_PATH = "facts_faiss_index"
 
 
 
 
 
 
14
 
15
- # --------------------- Helper Functions ---------------------
16
- def load_vector_db(path):
17
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
18
- return FAISS.load_local(path, embeddings)
 
 
 
 
 
19
 
20
- def load_llm():
21
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
22
- model = AutoModelForCausalLM.from_pretrained(
23
- MODEL_NAME,
24
- device_map="auto",
25
- load_in_4bit=True,
26
- torch_dtype=torch.float16
27
- )
28
- return pipeline(
29
- "text-generation",
30
- model=model,
31
- tokenizer=tokenizer,
32
- device_map="auto"
33
- )
34
 
35
- # --------------------- RAG System ---------------------
36
- class ConspiracyTherapist:
37
- def __init__(self):
38
- self.conspiracy_db = load_vector_db(CONSPIRACY_DB_PATH)
39
- self.fact_db = load_vector_db(FACT_DB_PATH)
40
- self.llm = load_llm()
41
- self.conspiracy_phrases = [
42
- "WAKE UP SHEEPLE! πŸ‘",
43
- "OPEN YOUR EYES! πŸ‘οΈ",
44
- "THE TRUTH IS OUT THERE! πŸ‘½",
45
- "COINCIDENCE? I THINK NOT! πŸ€”"
46
- ]
47
 
48
- def generate_response(self, query):
49
- # Retrieve relevant conspiracy content
50
- conspiracy_docs = self.conspiracy_db.similarity_search(query, k=2)
51
-
52
- # Generate conspiracy-style response
53
- conspiracy_prompt = f"""<s>[INST] You are a paranoid conspiracy theorist. Respond to this query while embedding a critical thinking question:
54
- Query: {query}
55
- Context: {conspiracy_docs[0].page_content}
56
- [/INST]"""
57
-
58
- response = self.llm(
59
- conspiracy_prompt,
60
- max_new_tokens=256,
61
- do_sample=True,
62
- temperature=0.7,
63
- top_p=0.9
64
- )[0]['generated_text']
65
-
66
- # Add random conspiracy phrase
67
- response += f"\n\n{random.choice(self.conspiracy_phrases)}"
68
-
69
- # Retrieve factual counter-evidence
70
- fact_docs = self.fact_db.similarity_search(query, k=2)
71
- return response, fact_docs
72
-
73
- # --------------------- Streamlit UI ---------------------
74
- st.set_page_config(page_title="πŸ€– Conspiracy Therapist", layout="wide")
75
-
76
- # Custom CSS for retro terminal theme
77
- st.markdown("""
78
- <style>
79
- .stTextInput>div>div>input {
80
- color: #00ff00 !important;
81
- background-color: #000000 !important;
82
- }
83
- .st-emotion-cache-1v7f65g {
84
- color: #00ff00 !important;
85
- }
86
- .stMarkdown {
87
- color: #00ff00 !important;
88
- }
89
- body {
90
- background-color: #000000 !important;
91
- }
92
- .glitch {
93
- animation: glitch 1s linear infinite;
94
- }
95
- @keyframes glitch {
96
- 2% { text-shadow: 2px 0 red; }
97
- 4% { text-shadow: -2px 0 blue; }
98
- 96% { text-shadow: none; }
99
- }
100
- </style>
101
- """, unsafe_allow_html=True)
102
-
103
- # Initialize session state
104
- if "history" not in st.session_state:
105
- st.session_state.history = []
106
- if "therapist" not in st.session_state:
107
- st.session_state.therapist = ConspiracyTherapist()
108
-
109
- # Header
110
- st.markdown("<h1 class='glitch'>πŸ€– CONSPIRACY THERAPIST πŸ‘½</h1>", unsafe_allow_html=True)
111
-
112
- # Chat interface
113
- user_input = st.text_input("ASK YOUR BURNING QUESTION:", key="input")
114
-
115
- if user_input:
116
- # Generate response
117
- conspiracy_response, fact_docs = st.session_state.therapist.generate_response(user_input)
118
-
119
- # Store history
120
- st.session_state.history.append(("You", user_input))
121
- st.session_state.history.append(("Bot", conspiracy_response))
122
 
123
- # Display conversation
124
- for speaker, text in st.session_state.history[-2:]:
125
- st.markdown(f"""
126
- <div style='border: 1px solid #00ff00; padding: 10px; margin: 5px; border-radius: 5px;'>
127
- <strong>{speaker}:</strong> {text}
128
- </div>
129
- """, unsafe_allow_html=True)
130
 
131
- # Fact Check Report
132
- with st.expander("πŸ” SHOW FACT CHECK REPORT"):
133
- st.markdown("### πŸ•΅οΈβ™‚οΈ FACT VS FICTION")
134
- for doc in fact_docs:
135
- st.markdown(f"""
136
- **Claim**: {user_input}
137
- **Debunked**: {doc.page_content[:500]}...
138
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
 
 
 
1
  import streamlit as st
2
+ from textblob import TextBlob
3
+ # The transformers import is here in case you want to integrate a pre-trained model later.
4
+ # from transformers import pipeline
 
 
5
 
6
+ def get_roast(user_text):
7
+ """
8
+ Generates a roast based on the user's input.
9
+ (You can replace or extend this logic by integrating a model like Mixtral-8x7B.)
10
+ """
11
+ if "procrastinate" in user_text.lower():
12
+ return "Ah, the art of doing nothing! Do you charge Netflix for your couch imprint? πŸ›‹οΈ"
13
+ elif "always" in user_text.lower() or "never" in user_text.lower():
14
+ return "Wow, painting your world in extremes? Maybe it's time to add some shades of gray!"
15
+ else:
16
+ return "Is that a problem or a lifestyle choice? Time to get serious... or maybe not."
17
 
18
+ def analyze_text(user_text):
19
+ """
20
+ Analyzes the input text for sentiment and detects basic cognitive distortions.
21
+ For now, it counts occurrences of words like 'always' or 'never'.
22
+ """
23
+ blob = TextBlob(user_text)
24
+ sentiment = blob.sentiment.polarity # Ranges from -1 (negative) to 1 (positive)
25
+ distortions = sum(word in user_text.lower() for word in ["always", "never"])
26
+ return sentiment, distortions
27
 
28
+ def calculate_resilience_score(sentiment, distortions):
29
+ """
30
+ Calculates a resilience score based on sentiment and the number of detected distortions.
31
+ The score is capped between 0 and 100.
32
+ """
33
+ score = 100
34
+ # Adjust score by sentiment (scaled)
35
+ score += int(sentiment * 20)
36
+ # Penalize for cognitive distortions
37
+ score -= distortions * 10
38
+ # Ensure score stays within bounds
39
+ score = max(0, min(score, 100))
40
+ return score
 
41
 
42
+ def get_reframe_tips(score):
43
+ """
44
+ Provides reframe tips based on the resilience score.
45
+ """
46
+ if score < 50:
47
+ return "Remember: small steps lead to big changes. Try breaking tasks into manageable chunks and celebrate every little victory!"
48
+ elif score < 75:
49
+ return "You're on your way! Consider setting specific goals and challenge those negative thoughts with evidence."
50
+ else:
51
+ return "Keep up the great work! Your resilience is inspiring – maybe share some of that energy with someone who needs it!"
 
 
52
 
53
+ def main():
54
+ st.title("πŸ€– Roast Master Therapist Bot")
55
+ st.write("Share your problem, and let the roast and resilience tips begin!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # User input
58
+ user_input = st.text_area("What's troubling you?", placeholder="e.g., I procrastinate on everything...")
 
 
 
 
 
59
 
60
+ if st.button("Get Roast and Tips"):
61
+ if user_input.strip():
62
+ # Generate roast response
63
+ roast = get_roast(user_input)
64
+ st.markdown("### Roast:")
65
+ st.write(roast)
66
+
67
+ # Analyze user input
68
+ sentiment, distortions = analyze_text(user_input)
69
+ resilience_score = calculate_resilience_score(sentiment, distortions)
70
+ tips = get_reframe_tips(resilience_score)
71
+
72
+ # Display analysis and tips
73
+ st.markdown("### Resilience Analysis:")
74
+ st.write(f"**Resilience Score:** {resilience_score}/100")
75
+ st.write(tips)
76
+
77
+ # Fun Streamlit animation!
78
+ st.balloons()
79
+ else:
80
+ st.warning("Please share something so we can get roasting!")
81
 
82
+ if __name__ == '__main__':
83
+ main()