vishwask commited on
Commit
2b9fe40
·
verified ·
1 Parent(s): 531bd51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -16
app.py CHANGED
@@ -10,6 +10,7 @@ from langchain.llms import HuggingFacePipeline
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
 
13
 
14
  from pathlib import Path
15
  import chromadb
@@ -127,24 +128,13 @@ def initialize_llmchain(temperature, max_tokens, top_k, vector_db, progress=gr.P
127
  "load_in_8bit": True})
128
 
129
  progress(0.75, desc="Defining buffer memory...")
130
- memory = ConversationBufferMemory(
131
- memory_key="chat_history",
132
- output_key='answer',
133
- return_messages=True
134
- )
135
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
136
  retriever=vector_db.as_retriever()
137
  progress(0.8, desc="Defining retrieval chain...")
138
- qa_chain = ConversationalRetrievalChain.from_llm(
139
- llm,
140
- retriever=retriever,
141
- chain_type="stuff",
142
- memory=memory,
143
- # combine_docs_chain_kwargs={"prompt": your_prompt})
144
- return_source_documents=True,
145
- #return_generated_question=False,
146
- verbose=False,
147
- )
148
  progress(0.9, desc="Done!")
149
  return qa_chain
150
 
@@ -269,7 +259,7 @@ def demo():
269
  with gr.Row():
270
  slider_temperature = gr.Slider(value = 0.1,visible=False)
271
  with gr.Row():
272
- slider_maxtokens = gr.Slider(value = 1000, visible=False)
273
  with gr.Row():
274
  slider_topk = gr.Slider(value = 3, visible=False)
275
 
 
10
  from langchain.chains import ConversationChain
11
  from langchain.memory import ConversationBufferMemory
12
  from langchain.llms import HuggingFaceHub
13
+ from langchain.memory import ConversationTokenBufferMemory
14
 
15
  from pathlib import Path
16
  import chromadb
 
128
  "load_in_8bit": True})
129
 
130
  progress(0.75, desc="Defining buffer memory...")
131
+ #memory = ConversationBufferMemory(memory_key="chat_history",output_key='answer',return_messages=True)
132
+ memory = ConversationTokenBufferMemory(llm = llm, max_token_limit=100)
 
 
 
133
  # retriever=vector_db.as_retriever(search_type="similarity", search_kwargs={'k': 3})
134
  retriever=vector_db.as_retriever()
135
  progress(0.8, desc="Defining retrieval chain...")
136
+ qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff",
137
+ memory=memory,return_source_documents=True,verbose=False)
 
 
 
 
 
 
 
 
138
  progress(0.9, desc="Done!")
139
  return qa_chain
140
 
 
259
  with gr.Row():
260
  slider_temperature = gr.Slider(value = 0.1,visible=False)
261
  with gr.Row():
262
+ slider_maxtokens = gr.Slider(value = 4000, visible=False)
263
  with gr.Row():
264
  slider_topk = gr.Slider(value = 3, visible=False)
265