|
import tiktoken |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_chroma import Chroma |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain.memory import ConversationSummaryBufferMemory |
|
from langchain_groq import ChatGroq |
|
import os |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
tokenizer = tiktoken.get_encoding('cl100k_base') |
|
FILE_NAMEs = os.listdir('data') |
|
|
|
SYSTEM_PROMPT = """ |
|
You are an AI-powered medical assistant trained to provide prescription recommendations based on user symptoms. Your responses should be accurate, safe, and aligned with general medical guidelines. |
|
When a user provides symptoms, follow these steps: |
|
1.Ask clarifying questions if needed to ensure accurate symptom understanding. |
|
2.Provide a probable condition or diagnosis based on symptoms. |
|
3.Recommend suitable over-the-counter or prescription medications (mentioning that a doctor's consultation is advised for prescriptions). |
|
4.Offer general care advice, such as lifestyle changes or home remedies. |
|
5.If symptoms indicate a severe or emergency condition, advise the user to seek immediate medical attention. |
|
Always be polite, professional, and ensure user safety in your responses. Avoid giving definitive diagnoses or prescriptions without medical consultation. |
|
context: {context} |
|
previous message summary: {previous_message_summary} |
|
""" |
|
|
|
human_template = "{question}" |
|
|
|
NLP_MODEL_NAME = "llama3-70b-8192" |
|
REASONING_MODEL_NAME = "mixtral-8x7b-32768" |
|
REASONING_MODEL_TEMPERATURE = 0 |
|
NLP_MODEL_TEMPERATURE = 0 |
|
NLP_MODEL_MAX_TOKENS = 5400 |
|
VECTOR_MAX_TOKENS = 100 |
|
VECTORS_TOKEN_OVERLAP_SIZE = 20 |
|
NUMBER_OF_VECTORS_FOR_RAG = 7 |
|
|
|
|
|
def tiktoken_len(text): |
|
tokens = tokenizer.encode(text, disallowed_special=()) |
|
return len(tokens) |
|
|
|
def get_vectorstore(): |
|
model_name = "BAAI/bge-small-en" |
|
model_kwargs = {"device": "cpu"} |
|
encode_kwargs = {"normalize_embeddings": True} |
|
hf = HuggingFaceEmbeddings( |
|
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs |
|
) |
|
persist_directory = "./chroma_db" |
|
all_splits = [] |
|
for file_name in FILE_NAMEs: |
|
if file_name.endswith(".pdf"): |
|
loader = PyPDFLoader(os.path.join("data", file_name)) |
|
data = loader.load()[0].page_content |
|
else: |
|
with open(os.path.join("data", file_name), "r") as f: |
|
data = f.read() |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=VECTOR_MAX_TOKENS, |
|
chunk_overlap=VECTORS_TOKEN_OVERLAP_SIZE, |
|
length_function=tiktoken_len, |
|
separators=["\n\n\n", "\n\n", "\n", " ", ""] |
|
) |
|
all_splits = all_splits + text_splitter.split_text(data) |
|
|
|
|
|
if os.path.exists(persist_directory): |
|
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=hf) |
|
else: |
|
vectorstore = Chroma.from_texts( |
|
texts=all_splits, embedding=hf, persist_directory=persist_directory |
|
) |
|
return vectorstore |
|
|
|
chat = ChatGroq(temperature=0, groq_api_key=os.getenv("GROQ_API_KEY"), model_name="llama3-8b-8192", streaming=True) |
|
rag_memory = ConversationSummaryBufferMemory(llm=chat, max_token_limit=3000) |
|
my_vector_store = get_vectorstore() |