Upload ask.py
Browse files
ask.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
import json
|
2 |
import logging
|
3 |
import os
|
|
|
|
|
4 |
import urllib.parse
|
5 |
from concurrent.futures import ThreadPoolExecutor
|
6 |
from datetime import datetime
|
7 |
from functools import partial
|
8 |
-
from
|
|
|
9 |
|
10 |
import click
|
11 |
import duckdb
|
@@ -15,6 +18,7 @@ from bs4 import BeautifulSoup
|
|
15 |
from dotenv import load_dotenv
|
16 |
from jinja2 import BaseLoader, Environment
|
17 |
from openai import OpenAI
|
|
|
18 |
|
19 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
20 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
@@ -25,6 +29,7 @@ def _get_logger(log_level: str) -> logging.Logger:
|
|
25 |
logger.setLevel(log_level)
|
26 |
if len(logger.handlers) > 0:
|
27 |
return logger
|
|
|
28 |
handler = logging.StreamHandler()
|
29 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
30 |
handler.setFormatter(formatter)
|
@@ -417,57 +422,99 @@ Here is the context:
|
|
417 |
output_length: int,
|
418 |
url_list_str: str,
|
419 |
model_name: str,
|
420 |
-
) -> str:
|
421 |
logger = self.logger
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
|
473 |
def launch_gradio(
|
@@ -482,44 +529,63 @@ def launch_gradio(
|
|
482 |
logger: logging.Logger,
|
483 |
) -> None:
|
484 |
ask = Ask(logger=logger)
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
)
|
493 |
-
|
494 |
-
label="
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
|
525 |
@click.command(help="Search web for the query and summarize the results.")
|
@@ -615,7 +681,8 @@ def search_extract_summarize(
|
|
615 |
if query is None:
|
616 |
raise Exception("Query is required for the command line mode")
|
617 |
ask = Ask(logger=logger)
|
618 |
-
|
|
|
619 |
query=query,
|
620 |
date_restrict=date_restrict,
|
621 |
target_site=target_site,
|
@@ -623,8 +690,10 @@ def search_extract_summarize(
|
|
623 |
output_length=output_length,
|
624 |
url_list_str=_read_url_list(url_list_file),
|
625 |
model_name=model_name,
|
626 |
-
)
|
627 |
-
|
|
|
|
|
628 |
|
629 |
|
630 |
if __name__ == "__main__":
|
|
|
1 |
import json
|
2 |
import logging
|
3 |
import os
|
4 |
+
import queue
|
5 |
+
import re
|
6 |
import urllib.parse
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from datetime import datetime
|
9 |
from functools import partial
|
10 |
+
from queue import Queue
|
11 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple
|
12 |
|
13 |
import click
|
14 |
import duckdb
|
|
|
18 |
from dotenv import load_dotenv
|
19 |
from jinja2 import BaseLoader, Environment
|
20 |
from openai import OpenAI
|
21 |
+
from regex import T
|
22 |
|
23 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
24 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
|
|
29 |
logger.setLevel(log_level)
|
30 |
if len(logger.handlers) > 0:
|
31 |
return logger
|
32 |
+
|
33 |
handler = logging.StreamHandler()
|
34 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
35 |
handler.setFormatter(formatter)
|
|
|
422 |
output_length: int,
|
423 |
url_list_str: str,
|
424 |
model_name: str,
|
425 |
+
) -> Generator[Tuple[str, str], None, Tuple[str, str]]:
|
426 |
logger = self.logger
|
427 |
+
log_queue = Queue()
|
428 |
+
|
429 |
+
queue_handler = logging.Handler()
|
430 |
+
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
431 |
+
queue_handler.emit = lambda record: log_queue.put(formatter.format(record))
|
432 |
+
logger.addHandler(queue_handler)
|
433 |
+
|
434 |
+
def update_logs():
|
435 |
+
logs = []
|
436 |
+
while True:
|
437 |
+
try:
|
438 |
+
log = log_queue.get_nowait()
|
439 |
+
logs.append(log)
|
440 |
+
except queue.Empty:
|
441 |
+
break
|
442 |
+
return "\n".join(logs)
|
443 |
+
|
444 |
+
def process_with_logs():
|
445 |
+
if url_list_str is None or url_list_str.strip() == "":
|
446 |
+
logger.info("Searching the web ...")
|
447 |
+
yield "", update_logs()
|
448 |
+
links = self.search_web(query, date_restrict, target_site)
|
449 |
+
logger.info(f"β
Found {len(links)} links for query: {query}")
|
450 |
+
for i, link in enumerate(links):
|
451 |
+
logger.debug(f"{i+1}. {link}")
|
452 |
+
yield "", update_logs()
|
453 |
+
else:
|
454 |
+
links = url_list_str.split("\n")
|
455 |
+
|
456 |
+
logger.info("Scraping the URLs ...")
|
457 |
+
yield "", update_logs()
|
458 |
+
scrape_results = self.scrape_urls(links)
|
459 |
+
logger.info(f"β
Scraped {len(scrape_results)} URLs.")
|
460 |
+
yield "", update_logs()
|
461 |
+
|
462 |
+
logger.info("Chunking the text ...")
|
463 |
+
yield "", update_logs()
|
464 |
+
chunking_results = self.chunk_results(scrape_results, 1000, 100)
|
465 |
+
total_chunks = 0
|
466 |
+
for url, chunks in chunking_results.items():
|
467 |
+
logger.debug(f"URL: {url}")
|
468 |
+
total_chunks += len(chunks)
|
469 |
+
for i, chunk in enumerate(chunks):
|
470 |
+
logger.debug(f"Chunk {i+1}: {chunk}")
|
471 |
+
logger.info(f"β
Generated {total_chunks} chunks ...")
|
472 |
+
yield "", update_logs()
|
473 |
+
|
474 |
+
logger.info(f"Saving {total_chunks} chunks to DB ...")
|
475 |
+
yield "", update_logs()
|
476 |
+
table_name = self.save_to_db(chunking_results)
|
477 |
+
logger.info(f"β
Successfully embedded and saved chunks to DB.")
|
478 |
+
yield "", update_logs()
|
479 |
+
|
480 |
+
logger.info("Querying the vector DB to get context ...")
|
481 |
+
matched_chunks = self.vector_search(table_name, query)
|
482 |
+
for i, result in enumerate(matched_chunks):
|
483 |
+
logger.debug(f"{i+1}. {result}")
|
484 |
+
logger.info(f"β
Got {len(matched_chunks)} matched chunks.")
|
485 |
+
yield "", update_logs()
|
486 |
+
|
487 |
+
logger.info("Running inference with context ...")
|
488 |
+
yield "", update_logs()
|
489 |
+
answer = self.run_inference(
|
490 |
+
query=query,
|
491 |
+
model_name=model_name,
|
492 |
+
matched_chunks=matched_chunks,
|
493 |
+
output_language=output_language,
|
494 |
+
output_length=output_length,
|
495 |
+
)
|
496 |
+
logger.info("β
Finished inference API call.")
|
497 |
+
logger.info("Generating output ...")
|
498 |
+
yield "", update_logs()
|
499 |
|
500 |
+
answer = f"# Answer\n\n{answer}\n"
|
501 |
+
references = "\n".join(
|
502 |
+
[f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
|
503 |
+
)
|
504 |
+
yield f"{answer}\n\n# References\n\n{references}", update_logs()
|
505 |
+
|
506 |
+
logs = ""
|
507 |
+
final_result = ""
|
508 |
+
|
509 |
+
try:
|
510 |
+
for result, log_update in process_with_logs():
|
511 |
+
logs += log_update + "\n"
|
512 |
+
final_result = result
|
513 |
+
yield final_result, logs
|
514 |
+
finally:
|
515 |
+
logger.removeHandler(queue_handler)
|
516 |
+
|
517 |
+
return final_result, logs
|
518 |
|
519 |
|
520 |
def launch_gradio(
|
|
|
529 |
logger: logging.Logger,
|
530 |
) -> None:
|
531 |
ask = Ask(logger=logger)
|
532 |
+
with gr.Blocks() as demo:
|
533 |
+
gr.Markdown("# Ask.py - Web Search-Extract-Summarize")
|
534 |
+
gr.Markdown(
|
535 |
+
"Search the web with the query and summarize the results. Source code: https://github.com/pengfeng/ask.py"
|
536 |
+
)
|
537 |
+
|
538 |
+
with gr.Row():
|
539 |
+
with gr.Column():
|
540 |
+
|
541 |
+
query_input = gr.Textbox(label="Query", value=query)
|
542 |
+
date_restrict_input = gr.Number(
|
543 |
+
label="Date Restrict (Optional) [0 or empty means no date limit.]",
|
544 |
+
value=date_restrict,
|
545 |
+
)
|
546 |
+
target_site_input = gr.Textbox(
|
547 |
+
label="Target Sites (Optional) [Empty means searching the whole web.]",
|
548 |
+
value=target_site,
|
549 |
+
)
|
550 |
+
output_language_input = gr.Textbox(
|
551 |
+
label="Output Language (Optional) [Default is English.]",
|
552 |
+
value=output_language,
|
553 |
+
)
|
554 |
+
output_length_input = gr.Number(
|
555 |
+
label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
|
556 |
+
value=output_length,
|
557 |
+
)
|
558 |
+
url_list_input = gr.Textbox(
|
559 |
+
label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
|
560 |
+
lines=5,
|
561 |
+
max_lines=20,
|
562 |
+
value=url_list_str,
|
563 |
+
)
|
564 |
+
|
565 |
+
with gr.Accordion("More Options", open=False):
|
566 |
+
model_name_input = gr.Textbox(label="Model Name", value=model_name)
|
567 |
+
|
568 |
+
submit_button = gr.Button("Submit")
|
569 |
+
|
570 |
+
with gr.Column():
|
571 |
+
answer_output = gr.Textbox(label="Answer")
|
572 |
+
logs_output = gr.Textbox(label="Logs", lines=10)
|
573 |
+
|
574 |
+
submit_button.click(
|
575 |
+
fn=ask.run_query,
|
576 |
+
inputs=[
|
577 |
+
query_input,
|
578 |
+
date_restrict_input,
|
579 |
+
target_site_input,
|
580 |
+
output_language_input,
|
581 |
+
output_length_input,
|
582 |
+
url_list_input,
|
583 |
+
model_name_input,
|
584 |
+
],
|
585 |
+
outputs=[answer_output, logs_output],
|
586 |
+
)
|
587 |
+
|
588 |
+
demo.queue().launch(share=share_ui)
|
589 |
|
590 |
|
591 |
@click.command(help="Search web for the query and summarize the results.")
|
|
|
681 |
if query is None:
|
682 |
raise Exception("Query is required for the command line mode")
|
683 |
ask = Ask(logger=logger)
|
684 |
+
|
685 |
+
for result, _ in ask.run_query(
|
686 |
query=query,
|
687 |
date_restrict=date_restrict,
|
688 |
target_site=target_site,
|
|
|
690 |
output_length=output_length,
|
691 |
url_list_str=_read_url_list(url_list_file),
|
692 |
model_name=model_name,
|
693 |
+
):
|
694 |
+
final_result = result
|
695 |
+
|
696 |
+
click.echo(final_result)
|
697 |
|
698 |
|
699 |
if __name__ == "__main__":
|