File size: 3,630 Bytes
1474633
 
934d19e
1474633
 
 
 
 
 
 
934d19e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1474633
 
 
 
934d19e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pandas as pd
import os
from sentence_transformers import SentenceTransformer
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import gradio as gr
import logging

# Set up basic logging (optional, but useful)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


try:
    # Load the data - check for the file path
    df = pd.read_csv('./Mental_Health_FAQ.csv')

    context_data = []
    for i in range(len(df)):
      context = f"Question: {df.iloc[i]['Questions']} Answer: {df.iloc[i]['Answers']}"
      context_data.append(context)

    # Embed the contexts
    embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
    context_embeddings = embedding_model.encode(context_data)


    # Get the API Key - important to check this is set
    groq_key = os.environ.get('new_chatAPI_key')
    if not groq_key:
        raise ValueError("Groq API key not found in environment variables.")


    # LLM used for RAG
    llm = ChatGroq(model="llama-3.3-70b-versatile",api_key=groq_key)

    # Embedding model
    embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")


    # Create the Vector Store!
    vectorstore = Chroma(
        collection_name="medical_dataset_store",
        embedding_function=embed_model,
    )

    # Add data to vector store
    vectorstore.add_texts(context_data)

    retriever = vectorstore.as_retriever()

    # Create the prompt template
    template = ("""You are a mental health professional.
        Use the provided context to answer the question.
        If you don't know the answer, say so. Explain your answer in detail.
        Do not discuss the context in your response; just provide the answer directly.
        Context: {context}
        Question: {question}
        Answer:""")

    rag_prompt = PromptTemplate.from_template(template)
    rag_chain = (
        {"context": retriever, "question": RunnablePassthrough()}
        | rag_prompt
        | llm
        | StrOutputParser()
    )


    def rag_memory_stream(message, history):
        partial_text = ""
        for new_text in rag_chain.stream(message):
            partial_text += new_text
            yield partial_text

    examples = [
        "I am not in a good mood", 
        "what is the possible symptompts of depression?"
    ]

    description = "Real-time AI App with Groq API and LangChain to Answer medical questions"
    title = "ThriveTalk Expert :) Try me!"
    demo = gr.ChatInterface(fn=rag_memory_stream,
                            type="messages",
                            title=title,
                            description=description,
                            fill_height=True,
                            examples=examples,
                            theme="glass",
    )

except Exception as e:
    logging.error(f"An error occurred during initialization: {e}")
    # If there is an error then return a dummy error text to tell user
    def error_function(message, history):
         yield "An error has occurred. Please check the logs"
    demo = gr.ChatInterface(fn=error_function,
                         type="messages",
                         title="ERROR :(",
                         description="Please check the logs",
                         fill_height=True,
                         theme="glass",
    )


if __name__ == "__main__":
    demo.launch()