Goodnight7 commited on
Commit
c69e60a
Β·
verified Β·
1 Parent(s): e6d300b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -183
app.py CHANGED
@@ -7,199 +7,174 @@ from langchain_core.tracers.context import collect_runs
7
  from qdrant_client import QdrantClient
8
  from dotenv import load_dotenv
9
  import os
10
- if "access_granted" not in st.session_state:
11
- st.session_state.access_granted = False
12
- if "profile" not in st.session_state:
13
- st.session_state.profile = None
14
- if "name" not in st.session_state:
15
- st.session_state.name = None
16
- if not st.session_state.access_granted:
17
- # Profile input section
18
- st.title("User Profile")
19
- name = st.text_input("Name")
20
- profile_selector = st.selectbox("Profile", options=["Patient", "Doctor"] )
21
-
22
- profile = profile_selector
23
- if profile and name:
24
- d = False
25
- else:
26
- d = True
27
-
28
- submission = st.button("Submit", disabled=d)
29
-
30
- if submission:
31
- st.session_state.profile = profile
32
- st.session_state.name = name
33
- st.session_state.access_granted = True # Grant access to main app
34
- st.rerun() # Reload the app
35
- else:
36
- load_dotenv()
37
- profile = st.session_state.profile
38
- client = Client()
39
- qdrant_api=os.getenv("QDRANT_API_KEY")
40
- qdrant_url=os.getenv("QDRANT_URL")
41
- qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api)
42
- st.set_page_config(page_title = "MEDICAL CHATBOT")
43
- st.subheader(f"Hello {st.session_state.name}! How can I assist you today!")
44
-
45
- memory = lc_memory.ConversationBufferMemory(
46
- chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
47
- return_messages=True,
48
- memory_key="chat_history",
49
- )
50
- st.sidebar.markdown("## Feedback Scale")
51
- feedback_option = (
52
- "thumbs" if st.sidebar.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "faces"
53
- )
54
 
55
- with st.sidebar:
56
- model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
57
- temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
58
- n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if st.sidebar.button("Clear message history"):
61
- print("Clearing message history")
62
- memory.clear()
63
 
64
- retriever = retriever(n_docs=n_docs)
65
- # Create Chain
66
- chain = get_expression_chain(retriever,model_name,temp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- for msg in st.session_state.langchain_messages:
69
- avatar = "πŸ’" if msg.type == "ai" else None
70
- with st.chat_message(msg.type, avatar=avatar):
71
- st.markdown(msg.content)
72
 
 
 
 
 
 
 
 
73
 
74
- prompt = st.chat_input(placeholder="Describe your symptoms or medical questions ?")
 
 
 
 
75
 
76
- if prompt :
77
- with st.chat_message("user"):
78
- st.write(prompt)
79
-
80
- with st.chat_message("assistant", avatar="πŸ’"):
81
- message_placeholder = st.empty()
82
- full_response = ""
83
- # Define the basic input structure for the chains
84
- input_dict = {"input": prompt.lower()}
85
- used_docs = retriever.get_relevant_documents(prompt.lower())
86
-
87
- with collect_runs() as cb:
88
- for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}):
89
- full_response += chunk.content
90
- message_placeholder.markdown(full_response + "β–Œ")
91
- memory.save_context(input_dict, {"output": full_response})
92
- st.session_state.run_id = cb.traced_runs[0].id
93
- message_placeholder.markdown(full_response)
94
- if used_docs :
95
- docs_content = "\n\n".join(
96
- [
97
- f"Doc {i+1}:\n"
98
- f"Source: {doc.metadata['source']}\n"
99
- f"Title: {doc.metadata['title']}\n"
100
- f"Content: {doc.page_content}\n"
101
- for i, doc in enumerate(used_docs)
102
- ]
103
- )
104
-
105
-
106
-
107
- with st.sidebar:
108
- st.download_button(
109
- label="Consulted Documents",
110
- data=docs_content,
111
- file_name="Consulted_documents.txt",
112
- mime="text/plain",
113
- )
114
 
 
 
 
 
 
 
 
 
 
 
 
115
  with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
116
- run_id = st.session_state.run_id
117
- question_embedding = get_embeddings(prompt)
118
- answer_embedding = get_embeddings(full_response)
119
- # Add question and answer to Qdrant
120
- qdrant_client.upload_collection(
121
- collection_name="chat-history",
122
- payload=[
123
- {"text": prompt, "type": "question", "question_ID": run_id},
124
- {"text": full_response, "type": "answer", "question_ID": run_id, "used_docs":used_docs}
125
- ],
126
- vectors=[
127
- question_embedding,
128
- answer_embedding,
129
- ],
130
- parallel=4,
131
- max_retries=3,
132
- )
 
 
 
 
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- if st.session_state.get("run_id"):
137
- run_id = st.session_state.run_id
138
- feedback = streamlit_feedback(
139
- feedback_type=feedback_option,
140
- optional_text_label="[Optional] Please provide an explanation",
141
- key=f"feedback_{run_id}",
142
- )
143
-
144
- # Define score mappings for both "thumbs" and "faces" feedback systems
145
- score_mappings = {
146
- "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
147
- "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
148
- }
149
-
150
- # Get the score mapping based on the selected feedback option
151
- scores = score_mappings[feedback_option]
152
-
153
- if feedback:
154
- # Get the score from the selected feedback option's score mapping
155
- score = scores.get(feedback["score"])
156
-
157
- if score is not None:
158
- # Formulate feedback type string incorporating the feedback option
159
- # and score value
160
- feedback_type_str = f"{feedback_option} {feedback['score']}"
161
-
162
- # Record the feedback with the formulated feedback type string
163
- # and optional comment
164
- with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
165
- feedback_record = client.create_feedback(
166
- run_id,
167
- feedback_type_str,
168
- score=score,
169
- comment=feedback.get("text"),
170
- source_info={"profile":profile}
171
- )
172
- st.session_state.feedback = {
173
- "feedback_id": str(feedback_record.id),
174
- "score": score,
175
- }
176
- else:
177
- st.warning("Invalid feedback score.")
178
 
179
- with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
180
- if feedback.get("text"):
181
- comment = feedback.get("text")
182
- feedback_embedding = get_embeddings(comment)
183
- else:
184
- comment = "no comment"
185
- feedback_embedding = get_embeddings(comment)
186
-
187
-
188
- qdrant_client.upload_collection(
189
- collection_name="chat-history",
190
- payload=[
191
- {"text": comment,
192
- "Score:":score,
193
- "type": "feedback",
194
- "question_ID": run_id,
195
- "User_profile":profile}
196
- ],
197
- vectors=[
198
- feedback_embedding
199
- ],
200
- parallel=4,
201
- max_retries=3,
202
- )
203
-
204
-
205
-
 
7
  from qdrant_client import QdrantClient
8
  from dotenv import load_dotenv
9
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+
12
+ load_dotenv()
13
+ client = Client()
14
+ qdrant_api=os.getenv("QDRANT_API_KEY")
15
+ qdrant_url=os.getenv("QDRANT_URL")
16
+ qdrant_client = QdrantClient(qdrant_url ,api_key=qdrant_api)
17
+ st.set_page_config(page_title = "MEDICAL CHATBOT")
18
+ st.subheader("Hello! How can I assist you today!")
19
+
20
+ memory = lc_memory.ConversationBufferMemory(
21
+ chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
22
+ return_messages=True,
23
+ memory_key="chat_history",
24
+ )
25
+ st.sidebar.markdown("## Feedback Scale")
26
+ feedback_option = (
27
+ "thumbs" if st.sidebar.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "faces"
28
+ )
29
+
30
+ with st.sidebar:
31
+ model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
32
+ temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
33
+ n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1)
34
+
35
+ if st.sidebar.button("Clear message history"):
36
+ print("Clearing message history")
37
+ memory.clear()
38
+
39
+ retriever = retriever(n_docs=n_docs)
40
+ # Create Chain
41
+ chain = get_expression_chain(retriever,model_name,temp)
42
+
43
+ for msg in st.session_state.langchain_messages:
44
+ avatar = "πŸ’" if msg.type == "ai" else None
45
+ with st.chat_message(msg.type, avatar=avatar):
46
+ st.markdown(msg.content)
47
+
48
+
49
+ prompt = st.chat_input(placeholder="Describe your symptoms or medical questions ?")
50
+
51
+ if prompt :
52
+ with st.chat_message("user"):
53
+ st.write(prompt)
54
+
55
+ with st.chat_message("assistant", avatar="πŸ’"):
56
+ message_placeholder = st.empty()
57
+ full_response = ""
58
+ # Define the basic input structure for the chains
59
+ input_dict = {"input": prompt.lower()}
60
+ used_docs = retriever.get_relevant_documents(prompt.lower())
61
+
62
+ with collect_runs() as cb:
63
+ for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}):
64
+ full_response += chunk.content
65
+ message_placeholder.markdown(full_response + "β–Œ")
66
+ memory.save_context(input_dict, {"output": full_response})
67
+ st.session_state.run_id = cb.traced_runs[0].id
68
+ message_placeholder.markdown(full_response)
69
+ if used_docs :
70
+ docs_content = "\n\n".join(
71
+ [
72
+ f"Doc {i+1}:\n"
73
+ f"Source: {doc.metadata['source']}\n"
74
+ f"Title: {doc.metadata['title']}\n"
75
+ f"Content: {doc.page_content}\n"
76
+ for i, doc in enumerate(used_docs)
77
+ ]
78
+ )
79
 
80
+
 
 
81
 
82
+ with st.sidebar:
83
+ st.download_button(
84
+ label="Consulted Documents",
85
+ data=docs_content,
86
+ file_name="Consulted_documents.txt",
87
+ mime="text/plain",
88
+ )
89
+
90
+ with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
91
+ run_id = st.session_state.run_id
92
+ question_embedding = get_embeddings(prompt)
93
+ answer_embedding = get_embeddings(full_response)
94
+ # Add question and answer to Qdrant
95
+ qdrant_client.upload_collection(
96
+ collection_name="chat-history",
97
+ payload=[
98
+ {"text": prompt, "type": "question", "question_ID": run_id},
99
+ {"text": full_response, "type": "answer", "question_ID": run_id, "used_docs":used_docs}
100
+ ],
101
+ vectors=[
102
+ question_embedding,
103
+ answer_embedding,
104
+ ],
105
+ parallel=4,
106
+ max_retries=3,
107
+ )
108
 
109
+
 
 
 
110
 
111
+ if st.session_state.get("run_id"):
112
+ run_id = st.session_state.run_id
113
+ feedback = streamlit_feedback(
114
+ feedback_type=feedback_option,
115
+ optional_text_label="[Optional] Please provide an explanation",
116
+ key=f"feedback_{run_id}",
117
+ )
118
 
119
+ # Define score mappings for both "thumbs" and "faces" feedback systems
120
+ score_mappings = {
121
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
122
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
123
+ }
124
 
125
+ # Get the score mapping based on the selected feedback option
126
+ scores = score_mappings[feedback_option]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ if feedback:
129
+ # Get the score from the selected feedback option's score mapping
130
+ score = scores.get(feedback["score"])
131
+
132
+ if score is not None:
133
+ # Formulate feedback type string incorporating the feedback option
134
+ # and score value
135
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
136
+
137
+ # Record the feedback with the formulated feedback type string
138
+ # and optional comment
139
  with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
140
+ feedback_record = client.create_feedback(
141
+ run_id,
142
+ feedback_type_str,
143
+ score=score,
144
+ comment=feedback.get("text"),
145
+ source_info={"profile":profile}
146
+ )
147
+ st.session_state.feedback = {
148
+ "feedback_id": str(feedback_record.id),
149
+ "score": score,
150
+ }
151
+ else:
152
+ st.warning("Invalid feedback score.")
153
+
154
+ with st.spinner("Just a sec! Dont enter prompts while loading pelase!"):
155
+ if feedback.get("text"):
156
+ comment = feedback.get("text")
157
+ feedback_embedding = get_embeddings(comment)
158
+ else:
159
+ comment = "no comment"
160
+ feedback_embedding = get_embeddings(comment)
161
 
162
 
163
+ qdrant_client.upload_collection(
164
+ collection_name="chat-history",
165
+ payload=[
166
+ {"text": comment,
167
+ "Score:":score,
168
+ "type": "feedback",
169
+ "question_ID": run_id,
170
+ "User_profile":profile}
171
+ ],
172
+ vectors=[
173
+ feedback_embedding
174
+ ],
175
+ parallel=4,
176
+ max_retries=3,
177
+ )
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
+