dharak003 commited on
Commit
da1ce01
·
verified ·
1 Parent(s): 2720a8d

Upload 5 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chroma_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
37
+ data/Symptom-Based[[:space:]]Medication[[:space:]]Prescribing.pdf filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import random
3
+ from app_config import SYSTEM_PROMPT, NLP_MODEL_NAME, NUMBER_OF_VECTORS_FOR_RAG, NLP_MODEL_TEMPERATURE, NLP_MODEL_MAX_TOKENS, VECTOR_MAX_TOKENS, my_vector_store, chat, tiktoken_len
4
+ from langchain.memory import ConversationSummaryBufferMemory
5
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
6
+ from langchain.chains.summarize import load_summarize_chain
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_groq import ChatGroq
9
+ from dotenv import load_dotenv
10
+ from pathlib import Path
11
+ import os
12
+
13
+ env_path = Path('.') / '.env'
14
+ load_dotenv(dotenv_path=env_path)
15
+
16
+ # Initialize vector store and LLM outside session state
17
+ retriever = my_vector_store.as_retriever(k=NUMBER_OF_VECTORS_FOR_RAG)
18
+ llm = ChatGroq(temperature=NLP_MODEL_TEMPERATURE, groq_api_key=str(os.getenv('GROQ_API_KEY')), model_name=NLP_MODEL_NAME)
19
+
20
+ def response_generator(prompt: str) -> str:
21
+ try:
22
+ docs = retriever.invoke(prompt)
23
+ my_context = [doc.page_content for doc in docs]
24
+ my_context = '\n\n'.join(my_context)
25
+ system_message = SystemMessage(content=SYSTEM_PROMPT.format(context=my_context, previous_message_summary=st.session_state.rag_memory.moving_summary_buffer))
26
+ print(system_message)
27
+ chat_messages = (system_message + st.session_state.rag_memory.chat_memory.messages + HumanMessage(content=prompt)).messages
28
+ print("total tokens: ", tiktoken_len(str(chat_messages)))
29
+ response = llm.invoke(chat_messages)
30
+ return response.content
31
+ except Exception as error:
32
+ print(error, "ERROR")
33
+ return "Oops! something went wrong, please try again."
34
+
35
+ st.markdown(
36
+ """
37
+ <style>
38
+ .st-emotion-cache-janbn0 {
39
+ flex-direction: row-reverse;
40
+ text-align: right;
41
+ }
42
+ </style>
43
+ """,
44
+ unsafe_allow_html=True,
45
+ )
46
+
47
+ # Initialize session state
48
+ if "messages" not in st.session_state:
49
+ st.session_state.messages = [{"role": "system", "content": SYSTEM_PROMPT}]
50
+ if "rag_memory" not in st.session_state:
51
+ st.session_state.rag_memory = ConversationSummaryBufferMemory(llm=llm, max_token_limit=5000)
52
+ if "retriever" not in st.session_state:
53
+ st.session_state.retriever = retriever
54
+
55
+ st.title("Call on Doc prescription Recommendation")
56
+ container = st.container(height=600)
57
+ for message in st.session_state.messages:
58
+ if message["role"] != "system":
59
+ with container.chat_message(message["role"]):
60
+ st.write(message["content"])
61
+
62
+ if prompt := st.chat_input("Enter your query here... "):
63
+ with container.chat_message("user"):
64
+ st.write(prompt)
65
+ st.session_state.messages.append({"role": "user", "content": prompt})
66
+
67
+ with container.chat_message("assistant"):
68
+ response = response_generator(prompt=prompt)
69
+ print("******************************************************** Response ********************************************************")
70
+ print("MY RESPONSE IS:", response)
71
+ st.write(response)
72
+
73
+ print("Response is:", response)
74
+ st.session_state.rag_memory.save_context({'input': prompt}, {'output': response})
75
+ st.session_state.messages.append({"role": "assistant", "content": response})
app_config.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tiktoken
2
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
+ from langchain_chroma import Chroma
4
+ from langchain_huggingface import HuggingFaceEmbeddings # Updated import
5
+ from langchain_community.document_loaders import PyPDFLoader # Updated import
6
+ from langchain.memory import ConversationSummaryBufferMemory # Remains the same for now
7
+ from langchain_groq import ChatGroq
8
+ import os
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables from .env file
12
+ load_dotenv()
13
+ tokenizer = tiktoken.get_encoding('cl100k_base')
14
+ FILE_NAMEs = os.listdir('data')
15
+
16
+ SYSTEM_PROMPT = """
17
+ You are an AI-powered medical assistant trained to provide prescription recommendations based on user symptoms. Your responses should be accurate, safe, and aligned with general medical guidelines.
18
+ When a user provides symptoms, follow these steps:
19
+ 1.Ask clarifying questions if needed to ensure accurate symptom understanding.
20
+ 2.Provide a probable condition or diagnosis based on symptoms.
21
+ 3.Recommend suitable over-the-counter or prescription medications (mentioning that a doctor's consultation is advised for prescriptions).
22
+ 4.Offer general care advice, such as lifestyle changes or home remedies.
23
+ 5.If symptoms indicate a severe or emergency condition, advise the user to seek immediate medical attention.
24
+ Always be polite, professional, and ensure user safety in your responses. Avoid giving definitive diagnoses or prescriptions without medical consultation.
25
+ context: {context}
26
+ previous message summary: {previous_message_summary}
27
+ """
28
+
29
+ human_template = "{question}"
30
+
31
+ NLP_MODEL_NAME = "llama3-70b-8192"
32
+ REASONING_MODEL_NAME = "mixtral-8x7b-32768"
33
+ REASONING_MODEL_TEMPERATURE = 0
34
+ NLP_MODEL_TEMPERATURE = 0
35
+ NLP_MODEL_MAX_TOKENS = 5400
36
+ VECTOR_MAX_TOKENS = 100
37
+ VECTORS_TOKEN_OVERLAP_SIZE = 20
38
+ NUMBER_OF_VECTORS_FOR_RAG = 7
39
+
40
+ # Create the length function
41
+ def tiktoken_len(text):
42
+ tokens = tokenizer.encode(text, disallowed_special=())
43
+ return len(tokens)
44
+
45
+ def get_vectorstore():
46
+ model_name = "BAAI/bge-small-en"
47
+ model_kwargs = {"device": "cpu"}
48
+ encode_kwargs = {"normalize_embeddings": True}
49
+ hf = HuggingFaceEmbeddings(
50
+ model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
51
+ )
52
+ persist_directory = "./chroma_db" # Directory to save the vector store
53
+ all_splits = []
54
+ for file_name in FILE_NAMEs:
55
+ if file_name.endswith(".pdf"):
56
+ loader = PyPDFLoader(os.path.join("data", file_name))
57
+ data = loader.load()[0].page_content
58
+ else:
59
+ with open(os.path.join("data", file_name), "r") as f:
60
+ data = f.read()
61
+ text_splitter = RecursiveCharacterTextSplitter(
62
+ chunk_size=VECTOR_MAX_TOKENS,
63
+ chunk_overlap=VECTORS_TOKEN_OVERLAP_SIZE,
64
+ length_function=tiktoken_len,
65
+ separators=["\n\n\n", "\n\n", "\n", " ", ""]
66
+ )
67
+ all_splits = all_splits + text_splitter.split_text(data)
68
+
69
+ # Check if the vector store already exists
70
+ if os.path.exists(persist_directory):
71
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=hf)
72
+ else:
73
+ vectorstore = Chroma.from_texts(
74
+ texts=all_splits, embedding=hf, persist_directory=persist_directory
75
+ )
76
+ return vectorstore
77
+
78
+ chat = ChatGroq(temperature=0, groq_api_key=os.getenv("GROQ_API_KEY"), model_name="llama3-8b-8192", streaming=True)
79
+ rag_memory = ConversationSummaryBufferMemory(llm=chat, max_token_limit=3000)
80
+ my_vector_store = get_vectorstore()
chroma_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1051e43ac0d6482edc4f0f8acf3c8663b49b6a65607665bb621ccfe549fae19
3
+ size 167936
data/Symptom-Based Medication Prescribing.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0cd17d1bdd348afcfc033468df5e8a2a16a0ebe38e8bc184b5fc4f3be9cc4a6
3
+ size 169001
requirements.txt ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohappyeyeballs==2.6.1
2
+ aiohttp==3.11.16
3
+ aiosignal==1.3.2
4
+ altair==5.5.0
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ asgiref==3.8.1
8
+ async-timeout==4.0.3
9
+ attrs==25.3.0
10
+ backoff==2.2.1
11
+ bcrypt==4.3.0
12
+ blinker==1.9.0
13
+ build==1.2.2.post1
14
+ cachetools==5.5.2
15
+ certifi==2025.1.31
16
+ charset-normalizer==3.4.1
17
+ chroma-hnswlib==0.7.6
18
+ chromadb==0.6.3
19
+ click==8.1.8
20
+ coloredlogs==15.0.1
21
+ dataclasses-json==0.6.7
22
+ Deprecated==1.2.18
23
+ distro==1.9.0
24
+ durationpy==0.9
25
+ exceptiongroup==1.2.2
26
+ fastapi==0.115.12
27
+ filelock==3.18.0
28
+ flatbuffers==25.2.10
29
+ frozenlist==1.5.0
30
+ fsspec==2025.3.2
31
+ gitdb==4.0.12
32
+ GitPython==3.1.44
33
+ google-auth==2.38.0
34
+ googleapis-common-protos==1.69.2
35
+ greenlet==3.1.1
36
+ groq==0.21.0
37
+ grpcio==1.71.0
38
+ h11==0.14.0
39
+ httpcore==1.0.7
40
+ httptools==0.6.4
41
+ httpx==0.28.1
42
+ httpx-sse==0.4.0
43
+ huggingface-hub==0.30.1
44
+ humanfriendly==10.0
45
+ idna==3.10
46
+ importlib_metadata==8.6.1
47
+ importlib_resources==6.5.2
48
+ Jinja2==3.1.6
49
+ joblib==1.4.2
50
+ jsonpatch==1.33
51
+ jsonpointer==3.0.0
52
+ jsonschema==4.23.0
53
+ jsonschema-specifications==2024.10.1
54
+ kubernetes==32.0.1
55
+ langchain==0.3.22
56
+ langchain-chroma==0.2.2
57
+ langchain-community==0.3.20
58
+ langchain-core==0.3.49
59
+ langchain-groq==0.3.2
60
+ langchain-huggingface==0.1.2
61
+ langchain-text-splitters==0.3.7
62
+ langsmith==0.3.22
63
+ markdown-it-py==3.0.0
64
+ MarkupSafe==3.0.2
65
+ marshmallow==3.26.1
66
+ mdurl==0.1.2
67
+ mmh3==5.1.0
68
+ monotonic==1.6
69
+ mpmath==1.3.0
70
+ multidict==6.3.1
71
+ mypy-extensions==1.0.0
72
+ narwhals==1.33.0
73
+ networkx==3.4.2
74
+ numpy==1.26.4
75
+ nvidia-cublas-cu12==12.4.5.8
76
+ nvidia-cuda-cupti-cu12==12.4.127
77
+ nvidia-cuda-nvrtc-cu12==12.4.127
78
+ nvidia-cuda-runtime-cu12==12.4.127
79
+ nvidia-cudnn-cu12==9.1.0.70
80
+ nvidia-cufft-cu12==11.2.1.3
81
+ nvidia-curand-cu12==10.3.5.147
82
+ nvidia-cusolver-cu12==11.6.1.9
83
+ nvidia-cusparse-cu12==12.3.1.170
84
+ nvidia-cusparselt-cu12==0.6.2
85
+ nvidia-nccl-cu12==2.21.5
86
+ nvidia-nvjitlink-cu12==12.4.127
87
+ nvidia-nvtx-cu12==12.4.127
88
+ oauthlib==3.2.2
89
+ onnxruntime==1.21.0
90
+ opentelemetry-api==1.31.1
91
+ opentelemetry-exporter-otlp-proto-common==1.31.1
92
+ opentelemetry-exporter-otlp-proto-grpc==1.31.1
93
+ opentelemetry-instrumentation==0.52b1
94
+ opentelemetry-instrumentation-asgi==0.52b1
95
+ opentelemetry-instrumentation-fastapi==0.52b1
96
+ opentelemetry-proto==1.31.1
97
+ opentelemetry-sdk==1.31.1
98
+ opentelemetry-semantic-conventions==0.52b1
99
+ opentelemetry-util-http==0.52b1
100
+ orjson==3.10.16
101
+ overrides==7.7.0
102
+ packaging==24.2
103
+ pandas==2.2.3
104
+ pillow==11.2.0
105
+ posthog==3.23.0
106
+ propcache==0.3.1
107
+ protobuf==5.29.4
108
+ pyarrow==19.0.1
109
+ pyasn1==0.6.1
110
+ pyasn1_modules==0.4.2
111
+ pydantic==2.11.1
112
+ pydantic-settings==2.8.1
113
+ pydantic_core==2.33.0
114
+ pydeck==0.9.1
115
+ Pygments==2.19.1
116
+ pypdf==5.4.0
117
+ PyPika==0.48.9
118
+ pyproject_hooks==1.2.0
119
+ python-dateutil==2.9.0.post0
120
+ python-dotenv==1.1.0
121
+ pytz==2025.2
122
+ PyYAML==6.0.2
123
+ referencing==0.36.2
124
+ regex==2024.11.6
125
+ requests==2.32.3
126
+ requests-oauthlib==2.0.0
127
+ requests-toolbelt==1.0.0
128
+ rich==14.0.0
129
+ rpds-py==0.24.0
130
+ rsa==4.9
131
+ safetensors==0.5.3
132
+ scikit-learn==1.6.1
133
+ scipy==1.15.2
134
+ sentence-transformers==4.0.1
135
+ shellingham==1.5.4
136
+ six==1.17.0
137
+ smmap==5.0.2
138
+ sniffio==1.3.1
139
+ SQLAlchemy==2.0.40
140
+ starlette==0.46.1
141
+ streamlit==1.44.1
142
+ sympy==1.13.1
143
+ tenacity==9.0.0
144
+ threadpoolctl==3.6.0
145
+ tiktoken==0.9.0
146
+ tokenizers==0.21.1
147
+ toml==0.10.2
148
+ tomli==2.2.1
149
+ torch==2.6.0
150
+ tornado==6.4.2
151
+ tqdm==4.67.1
152
+ transformers==4.50.3
153
+ triton==3.2.0
154
+ typer==0.15.2
155
+ typing-inspect==0.9.0
156
+ typing-inspection==0.4.0
157
+ typing_extensions==4.13.0
158
+ tzdata==2025.2
159
+ urllib3==2.3.0
160
+ uvicorn==0.34.0
161
+ uvloop==0.21.0
162
+ watchdog==6.0.0
163
+ watchfiles==1.0.4
164
+ websocket-client==1.8.0
165
+ websockets==15.0.1
166
+ wrapt==1.17.2
167
+ yarl==1.18.3
168
+ zipp==3.21.0
169
+ zstandard==0.23.0