Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +95 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain.chains import RetrievalQA
|
3 |
+
from langchain.memory import ConversationBufferWindowMemory
|
4 |
+
from langchain.vectorstores import Pinecone
|
5 |
+
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
|
6 |
+
from langchain.llms import HuggingFaceEndpoint
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from pinecone import Pinecone
|
9 |
+
from langchain_pinecone import PineconeVectorStore
|
10 |
+
from streamlit_chat import message
|
11 |
+
import re
|
12 |
+
|
13 |
+
def main():
|
14 |
+
# Set your Hugging Face API token and Pinecone API key
|
15 |
+
huggingfacehub_api_token = huggingfacehub_api_token
|
16 |
+
pinecone_api_key = pinecone_api_key
|
17 |
+
|
18 |
+
# Initialize embeddings
|
19 |
+
embeddings = HuggingFaceInferenceAPIEmbeddings(
|
20 |
+
api_key=huggingfacehub_api_token, model_name="sentence-transformers/all-MiniLM-l6-v2"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Initialize Pinecone
|
24 |
+
vectorstore = PineconeVectorStore(
|
25 |
+
index_name="chatbot-law",
|
26 |
+
embedding=embeddings,
|
27 |
+
pinecone_api_key=pinecone_api_key
|
28 |
+
)
|
29 |
+
|
30 |
+
# Define the LLM
|
31 |
+
llm = HuggingFaceEndpoint(repo_id="togethercomputer/RedPajama-INCITE-Chat-3B-v1", huggingfacehub_api_token=huggingfacehub_api_token)
|
32 |
+
|
33 |
+
# Define the prompt template
|
34 |
+
prompt_template = """You are a Nigerian legal chatbot. Advise lawyers on questions regarding Nigerian law.
|
35 |
+
Use the following piece of context to answer the question.
|
36 |
+
If you don't know the answer, just say you don't know.
|
37 |
+
Keep the answer within six sentences and never ask users to seek advise from a professional lawyer.
|
38 |
+
|
39 |
+
Context: {context}
|
40 |
+
Question: {question}
|
41 |
+
|
42 |
+
Answer the question and provide additional helpful information, based on the pieces of information, if applicable.
|
43 |
+
"""
|
44 |
+
|
45 |
+
prompt = PromptTemplate(
|
46 |
+
template=prompt_template,
|
47 |
+
input_variables=["context", "question"]
|
48 |
+
)
|
49 |
+
|
50 |
+
# Initialize memory
|
51 |
+
memory = ConversationBufferWindowMemory(k=5)
|
52 |
+
|
53 |
+
# Initialize the RetrievalQA chain with memory
|
54 |
+
qa = RetrievalQA.from_chain_type(
|
55 |
+
llm=llm,
|
56 |
+
chain_type="stuff",
|
57 |
+
retriever=vectorstore.as_retriever(),
|
58 |
+
chain_type_kwargs={"prompt": prompt, "verbose": False},
|
59 |
+
memory=memory
|
60 |
+
)
|
61 |
+
|
62 |
+
# Function to generate response
|
63 |
+
def generate_response(user_input):
|
64 |
+
response = qa({"query": user_input})
|
65 |
+
# Remove any long dashes or unwanted characters from the response
|
66 |
+
cleaned_response = re.sub(r"^\s*[-–—]+\s*", "", response['result'])
|
67 |
+
cleaned_response = cleaned_response.replace("\n", " ")
|
68 |
+
return cleaned_response.strip()
|
69 |
+
|
70 |
+
# Set the title and default styling
|
71 |
+
st.title("Nigerian Lawyer Chatbot")
|
72 |
+
|
73 |
+
# Initialize session state for messages
|
74 |
+
if 'messages' not in st.session_state:
|
75 |
+
st.session_state.messages = []
|
76 |
+
|
77 |
+
# Display the chat
|
78 |
+
for i, msg in enumerate(st.session_state.messages):
|
79 |
+
if msg["is_user"]:
|
80 |
+
message(msg["content"], is_user=True, key=str(i), avatar_style="micah")
|
81 |
+
else:
|
82 |
+
message(msg["content"], is_user=False, key=str(i), avatar_style="bottts")
|
83 |
+
|
84 |
+
# Handle user input
|
85 |
+
user_input = st.chat_input("Ask a legal question:")
|
86 |
+
|
87 |
+
if user_input:
|
88 |
+
# Append user message and generate response
|
89 |
+
st.session_state.messages.append({"content": user_input, "is_user": True})
|
90 |
+
response = generate_response(user_input)
|
91 |
+
st.session_state.messages.append({"content": response, "is_user": False})
|
92 |
+
st.rerun() # Refresh the app to display the new messages
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit
|
2 |
+
langchain
|
3 |
+
pinecone-client
|
4 |
+
sentence-transformers
|
5 |
+
langchain-pinecone
|
6 |
+
langchain_community
|
7 |
+
langchain_huggingface
|
8 |
+
streamlit_chat
|