Kibo1 commited on
Commit
b603cba
·
verified ·
1 Parent(s): 38b44e3

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import os
3
+ from sentence_transformers import SentenceTransformer
4
+ from langchain_groq import ChatGroq
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_chroma import Chroma
7
+ from langchain_core.prompts import PromptTemplate
8
+ from langchain_core.output_parsers import StrOutputParser
9
+ from langchain_core.runnables import RunnablePassthrough
10
+ import gradio as gr
11
+ import logging
12
+
13
+ # Set up basic logging (optional, but useful)
14
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15
+
16
+
17
+ try:
18
+ # Load the data - check for the file path
19
+ df = pd.read_csv('./Mental_Health_FAQ.csv')
20
+
21
+ context_data = []
22
+ for i in range(len(df)):
23
+ context = f"Question: {df.iloc[i]['Questions']} Answer: {df.iloc[i]['Answers']}"
24
+ context_data.append(context)
25
+
26
+ # Embed the contexts
27
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
28
+ context_embeddings = embedding_model.encode(context_data)
29
+
30
+
31
+ # Get the API Key - important to check this is set
32
+ groq_key = os.environ.get('new_chatAPI_key')
33
+ if not groq_key:
34
+ raise ValueError("Groq API key not found in environment variables.")
35
+
36
+
37
+ # LLM used for RAG
38
+ llm = ChatGroq(model="llama-3.3-70b-versatile",api_key=groq_key)
39
+
40
+ # Embedding model
41
+ embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
42
+
43
+
44
+ # Create the Vector Store!
45
+ vectorstore = Chroma(
46
+ collection_name="medical_dataset_store",
47
+ embedding_function=embed_model,
48
+ )
49
+
50
+ # Add data to vector store
51
+ vectorstore.add_texts(context_data)
52
+
53
+ retriever = vectorstore.as_retriever()
54
+
55
+ # Create the prompt template
56
+ template = ("""You are a mental health professional.
57
+ Use the provided context to answer the question.
58
+ If you don't know the answer, say so. Explain your answer in detail.
59
+ Do not discuss the context in your response; just provide the answer directly.
60
+ Context: {context}
61
+ Question: {question}
62
+ Answer:""")
63
+
64
+ rag_prompt = PromptTemplate.from_template(template)
65
+ rag_chain = (
66
+ {"context": retriever, "question": RunnablePassthrough()}
67
+ | rag_prompt
68
+ | llm
69
+ | StrOutputParser()
70
+ )
71
+
72
+
73
+ def rag_memory_stream(message, history):
74
+ partial_text = ""
75
+ for new_text in rag_chain.stream(message):
76
+ partial_text += new_text
77
+ yield partial_text
78
+
79
+ examples = [
80
+ "I am not in a good mood",
81
+ "what is the possible symptompts of depression?"
82
+ ]
83
+
84
+ description = "Real-time AI App with Groq API and LangChain to Answer medical questions"
85
+ title = "ThriveTalk Expert :) Try me!"
86
+ demo = gr.ChatInterface(fn=rag_memory_stream,
87
+ type="messages",
88
+ title=title,
89
+ description=description,
90
+ fill_height=True,
91
+ examples=examples,
92
+ theme="glass",
93
+ )
94
+
95
+ except Exception as e:
96
+ logging.error(f"An error occurred during initialization: {e}")
97
+ # If there is an error then return a dummy error text to tell user
98
+ def error_function(message, history):
99
+ yield "An error has occurred. Please check the logs"
100
+ demo = gr.ChatInterface(fn=error_function,
101
+ type="messages",
102
+ title="ThriveTalk :(Ask me",
103
+ description="Please check the logs",
104
+ fill_height=True,
105
+ theme="glass",
106
+ )
107
+
108
+
109
+ if __name__ == "__main__":
110
+ demo.launch()