Pijush2023 commited on
Commit
c7cfbcf
·
verified ·
1 Parent(s): c8f1081

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -18
app.py CHANGED
@@ -121,20 +121,19 @@ graph = Neo4jGraph(
121
  password="XCSXe1Jl_gjyJqoBGXDqY1UrfgDc4Z_RT5YGrxPAy-g"
122
  )
123
 
124
- dataset_name = "Pijush2023/birmindata07312024"
125
- page_content_column = 'events_description'
126
- loader = HuggingFaceDatasetLoader(dataset_name, page_content_column)
127
- data = loader.load()
 
 
128
 
129
- text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
130
- documents = text_splitter.split_documents(data)
131
 
132
- embeddings = OpenAIEmbeddings(api_key=os.environ['OPENAI_API_KEY'])
133
- llm = ChatOpenAI(temperature=0, model='gpt-4o', api_key=os.environ['OPENAI_API_KEY'])
134
-
135
- llm_transformer = LLMGraphTransformer(llm=llm)
136
- graph_documents = llm_transformer.convert_to_graph_documents(documents)
137
- graph.add_graph_documents(graph_documents, baseEntityLabel=True, include_source=True)
138
 
139
  class Entities(BaseModel):
140
  names: list[str] = Field(..., description="All the person, organization, or business entities that appear in the text")
@@ -243,19 +242,23 @@ chain_neo4j = (
243
  def generate_answer(message, choice, retrieval_mode):
244
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
245
 
 
 
246
  if retrieval_mode == "Vector":
247
- qa_chain = build_qa_chain(QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2)
248
- agent = initialize_agent_with_prompt(QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2)
249
- response = agent(message)
 
 
 
 
250
  elif retrieval_mode == "Knowledge-Graph":
251
  response = chain_neo4j.invoke({"question": message})
252
  else:
253
  response = "Invalid retrieval mode selected."
254
 
255
- addresses = extract_addresses(response['output'])
256
- return response['output'], addresses
257
 
258
- # The rest of your Gradio code...
259
  def bot(history, choice, tts_choice, retrieval_mode):
260
  if not history:
261
  return history
 
121
  password="XCSXe1Jl_gjyJqoBGXDqY1UrfgDc4Z_RT5YGrxPAy-g"
122
  )
123
 
124
+ # Avoid pushing the graph documents to Neo4j every time
125
+ # Only push the documents once and comment the code below after the initial push
126
+ # dataset_name = "Pijush2023/birmindata07312024"
127
+ # page_content_column = 'events_description'
128
+ # loader = HuggingFaceDatasetLoader(dataset_name, page_content_column)
129
+ # data = loader.load()
130
 
131
+ # text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=50)
132
+ # documents = text_splitter.split_documents(data)
133
 
134
+ # llm_transformer = LLMGraphTransformer(llm=llm)
135
+ # graph_documents = llm_transformer.convert_to_graph_documents(documents)
136
+ # graph.add_graph_documents(graph_documents, baseEntityLabel=True, include_source=True)
 
 
 
137
 
138
  class Entities(BaseModel):
139
  names: list[str] = Field(..., description="All the person, organization, or business entities that appear in the text")
 
242
  def generate_answer(message, choice, retrieval_mode):
243
  logging.debug(f"generate_answer called with choice: {choice} and retrieval_mode: {retrieval_mode}")
244
 
245
+ prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
246
+
247
  if retrieval_mode == "Vector":
248
+ qa_chain = RetrievalQA.from_chain_type(
249
+ llm=chat_model,
250
+ chain_type="stuff",
251
+ retriever=retriever,
252
+ chain_type_kwargs={"prompt": prompt_template}
253
+ )
254
+ response = qa_chain({"query": message})
255
  elif retrieval_mode == "Knowledge-Graph":
256
  response = chain_neo4j.invoke({"question": message})
257
  else:
258
  response = "Invalid retrieval mode selected."
259
 
260
+ return response['output'], extract_addresses(response['output'])
 
261
 
 
262
  def bot(history, choice, tts_choice, retrieval_mode):
263
  if not history:
264
  return history