Shreyas094 commited on
Commit
3cb16ec
·
verified ·
1 Parent(s): 85693d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -24
app.py CHANGED
@@ -22,7 +22,12 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_community.llms import HuggingFaceHub
23
  from langchain_core.documents import Document
24
  from sentence_transformers import SentenceTransformer
25
- from llama_parse import LlamaParse
 
 
 
 
 
26
 
27
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
28
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
@@ -378,10 +383,25 @@ def prepare_context(query: str, documents: List[Document], max_tokens: int) -> s
378
 
379
  return truncate_text(context, max_tokens)
380
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
382
  if not question:
383
  return "Please enter a question."
384
 
 
385
  model = get_model(temperature, top_p, repetition_penalty)
386
 
387
  # Update the chatbot's model
@@ -395,17 +415,14 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
395
  database = None
396
 
397
  max_attempts = 3
398
- max_input_tokens = 20000 # Leave room for the model's response
399
  max_output_tokens = 800
400
 
401
  if web_search:
402
  contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
403
 
404
- print(f"Contextualized question: {contextualized_question}")
405
- print(f"User Instructions: {user_instructions}")
406
-
407
  try:
408
- search_results = google_search(contextualized_question, num_results=5) # Increased from 3 to 5
409
  except Exception as e:
410
  print(f"Error in web search: {e}")
411
  return f"I apologize, but I encountered an error while searching for information: {str(e)}"
@@ -426,8 +443,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
426
 
427
  database.save_local("faiss_database")
428
 
429
- # Prepare context using reranking
430
- context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
431
 
432
  instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
433
 
@@ -443,13 +459,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
443
  Provide a concise and relevant answer to the question.
444
  """
445
 
446
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
447
-
448
- current_conv_context = truncate_text(chatbot.get_context(), max_input_tokens // 4) # Use quarter of max_input_tokens for conversation context
449
- current_topics = topics[:5] # Limit to top 5 topics
450
- current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()} # Limit to top 3 entities per type
451
 
452
- formatted_prompt = prompt_val.format(
453
  context=context_str,
454
  conv_context=current_conv_context,
455
  question=question,
@@ -461,12 +475,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
461
  formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
462
 
463
  try:
464
- full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
 
 
 
 
 
465
  answer = extract_answer(full_response, user_instructions)
466
  all_answers.append(answer)
467
  break
468
  except Exception as e:
469
- print(f"Error in generate_chunked_response: {e}")
470
  if attempt == max_attempts - 1:
471
  all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
472
 
@@ -490,11 +509,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
490
  if database is None:
491
  return "No documents available. Please upload PDF documents to answer questions."
492
 
493
- retriever = database.as_retriever(search_kwargs={"k": 10}) # Retrieve more documents for reranking
494
  relevant_docs = retriever.get_relevant_documents(question)
495
 
496
- # Prepare context using reranking
497
- context_str = prepare_context(question, relevant_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
498
 
499
  instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
500
 
@@ -507,18 +525,22 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
507
  Provide a summarized and direct answer to the question.
508
  """
509
 
510
- prompt_val = ChatPromptTemplate.from_template(prompt_template)
511
- formatted_prompt = prompt_val.format(context=context_str, question=question)
512
 
513
  if estimate_tokens(formatted_prompt) > max_input_tokens:
514
  formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
515
 
516
  try:
517
- full_response = generate_chunked_response(model, formatted_prompt, max_tokens=max_output_tokens)
 
 
 
 
 
518
  answer = extract_answer(full_response, user_instructions)
519
  return answer
520
  except Exception as e:
521
- print(f"Error in generate_chunked_response: {e}")
522
  if attempt == max_attempts - 1:
523
  return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
524
 
@@ -591,13 +613,14 @@ with gr.Blocks() as demo:
591
 
592
  enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
593
 
 
594
  def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
595
  answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
596
  history.append((question, answer))
597
  return "", history
598
 
599
  submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
600
-
601
  clear_button = gr.Button("Clear Cache")
602
  clear_output = gr.Textbox(label="Cache Status")
603
  clear_button.click(clear_cache, inputs=[], outputs=clear_output)
 
22
  from langchain_community.llms import HuggingFaceHub
23
  from langchain_core.documents import Document
24
  from sentence_transformers import SentenceTransformer
25
+ from llama_parse import
26
+ from llama_cpp import Llama
27
+ from llama_cpp_agent.llm_agent import LlamaCppAgent
28
+ from llama_cpp_agent.messages_formatter import MessagesFormatterType
29
+ from llama_cpp_agent.providers.llama_cpp_endpoint_provider import LlamaCppEndpointSettings
30
+
31
 
32
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
33
  llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
 
383
 
384
  return truncate_text(context, max_tokens)
385
 
386
+ # Initialize LlamaCppAgent
387
+ def initialize_llama_cpp_agent():
388
+ main_model = LlamaCppEndpointSettings(
389
+ completions_endpoint_url="http://127.0.0.1:8080/completion"
390
+ )
391
+ llama_cpp_agent = LlamaCppAgent(
392
+ main_model,
393
+ debug_output=False,
394
+ system_prompt="You are an AI assistant designed to help with RAG tasks.",
395
+ predefined_messages_formatter_type=MessagesFormatterType.CHATML
396
+ )
397
+ return llama_cpp_agent
398
+
399
+ # Modify the ask_question function to use LlamaCppAgent
400
  def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
401
  if not question:
402
  return "Please enter a question."
403
 
404
+ llama_cpp_agent = initialize_llama_cpp_agent()
405
  model = get_model(temperature, top_p, repetition_penalty)
406
 
407
  # Update the chatbot's model
 
415
  database = None
416
 
417
  max_attempts = 3
418
+ max_input_tokens = 20000
419
  max_output_tokens = 800
420
 
421
  if web_search:
422
  contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
423
 
 
 
 
424
  try:
425
+ search_results = google_search(contextualized_question, num_results=5)
426
  except Exception as e:
427
  print(f"Error in web search: {e}")
428
  return f"I apologize, but I encountered an error while searching for information: {str(e)}"
 
443
 
444
  database.save_local("faiss_database")
445
 
446
+ context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2)
 
447
 
448
  instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
449
 
 
459
  Provide a concise and relevant answer to the question.
460
  """
461
 
462
+ current_conv_context = truncate_text(chatbot.get_context(), max_input_tokens // 4)
463
+ current_topics = topics[:5]
464
+ current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()}
 
 
465
 
466
+ formatted_prompt = prompt_template.format(
467
  context=context_str,
468
  conv_context=current_conv_context,
469
  question=question,
 
475
  formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
476
 
477
  try:
478
+ # Use LlamaCppAgent for initial response generation
479
+ initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
480
+
481
+ # Use generate_chunked_response for further refinement if needed
482
+ full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
483
+
484
  answer = extract_answer(full_response, user_instructions)
485
  all_answers.append(answer)
486
  break
487
  except Exception as e:
488
+ print(f"Error in response generation: {e}")
489
  if attempt == max_attempts - 1:
490
  all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
491
 
 
509
  if database is None:
510
  return "No documents available. Please upload PDF documents to answer questions."
511
 
512
+ retriever = database.as_retriever(search_kwargs={"k": 5})
513
  relevant_docs = retriever.get_relevant_documents(question)
514
 
515
+ context_str = prepare_context(question, relevant_docs, max_input_tokens // 2)
 
516
 
517
  instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
518
 
 
525
  Provide a summarized and direct answer to the question.
526
  """
527
 
528
+ formatted_prompt = prompt_template.format(context=context_str, question=question)
 
529
 
530
  if estimate_tokens(formatted_prompt) > max_input_tokens:
531
  formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
532
 
533
  try:
534
+ # Use LlamaCppAgent for initial response generation
535
+ initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
536
+
537
+ # Use generate_chunked_response for further refinement if needed
538
+ full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
539
+
540
  answer = extract_answer(full_response, user_instructions)
541
  return answer
542
  except Exception as e:
543
+ print(f"Error in response generation: {e}")
544
  if attempt == max_attempts - 1:
545
  return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
546
 
 
613
 
614
  enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
615
 
616
+ # Update the chat function to use the modified ask_question function
617
  def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
618
  answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
619
  history.append((question, answer))
620
  return "", history
621
 
622
  submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
623
+
624
  clear_button = gr.Button("Clear Cache")
625
  clear_output = gr.Textbox(label="Cache Status")
626
  clear_button.click(clear_cache, inputs=[], outputs=clear_output)