File size: 3,530 Bytes
28aa065
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
from langchain.chains import RetrievalQA
from langchain.memory import ConversationBufferWindowMemory
from langchain.vectorstores import Pinecone
from langchain.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain.llms import HuggingFaceEndpoint
from langchain.prompts import PromptTemplate
from pinecone import Pinecone
from langchain_pinecone import PineconeVectorStore
from streamlit_chat import message
import re

def main():
    # Set your Hugging Face API token and Pinecone API key
    huggingfacehub_api_token = st.secrets["huggingfacehub_api_token"]
    pinecone_api_key = st.secrets["pinecone_api_key"]

    # Initialize embeddings
    embeddings = HuggingFaceInferenceAPIEmbeddings(
        api_key=huggingfacehub_api_token, model_name="sentence-transformers/all-MiniLM-l6-v2"
    )

    # Initialize Pinecone
    vectorstore = PineconeVectorStore(
        index_name="chatbot-law",
        embedding=embeddings, 
        pinecone_api_key=pinecone_api_key
    )

    # Define the LLM
    llm = HuggingFaceEndpoint(repo_id="togethercomputer/RedPajama-INCITE-Chat-3B-v1", huggingfacehub_api_token=huggingfacehub_api_token)

    # Define the prompt template
    prompt_template = """You are a Nigerian legal chatbot. Advise lawyers on questions regarding Nigerian law.
    Use the following piece of context to answer the question.
    If you don't know the answer, just say you don't know.
    Keep the answer within six sentences and never ask users to seek advise from a professional lawyer.

    Context: {context}
    Question: {question}

    Answer the question and provide additional helpful information, based on the pieces of information, if applicable.
    """

    prompt = PromptTemplate(
        template=prompt_template,
        input_variables=["context", "question"]
    )

    # Initialize memory
    memory = ConversationBufferWindowMemory(k=5)

    # Initialize the RetrievalQA chain with memory
    qa = RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vectorstore.as_retriever(),
        chain_type_kwargs={"prompt": prompt, "verbose": False},
        memory=memory
    )

    # Function to generate response
    def generate_response(user_input):
        response = qa({"query": user_input})
        # Remove any long dashes or unwanted characters from the response
        cleaned_response = re.sub(r"^\s*[-–—]+\s*", "", response['result'])
        cleaned_response = cleaned_response.replace("\n", " ")
        return cleaned_response.strip()

    # Set the title and default styling
    st.title("Nigerian Lawyer Chatbot")

    # Initialize session state for messages
    if 'messages' not in st.session_state:
        st.session_state.messages = []

    # Display the chat
    for i, msg in enumerate(st.session_state.messages):
        if msg["is_user"]:
            message(msg["content"], is_user=True, key=str(i), avatar_style="micah")
        else:
            message(msg["content"], is_user=False, key=str(i), avatar_style="bottts")

    # Handle user input
    user_input = st.chat_input("Ask a legal question:")

    if user_input:
        # Append user message and generate response
        st.session_state.messages.append({"content": user_input, "is_user": True})
        response = generate_response(user_input)
        st.session_state.messages.append({"content": response, "is_user": False})
        st.rerun()  # Refresh the app to display the new messages

if __name__ == "__main__":
    main()