Pijush2023 commited on
Commit
8a9de56
·
verified ·
1 Parent(s): 2fbb8d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -360,8 +360,14 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
360
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
361
 
362
  if retrieval_mode == "VDB":
 
 
 
 
 
 
363
  if selected_model == chat_model:
364
- # Use Langchain with GPT-4o
365
  qa_chain = RetrievalQA.from_chain_type(
366
  llm=chat_model,
367
  chain_type="stuff",
@@ -370,15 +376,17 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
370
  )
371
  response = qa_chain({"query": message})
372
  return response['result'], extract_addresses(response['result'])
 
373
  elif selected_model == phi_pipe:
374
- # Directly use the Phi-3.5 model for text generation
375
- response = selected_model(message, **{
376
- "max_new_tokens": 500,
377
  "return_full_text": False,
378
- "temperature": 0.0,
379
- "do_sample": False,
380
  })[0]['generated_text']
381
  return response, extract_addresses(response)
 
382
  elif retrieval_mode == "KGF":
383
  response = chain_neo4j.invoke({"question": message})
384
  return response, extract_addresses(response)
 
360
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
361
 
362
  if retrieval_mode == "VDB":
363
+ # Retrieve context from the vector database
364
+ context = retriever.get_relevant_documents(message)
365
+
366
+ # Format the prompt
367
+ prompt = prompt_template.format(context=context, question=message)
368
+
369
  if selected_model == chat_model:
370
+ # Use GPT-4o with Langchain
371
  qa_chain = RetrievalQA.from_chain_type(
372
  llm=chat_model,
373
  chain_type="stuff",
 
376
  )
377
  response = qa_chain({"query": message})
378
  return response['result'], extract_addresses(response['result'])
379
+
380
  elif selected_model == phi_pipe:
381
+ # Use Phi-3.5 directly with the formatted prompt
382
+ response = selected_model(prompt, **{
383
+ "max_new_tokens": 300, # Limit the tokens for faster generation
384
  "return_full_text": False,
385
+ "temperature": 0.5, # Adjust temperature for more consistent answers
386
+ "do_sample": True,
387
  })[0]['generated_text']
388
  return response, extract_addresses(response)
389
+
390
  elif retrieval_mode == "KGF":
391
  response = chain_neo4j.invoke({"question": message})
392
  return response, extract_addresses(response)