Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain.chains import RetrievalQA | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.vectorstores import Pinecone | |
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings | |
from langchain.llms import HuggingFaceEndpoint | |
from langchain.prompts import PromptTemplate | |
from pinecone import Pinecone | |
from langchain_pinecone import PineconeVectorStore | |
from streamlit_chat import message | |
import re | |
def main(): | |
# Set your Hugging Face API token and Pinecone API key | |
huggingfacehub_api_token = st.secrets["huggingfacehub_api_token"] | |
pinecone_api_key = st.secrets["pinecone_api_key"] | |
# Initialize embeddings | |
embeddings = HuggingFaceInferenceAPIEmbeddings( | |
api_key=huggingfacehub_api_token, model_name="sentence-transformers/all-MiniLM-l6-v2" | |
) | |
# Initialize Pinecone | |
vectorstore = PineconeVectorStore( | |
index_name="chatbot-law", | |
embedding=embeddings, | |
pinecone_api_key=pinecone_api_key | |
) | |
# Define the LLM | |
llm = HuggingFaceEndpoint(repo_id="togethercomputer/RedPajama-INCITE-Chat-3B-v1", huggingfacehub_api_token=huggingfacehub_api_token) | |
# Define the prompt template | |
prompt_template = """You are a Nigerian legal chatbot. Advise lawyers on questions regarding Nigerian law. | |
Use the following piece of context to answer the question. | |
If you don't know the answer, just say you don't know. | |
Keep the answer within six sentences and never ask users to seek advise from a professional lawyer. | |
Context: {context} | |
Question: {question} | |
Answer the question and provide additional helpful information, based on the pieces of information, if applicable. | |
""" | |
prompt = PromptTemplate( | |
template=prompt_template, | |
input_variables=["context", "question"] | |
) | |
# Initialize memory | |
memory = ConversationBufferWindowMemory(k=5) | |
# Initialize the RetrievalQA chain with memory | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=vectorstore.as_retriever(), | |
chain_type_kwargs={"prompt": prompt, "verbose": False}, | |
memory=memory | |
) | |
# Function to generate response | |
def generate_response(user_input): | |
response = qa({"query": user_input}) | |
# Remove any long dashes or unwanted characters from the response | |
cleaned_response = re.sub(r"^\s*[-ββ]+\s*", "", response['result']) | |
cleaned_response = cleaned_response.replace("\n", " ") | |
return cleaned_response.strip() | |
# Set the title and default styling | |
st.title("Nigerian Lawyer Chatbot") | |
# Initialize session state for messages | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
# Display the chat | |
for i, msg in enumerate(st.session_state.messages): | |
if msg["is_user"]: | |
message(msg["content"], is_user=True, key=str(i), avatar_style="micah") | |
else: | |
message(msg["content"], is_user=False, key=str(i), avatar_style="bottts") | |
# Handle user input | |
user_input = st.chat_input("Ask a legal question:") | |
if user_input: | |
# Append user message and generate response | |
st.session_state.messages.append({"content": user_input, "is_user": True}) | |
response = generate_response(user_input) | |
st.session_state.messages.append({"content": response, "is_user": False}) | |
st.rerun() # Refresh the app to display the new messages | |
if __name__ == "__main__": | |
main() | |