Spaces:
Sleeping
Sleeping
import streamlit as st | |
import random | |
from app_config import SYSTEM_PROMPT, NLP_MODEL_NAME, NUMBER_OF_VECTORS_FOR_RAG, NLP_MODEL_TEMPERATURE, NLP_MODEL_MAX_TOKENS, VECTOR_MAX_TOKENS,my_vector_store,chat,tiktoken_len | |
from langchain.memory import ConversationSummaryBufferMemory | |
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage | |
from langchain.chains.summarize import load_summarize_chain | |
from langchain.prompts import PromptTemplate | |
from langchain_groq import ChatGroq | |
from dotenv import load_dotenv | |
from pathlib import Path | |
import os | |
env_path = Path('.') / '.env' | |
load_dotenv(dotenv_path=env_path) | |
def response_generator(prompt: str) -> str: | |
"""this function can be used for general quetion answers which are related to tyrex and tyre recycling | |
Args: | |
prompt (string): user query | |
Returns: | |
string: answer of the query | |
""" | |
try: | |
retriever = st.session_state.retriever | |
docs = retriever.invoke(prompt) | |
my_context = [doc.page_content for doc in docs] | |
my_context = '\n\n'.join(my_context) | |
system_message = SystemMessage(content = SYSTEM_PROMPT.format(context=my_context, previous_message_summary=st.session_state.rag_memory.moving_summary_buffer)) | |
print(system_message) | |
chat_messages = (system_message + st.session_state.rag_memory.chat_memory.messages + HumanMessage(content=prompt)).messages | |
print("total tokens: ", tiktoken_len(str(chat_messages))) | |
# print("my_context*********",my_context) | |
response = st.session_state.llm.invoke(chat_messages) | |
return response.content | |
except Exception as error: | |
print(error, "ERROR") | |
return "Oops! something went wrong, please try again." | |
st.markdown( | |
""" | |
<style> | |
.st-emotion-cache-janbn0 { | |
flex-direction: row-reverse; | |
text-align: right; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# When user gives input | |
print("SYSTEM MESSAGE") | |
if "messages" not in st.session_state: | |
st.session_state.messages=[{"role": "system", "content": SYSTEM_PROMPT}] | |
print("SYSTEM MODEL") | |
if "llm" not in st.session_state: | |
st.session_state.llm = ChatGroq(temperature=NLP_MODEL_TEMPERATURE, groq_api_key=str(os.getenv('GROQ_API_KEY')), model_name=NLP_MODEL_NAME) | |
print("rag") | |
if "rag_memory" not in st.session_state: | |
st.session_state.rag_memory = ConversationSummaryBufferMemory(llm=st.session_state.llm, max_token_limit= 5000) | |
print("retrival") | |
if "retriever" not in st.session_state: | |
# vector_store = get_vectorstore_with_doc_from_pdf('GPT OUTPUT.pdf') | |
st.session_state.retriever = my_vector_store.as_retriever(k=NUMBER_OF_VECTORS_FOR_RAG) | |
st.title("Insurance Bot") | |
print("container") | |
# Display chat messages from history | |
container = st.container(height=600) | |
for message in st.session_state.messages: | |
if message["role"] != "system": | |
with container.chat_message(message["role"]): | |
st.write(message["content"]) | |
if prompt := st.chat_input("Enter your query here... "): | |
with container.chat_message("user"): | |
st.write(prompt) | |
st.session_state.messages.append({"role":"user" , "content":prompt}) | |
with container.chat_message("assistant"): | |
response = response_generator(prompt=prompt) | |
print("******************************************************** Response ********************************************************") | |
print("MY RESPONSE IS:", response) | |
st.write(response) | |
print("Response is:", response) | |
st.session_state.rag_memory.save_context({'input': prompt}, {'output': response}) | |
st.session_state.messages.append({"role":"assistant" , "content":response}) |