anasmkh commited on
Commit
9a9542a
·
verified ·
1 Parent(s): 1e9c206

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains import RetrievalQA
2
+ from langchain.chat_models import ChatOpenAI
3
+ from IPython.display import display, Markdown
4
+ from langchain.llms import OpenAI
5
+ from langchain.memory import ConversationBufferMemory
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.text_splitter import CharacterTextSplitter
8
+ from langchain.indexes import VectorstoreIndexCreator
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.embeddings import OpenAIEmbeddings
11
+ from langchain_core.vectorstores import InMemoryVectorStore
12
+ from langchain.vectorstores import FAISS
13
+ from langchain.retrievers import BM25Retriever,EnsembleRetriever
14
+ from langchain_core.prompts import ChatPromptTemplate
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain.schema.runnable import RunnablePassthrough
17
+ import gradio as gr
18
+ import os
19
+
20
+ pdf_folder_path = "files"
21
+
22
+ documents = []
23
+ for filename in os.listdir(pdf_folder_path):
24
+ if filename.endswith(".pdf"):
25
+ file_path = os.path.join(pdf_folder_path, filename)
26
+ loader = PyPDFLoader(file_path)
27
+ documents.extend(loader.load())
28
+
29
+ text_splitter = CharacterTextSplitter()
30
+ text_splits=text_splitter.split_documents(documents)
31
+
32
+
33
+ openai_api_key = os.genenv("OPENAI_API_KEY")
34
+ openai_api_key = openai_api_key
35
+
36
+
37
+
38
+ embeddings = OpenAIEmbeddings()
39
+
40
+ vector_store = FAISS.from_documents(documents, embeddings)
41
+
42
+ retriever_vectordb = vector_store.as_retriever(search_kwargs={"k": 5})
43
+ keyword_retriever = BM25Retriever.from_documents(text_splits)
44
+ keyword_retriever.k = 5
45
+ ensemble_retriever = EnsembleRetriever(retrievers=[retriever_vectordb,keyword_retriever],
46
+ weights=[0.5, 0.5])
47
+
48
+
49
+
50
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.4, api_key=adminkey)
51
+
52
+ memory = ConversationBufferMemory(
53
+ memory_key="chat_history",
54
+ input_key="question" ,
55
+ return_messages=True
56
+ )
57
+
58
+
59
+ conversation_chain = ConversationalRetrievalChain.from_llm(
60
+ retriever=ensemble_retriever,
61
+ llm=llm,
62
+ memory=memory,
63
+ verbose=False
64
+ )
65
+
66
+
67
+ template = """
68
+ <|system|>>
69
+ You are an AI Assistant that follows instructions extremely well.
70
+ Please be truthful and give direct answers. Please tell 'I don't know' if user query is not in CONTEXT
71
+
72
+ CONTEXT: {context}
73
+ </s>
74
+ <|user|>
75
+ {query}
76
+ </s>
77
+ <|assistant|>
78
+ """
79
+
80
+ prompt = ChatPromptTemplate.from_template(template)
81
+ output_parser = StrOutputParser()
82
+
83
+ chain = (
84
+ {"context": conversation_chain, "query": RunnablePassthrough()}
85
+ | prompt
86
+ | llm
87
+ | output_parser
88
+ )
89
+
90
+
91
+
92
+ def chat_with_ai(user_input, chat_history):
93
+ response = chain.invoke(user_input)
94
+
95
+ chat_history.append((user_input, str(response)))
96
+
97
+ return chat_history, ""
98
+
99
+
100
+ def gradio_chatbot():
101
+ with gr.Blocks() as demo:
102
+ gr.Markdown("# Chat Interface for LlamaIndex")
103
+
104
+ chatbot = gr.Chatbot(label="Langchain Chatbot")
105
+ user_input = gr.Textbox(
106
+ placeholder="Ask a question...", label="Enter your question"
107
+ )
108
+
109
+ submit_button = gr.Button("Send")
110
+
111
+ chat_history = gr.State([])
112
+
113
+ submit_button.click(chat_with_ai, inputs=[user_input, chat_history], outputs=[chatbot, user_input])
114
+
115
+ user_input.submit(chat_with_ai, inputs=[user_input, chat_history], outputs=[chatbot, user_input])
116
+
117
+ return demo
118
+
119
+ gradio_chatbot().launch(debug=True)