Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -352,7 +352,7 @@ def estimate_tokens(text):
|
|
352 |
# Rough estimate: 1 token ~= 4 characters
|
353 |
return len(text) // 4
|
354 |
|
355 |
-
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot):
|
356 |
if not question:
|
357 |
return "Please enter a question."
|
358 |
|
@@ -368,16 +368,15 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
368 |
else:
|
369 |
database = None
|
370 |
|
371 |
-
max_attempts = 3
|
372 |
context_reduction_factor = 0.7
|
373 |
-
max_tokens = 32000
|
374 |
|
375 |
if web_search:
|
376 |
-
contextualized_question, topics, entity_tracker,
|
377 |
|
378 |
-
# Log the contextualized question and instructions separately for debugging
|
379 |
print(f"Contextualized question: {contextualized_question}")
|
380 |
-
print(f"Instructions: {
|
381 |
|
382 |
try:
|
383 |
search_results = google_search(contextualized_question, num_results=3)
|
@@ -403,7 +402,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
403 |
|
404 |
context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
405 |
|
406 |
-
instruction_prompt = f"User Instructions: {
|
407 |
|
408 |
prompt_template = f"""
|
409 |
Answer the question based on the following web search results, conversation context, entity information, and user instructions:
|
@@ -419,7 +418,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
419 |
|
420 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
421 |
|
422 |
-
# Start with full context and progressively reduce if necessary
|
423 |
current_context = context_str
|
424 |
current_conv_context = chatbot.get_context()
|
425 |
current_topics = topics
|
@@ -434,13 +432,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
434 |
entities=json.dumps(current_entities)
|
435 |
)
|
436 |
|
437 |
-
# Estimate token count (rough estimate)
|
438 |
estimated_tokens = len(formatted_prompt) // 4
|
439 |
|
440 |
-
if estimated_tokens <= max_tokens - 1000:
|
441 |
break
|
442 |
|
443 |
-
# Reduce context if estimated token count is too high
|
444 |
current_context = current_context[:int(len(current_context) * context_reduction_factor)]
|
445 |
current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
|
446 |
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
@@ -450,7 +446,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
450 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
451 |
|
452 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
453 |
-
answer = extract_answer(full_response,
|
454 |
all_answers.append(answer)
|
455 |
break
|
456 |
|
@@ -469,12 +465,11 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
469 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
470 |
answer += sources_section
|
471 |
|
472 |
-
# Update chatbot context with the answer
|
473 |
chatbot.add_to_history(answer)
|
474 |
|
475 |
return answer
|
476 |
|
477 |
-
|
478 |
for attempt in range(max_attempts):
|
479 |
try:
|
480 |
if database is None:
|
@@ -484,11 +479,14 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
484 |
relevant_docs = retriever.get_relevant_documents(question)
|
485 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
486 |
|
487 |
-
|
|
|
|
|
488 |
Answer the question based on the following context from the PDF document:
|
489 |
Context:
|
490 |
-
{context}
|
491 |
-
Question: {question}
|
|
|
492 |
Provide a summarized and direct answer to the question.
|
493 |
"""
|
494 |
|
@@ -498,17 +496,16 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
498 |
|
499 |
estimated_tokens = estimate_tokens(formatted_prompt)
|
500 |
|
501 |
-
if estimated_tokens <= max_tokens - 1000:
|
502 |
break
|
503 |
|
504 |
-
# Reduce context if estimated token count is too high
|
505 |
context_str = context_str[:int(len(context_str) * context_reduction_factor)]
|
506 |
|
507 |
if len(context_str) < 100:
|
508 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
509 |
|
510 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
511 |
-
answer = extract_answer(full_response)
|
512 |
|
513 |
return answer
|
514 |
|
@@ -524,6 +521,7 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search, c
|
|
524 |
|
525 |
return "An unexpected error occurred. Please try again later."
|
526 |
|
|
|
527 |
def extract_answer(full_response, instructions=None):
|
528 |
answer_patterns = [
|
529 |
r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
|
@@ -575,6 +573,7 @@ with gr.Blocks() as demo:
|
|
575 |
with gr.Column(scale=2):
|
576 |
chatbot = gr.Chatbot(label="Conversation")
|
577 |
question_input = gr.Textbox(label="Ask a question")
|
|
|
578 |
submit_button = gr.Button("Submit")
|
579 |
with gr.Column(scale=1):
|
580 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
@@ -584,12 +583,12 @@ with gr.Blocks() as demo:
|
|
584 |
|
585 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
586 |
|
587 |
-
def chat(question, history, temperature, top_p, repetition_penalty, web_search):
|
588 |
-
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot)
|
589 |
history.append((question, answer))
|
590 |
return "", history
|
591 |
|
592 |
-
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox], outputs=[question_input, chatbot])
|
593 |
|
594 |
clear_button = gr.Button("Clear Cache")
|
595 |
clear_output = gr.Textbox(label="Cache Status")
|
|
|
352 |
# Rough estimate: 1 token ~= 4 characters
|
353 |
return len(text) // 4
|
354 |
|
355 |
+
def ask_question(question, temperature, top_p, repetition_penalty, web_search, chatbot, user_instructions):
|
356 |
if not question:
|
357 |
return "Please enter a question."
|
358 |
|
|
|
368 |
else:
|
369 |
database = None
|
370 |
|
371 |
+
max_attempts = 3
|
372 |
context_reduction_factor = 0.7
|
373 |
+
max_tokens = 32000
|
374 |
|
375 |
if web_search:
|
376 |
+
contextualized_question, topics, entity_tracker, _ = chatbot.process_question(question)
|
377 |
|
|
|
378 |
print(f"Contextualized question: {contextualized_question}")
|
379 |
+
print(f"User Instructions: {user_instructions}")
|
380 |
|
381 |
try:
|
382 |
search_results = google_search(contextualized_question, num_results=3)
|
|
|
402 |
|
403 |
context_str = "\n".join([f"Source: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
404 |
|
405 |
+
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
406 |
|
407 |
prompt_template = f"""
|
408 |
Answer the question based on the following web search results, conversation context, entity information, and user instructions:
|
|
|
418 |
|
419 |
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
420 |
|
|
|
421 |
current_context = context_str
|
422 |
current_conv_context = chatbot.get_context()
|
423 |
current_topics = topics
|
|
|
432 |
entities=json.dumps(current_entities)
|
433 |
)
|
434 |
|
|
|
435 |
estimated_tokens = len(formatted_prompt) // 4
|
436 |
|
437 |
+
if estimated_tokens <= max_tokens - 1000:
|
438 |
break
|
439 |
|
|
|
440 |
current_context = current_context[:int(len(current_context) * context_reduction_factor)]
|
441 |
current_conv_context = current_conv_context[:int(len(current_conv_context) * context_reduction_factor)]
|
442 |
current_topics = current_topics[:max(1, int(len(current_topics) * context_reduction_factor))]
|
|
|
446 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
447 |
|
448 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
449 |
+
answer = extract_answer(full_response, user_instructions)
|
450 |
all_answers.append(answer)
|
451 |
break
|
452 |
|
|
|
465 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
466 |
answer += sources_section
|
467 |
|
|
|
468 |
chatbot.add_to_history(answer)
|
469 |
|
470 |
return answer
|
471 |
|
472 |
+
else: # PDF document chat
|
473 |
for attempt in range(max_attempts):
|
474 |
try:
|
475 |
if database is None:
|
|
|
479 |
relevant_docs = retriever.get_relevant_documents(question)
|
480 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
481 |
|
482 |
+
instruction_prompt = f"User Instructions: {user_instructions}\n" if user_instructions else ""
|
483 |
+
|
484 |
+
prompt_template = f"""
|
485 |
Answer the question based on the following context from the PDF document:
|
486 |
Context:
|
487 |
+
{{context}}
|
488 |
+
Question: {{question}}
|
489 |
+
{instruction_prompt}
|
490 |
Provide a summarized and direct answer to the question.
|
491 |
"""
|
492 |
|
|
|
496 |
|
497 |
estimated_tokens = estimate_tokens(formatted_prompt)
|
498 |
|
499 |
+
if estimated_tokens <= max_tokens - 1000:
|
500 |
break
|
501 |
|
|
|
502 |
context_str = context_str[:int(len(context_str) * context_reduction_factor)]
|
503 |
|
504 |
if len(context_str) < 100:
|
505 |
raise ValueError("Context reduced too much. Unable to process the query.")
|
506 |
|
507 |
full_response = generate_chunked_response(model, formatted_prompt, max_tokens=1000)
|
508 |
+
answer = extract_answer(full_response, user_instructions)
|
509 |
|
510 |
return answer
|
511 |
|
|
|
521 |
|
522 |
return "An unexpected error occurred. Please try again later."
|
523 |
|
524 |
+
|
525 |
def extract_answer(full_response, instructions=None):
|
526 |
answer_patterns = [
|
527 |
r"Provide a concise and direct answer to the question without mentioning the web search or these instructions:",
|
|
|
573 |
with gr.Column(scale=2):
|
574 |
chatbot = gr.Chatbot(label="Conversation")
|
575 |
question_input = gr.Textbox(label="Ask a question")
|
576 |
+
instructions_input = gr.Textbox(label="Instructions for response (optional)", placeholder="Enter any specific instructions for the response here")
|
577 |
submit_button = gr.Button("Submit")
|
578 |
with gr.Column(scale=1):
|
579 |
temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1)
|
|
|
583 |
|
584 |
enhanced_context_driven_chatbot = EnhancedContextDrivenChatbot()
|
585 |
|
586 |
+
def chat(question, history, temperature, top_p, repetition_penalty, web_search, user_instructions):
|
587 |
+
answer = ask_question(question, temperature, top_p, repetition_penalty, web_search, enhanced_context_driven_chatbot, user_instructions)
|
588 |
history.append((question, answer))
|
589 |
return "", history
|
590 |
|
591 |
+
submit_button.click(chat, inputs=[question_input, chatbot, temperature_slider, top_p_slider, repetition_penalty_slider, web_search_checkbox, instructions_input], outputs=[question_input, chatbot])
|
592 |
|
593 |
clear_button = gr.Button("Clear Cache")
|
594 |
clear_output = gr.Textbox(label="Cache Status")
|