Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -22,7 +22,12 @@ from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
22 |
from langchain_community.llms import HuggingFaceHub
|
23 |
from langchain_core.documents import Document
|
24 |
from sentence_transformers import SentenceTransformer
|
25 |
-
from llama_parse import
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
28 |
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
@@ -378,10 +383,25 @@ def prepare_context(query: str, documents: List[Document], max_tokens: int) -> s
|
|
378 |
|
379 |
return truncate_text(context, max_tokens)
|
380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
381 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
382 |
if not question:
|
383 |
return "Please enter a question."
|
384 |
|
|
|
385 |
model = get_model(temperature, top_p, repetition_penalty)
|
386 |
|
387 |
# Update the chatbot's model
|
@@ -395,17 +415,14 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
395 |
database = None
|
396 |
|
397 |
max_attempts = 3
|
398 |
-
max_input_tokens = 20000
|
399 |
max_output_tokens = 800
|
400 |
|
401 |
if web_search:
|
402 |
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
403 |
|
404 |
-
print(f"Contextualized question: {contextualized_question}")
|
405 |
-
print(f"User Instructions: {user_instructions}")
|
406 |
-
|
407 |
try:
|
408 |
-
search_results = google_search(contextualized_question, num_results=5)
|
409 |
except Exception as e:
|
410 |
print(f"Error in web search: {e}")
|
411 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
@@ -426,8 +443,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
426 |
|
427 |
database.save_local("faiss_database")
|
428 |
|
429 |
-
|
430 |
-
context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
431 |
|
432 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
433 |
|
@@ -443,13 +459,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
443 |
Provide a concise and relevant answer to the question.
|
444 |
"""
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
current_topics = topics[:5] # Limit to top 5 topics
|
450 |
-
current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()} # Limit to top 3 entities per type
|
451 |
|
452 |
-
formatted_prompt =
|
453 |
context=context_str,
|
454 |
conv_context=current_conv_context,
|
455 |
question=question,
|
@@ -461,12 +475,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
461 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
462 |
|
463 |
try:
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
465 |
answer = extract_answer(full_response, user_instructions)
|
466 |
all_answers.append(answer)
|
467 |
break
|
468 |
except Exception as e:
|
469 |
-
print(f"Error in
|
470 |
if attempt == max_attempts - 1:
|
471 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
472 |
|
@@ -490,11 +509,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
490 |
if database is None:
|
491 |
return "No documents available. Please upload PDF documents to answer questions."
|
492 |
|
493 |
-
retriever = database.as_retriever(search_kwargs={"k":
|
494 |
relevant_docs = retriever.get_relevant_documents(question)
|
495 |
|
496 |
-
|
497 |
-
context_str = prepare_context(question, relevant_docs, max_input_tokens // 2) # Use half of max_input_tokens for context
|
498 |
|
499 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
500 |
|
@@ -507,18 +525,22 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
507 |
Provide a summarized and direct answer to the question.
|
508 |
"""
|
509 |
|
510 |
-
|
511 |
-
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
512 |
|
513 |
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
514 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
515 |
|
516 |
try:
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
518 |
answer = extract_answer(full_response, user_instructions)
|
519 |
return answer
|
520 |
except Exception as e:
|
521 |
-
print(f"Error in
|
522 |
if attempt == max_attempts - 1:
|
523 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
524 |
|
@@ -591,13 +613,14 @@ with gr.Blocks() as demo:
|
|
591 |
|
592 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
593 |
|
|
|
594 |
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
595 |
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
596 |
history.append((question, answer))
|
597 |
return "", history
|
598 |
|
599 |
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
600 |
-
|
601 |
clear_button = gr.Button("Clear Cache")
|
602 |
clear_output = gr.Textbox(label="Cache Status")
|
603 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|
|
|
22 |
from langchain_community.llms import HuggingFaceHub
|
23 |
from langchain_core.documents import Document
|
24 |
from sentence_transformers import SentenceTransformer
|
25 |
+
from llama_parse import
|
26 |
+
from llama_cpp import Llama
|
27 |
+
from llama_cpp_agent.llm_agent import LlamaCppAgent
|
28 |
+
from llama_cpp_agent.messages_formatter import MessagesFormatterType
|
29 |
+
from llama_cpp_agent.providers.llama_cpp_endpoint_provider import LlamaCppEndpointSettings
|
30 |
+
|
31 |
|
32 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
33 |
llama_cloud_api_key = os.environ.get("LLAMA_CLOUD_API_KEY")
|
|
|
383 |
|
384 |
return truncate_text(context, max_tokens)
|
385 |
|
386 |
+
# Initialize LlamaCppAgent
|
387 |
+
def initialize_llama_cpp_agent():
|
388 |
+
main_model = LlamaCppEndpointSettings(
|
389 |
+
completions_endpoint_url="http://127.0.0.1:8080/completion"
|
390 |
+
)
|
391 |
+
llama_cpp_agent = LlamaCppAgent(
|
392 |
+
main_model,
|
393 |
+
debug_output=False,
|
394 |
+
system_prompt="You are an AI assistant designed to help with RAG tasks.",
|
395 |
+
predefined_messages_formatter_type=MessagesFormatterType.CHATML
|
396 |
+
)
|
397 |
+
return llama_cpp_agent
|
398 |
+
|
399 |
+
# Modify the ask_question function to use LlamaCppAgent
|
400 |
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
401 |
if not question:
|
402 |
return "Please enter a question."
|
403 |
|
404 |
+
llama_cpp_agent = initialize_llama_cpp_agent()
|
405 |
model = get_model(temperature, top_p, repetition_penalty)
|
406 |
|
407 |
# Update the chatbot's model
|
|
|
415 |
database = None
|
416 |
|
417 |
max_attempts = 3
|
418 |
+
max_input_tokens = 20000
|
419 |
max_output_tokens = 800
|
420 |
|
421 |
if web_search:
|
422 |
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
423 |
|
|
|
|
|
|
|
424 |
try:
|
425 |
+
search_results = google_search(contextualized_question, num_results=5)
|
426 |
except Exception as e:
|
427 |
print(f"Error in web search: {e}")
|
428 |
return f"I apologize, but I encountered an error while searching for information: {str(e)}"
|
|
|
443 |
|
444 |
database.save_local("faiss_database")
|
445 |
|
446 |
+
context_str = prepare_context(contextualized_question, web_docs, max_input_tokens // 2)
|
|
|
447 |
|
448 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
449 |
|
|
|
459 |
Provide a concise and relevant answer to the question.
|
460 |
"""
|
461 |
|
462 |
+
current_conv_context = truncate_text(chatbot.get_context(), max_input_tokens // 4)
|
463 |
+
current_topics = topics[:5]
|
464 |
+
current_entities = {k: list(v)[:3] for k, v in entity_tracker.items()}
|
|
|
|
|
465 |
|
466 |
+
formatted_prompt = prompt_template.format(
|
467 |
context=context_str,
|
468 |
conv_context=current_conv_context,
|
469 |
question=question,
|
|
|
475 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
476 |
|
477 |
try:
|
478 |
+
# Use LlamaCppAgent for initial response generation
|
479 |
+
initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
|
480 |
+
|
481 |
+
# Use generate_chunked_response for further refinement if needed
|
482 |
+
full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
|
483 |
+
|
484 |
answer = extract_answer(full_response, user_instructions)
|
485 |
all_answers.append(answer)
|
486 |
break
|
487 |
except Exception as e:
|
488 |
+
print(f"Error in response generation: {e}")
|
489 |
if attempt == max_attempts - 1:
|
490 |
all_answers.append(f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question.")
|
491 |
|
|
|
509 |
if database is None:
|
510 |
return "No documents available. Please upload PDF documents to answer questions."
|
511 |
|
512 |
+
retriever = database.as_retriever(search_kwargs={"k": 5})
|
513 |
relevant_docs = retriever.get_relevant_documents(question)
|
514 |
|
515 |
+
context_str = prepare_context(question, relevant_docs, max_input_tokens // 2)
|
|
|
516 |
|
517 |
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
518 |
|
|
|
525 |
Provide a summarized and direct answer to the question.
|
526 |
"""
|
527 |
|
528 |
+
formatted_prompt = prompt_template.format(context=context_str, question=question)
|
|
|
529 |
|
530 |
if estimate_tokens(formatted_prompt) > max_input_tokens:
|
531 |
formatted_prompt = truncate_text(formatted_prompt, max_input_tokens)
|
532 |
|
533 |
try:
|
534 |
+
# Use LlamaCppAgent for initial response generation
|
535 |
+
initial_response = llama_cpp_agent.get_chat_response(formatted_prompt, temperature=temperature)
|
536 |
+
|
537 |
+
# Use generate_chunked_response for further refinement if needed
|
538 |
+
full_response = generate_chunked_response(model, initial_response, max_tokens=max_output_tokens)
|
539 |
+
|
540 |
answer = extract_answer(full_response, user_instructions)
|
541 |
return answer
|
542 |
except Exception as e:
|
543 |
+
print(f"Error in response generation: {e}")
|
544 |
if attempt == max_attempts - 1:
|
545 |
return f"I apologize, but I encountered an error while generating the response. Please try again with a simpler question."
|
546 |
|
|
|
613 |
|
614 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
615 |
|
616 |
+
# Update the chat function to use the modified ask_question function
|
617 |
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
618 |
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
619 |
history.append((question, answer))
|
620 |
return "", history
|
621 |
|
622 |
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
623 |
+
|
624 |
clear_button = gr.Button("Clear Cache")
|
625 |
clear_output = gr.Textbox(label="Cache Status")
|
626 |
clear_button.click(clear_cache, inputs=[], outputs=clear_output)
|