LeetTools commited on
Commit
b8a5090
Β·
verified Β·
1 Parent(s): fc0d356

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +160 -91
ask.py CHANGED
@@ -1,11 +1,14 @@
1
  import json
2
  import logging
3
  import os
 
 
4
  import urllib.parse
5
  from concurrent.futures import ThreadPoolExecutor
6
  from datetime import datetime
7
  from functools import partial
8
- from typing import Any, Dict, List, Optional, Tuple
 
9
 
10
  import click
11
  import duckdb
@@ -15,6 +18,7 @@ from bs4 import BeautifulSoup
15
  from dotenv import load_dotenv
16
  from jinja2 import BaseLoader, Environment
17
  from openai import OpenAI
 
18
 
19
  script_dir = os.path.dirname(os.path.abspath(__file__))
20
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
@@ -25,6 +29,7 @@ def _get_logger(log_level: str) -> logging.Logger:
25
  logger.setLevel(log_level)
26
  if len(logger.handlers) > 0:
27
  return logger
 
28
  handler = logging.StreamHandler()
29
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
30
  handler.setFormatter(formatter)
@@ -417,57 +422,99 @@ Here is the context:
417
  output_length: int,
418
  url_list_str: str,
419
  model_name: str,
420
- ) -> str:
421
  logger = self.logger
422
- if url_list_str is None or url_list_str.strip() == "":
423
- logger.info("Searching the web ...")
424
- links = self.search_web(query, date_restrict, target_site)
425
- logger.info(f"βœ… Found {len(links)} links for query: {query}")
426
- for i, link in enumerate(links):
427
- logger.debug(f"{i+1}. {link}")
428
- else:
429
- links = url_list_str.split("\n")
430
-
431
- logger.info("Scraping the URLs ...")
432
- scrape_results = self.scrape_urls(links)
433
- logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
434
-
435
- logger.info("Chunking the text ...")
436
- chunking_results = self.chunk_results(scrape_results, 1000, 100)
437
- total_chunks = 0
438
- for url, chunks in chunking_results.items():
439
- logger.debug(f"URL: {url}")
440
- total_chunks += len(chunks)
441
- for i, chunk in enumerate(chunks):
442
- logger.debug(f"Chunk {i+1}: {chunk}")
443
- logger.info(f"βœ… Generated {total_chunks} chunks ...")
444
-
445
- logger.info(f"Saving {total_chunks} chunks to DB ...")
446
- table_name = self.save_to_db(chunking_results)
447
- logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
448
-
449
- logger.info("Querying the vector DB to get context ...")
450
- matched_chunks = self.vector_search(table_name, query)
451
- for i, result in enumerate(matched_chunks):
452
- logger.debug(f"{i+1}. {result}")
453
- logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
454
-
455
- logger.info("Running inference with context ...")
456
- answer = self.run_inference(
457
- query=query,
458
- model_name=model_name,
459
- matched_chunks=matched_chunks,
460
- output_language=output_language,
461
- output_length=output_length,
462
- )
463
- logger.info("βœ… Finished inference API call.")
464
- logger.info("Generating output ...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
- answer = f"# Answer\n\n{answer}\n"
467
- references = "\n".join(
468
- [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
469
- )
470
- return f"{answer}\n\n# References\n\n{references}"
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
 
473
  def launch_gradio(
@@ -482,44 +529,63 @@ def launch_gradio(
482
  logger: logging.Logger,
483
  ) -> None:
484
  ask = Ask(logger=logger)
485
- iface = gr.Interface(
486
- fn=ask.run_query,
487
- inputs=[
488
- gr.Textbox(label="Query", value=query),
489
- gr.Number(
490
- label="Date Restrict (Optional) [0 or empty means no date limit.]",
491
- value=date_restrict,
492
- ),
493
- gr.Textbox(
494
- label="Target Sites (Optional) [Empty means search the whole web.]",
495
- value=target_site,
496
- ),
497
- gr.Textbox(
498
- label="Output Language (Optional) [Default is English.]",
499
- value=output_language,
500
- ),
501
- gr.Number(
502
- label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
503
- value=output_length,
504
- ),
505
- gr.Textbox(
506
- label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
507
- lines=5,
508
- max_lines=20,
509
- value=url_list_str,
510
- ),
511
- ],
512
- additional_inputs=[
513
- gr.Textbox(label="Model Name", value=model_name),
514
- ],
515
- outputs="text",
516
- show_progress=True,
517
- flagging_options=[("Report Error", None)],
518
- title="Ask.py - Web Search-Extract-Summarize",
519
- description="Search the web with the query and summarize the results. Source code: https://github.com/pengfeng/ask.py",
520
- )
521
-
522
- iface.launch(share=share_ui)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
 
524
 
525
  @click.command(help="Search web for the query and summarize the results.")
@@ -615,7 +681,8 @@ def search_extract_summarize(
615
  if query is None:
616
  raise Exception("Query is required for the command line mode")
617
  ask = Ask(logger=logger)
618
- result = ask.run_query(
 
619
  query=query,
620
  date_restrict=date_restrict,
621
  target_site=target_site,
@@ -623,8 +690,10 @@ def search_extract_summarize(
623
  output_length=output_length,
624
  url_list_str=_read_url_list(url_list_file),
625
  model_name=model_name,
626
- )
627
- click.echo(result)
 
 
628
 
629
 
630
  if __name__ == "__main__":
 
1
  import json
2
  import logging
3
  import os
4
+ import queue
5
+ import re
6
  import urllib.parse
7
  from concurrent.futures import ThreadPoolExecutor
8
  from datetime import datetime
9
  from functools import partial
10
+ from queue import Queue
11
+ from typing import Any, Dict, Generator, List, Optional, Tuple
12
 
13
  import click
14
  import duckdb
 
18
  from dotenv import load_dotenv
19
  from jinja2 import BaseLoader, Environment
20
  from openai import OpenAI
21
+ from regex import T
22
 
23
  script_dir = os.path.dirname(os.path.abspath(__file__))
24
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
 
29
  logger.setLevel(log_level)
30
  if len(logger.handlers) > 0:
31
  return logger
32
+
33
  handler = logging.StreamHandler()
34
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
35
  handler.setFormatter(formatter)
 
422
  output_length: int,
423
  url_list_str: str,
424
  model_name: str,
425
+ ) -> Generator[Tuple[str, str], None, Tuple[str, str]]:
426
  logger = self.logger
427
+ log_queue = Queue()
428
+
429
+ queue_handler = logging.Handler()
430
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
431
+ queue_handler.emit = lambda record: log_queue.put(formatter.format(record))
432
+ logger.addHandler(queue_handler)
433
+
434
+ def update_logs():
435
+ logs = []
436
+ while True:
437
+ try:
438
+ log = log_queue.get_nowait()
439
+ logs.append(log)
440
+ except queue.Empty:
441
+ break
442
+ return "\n".join(logs)
443
+
444
+ def process_with_logs():
445
+ if url_list_str is None or url_list_str.strip() == "":
446
+ logger.info("Searching the web ...")
447
+ yield "", update_logs()
448
+ links = self.search_web(query, date_restrict, target_site)
449
+ logger.info(f"βœ… Found {len(links)} links for query: {query}")
450
+ for i, link in enumerate(links):
451
+ logger.debug(f"{i+1}. {link}")
452
+ yield "", update_logs()
453
+ else:
454
+ links = url_list_str.split("\n")
455
+
456
+ logger.info("Scraping the URLs ...")
457
+ yield "", update_logs()
458
+ scrape_results = self.scrape_urls(links)
459
+ logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
460
+ yield "", update_logs()
461
+
462
+ logger.info("Chunking the text ...")
463
+ yield "", update_logs()
464
+ chunking_results = self.chunk_results(scrape_results, 1000, 100)
465
+ total_chunks = 0
466
+ for url, chunks in chunking_results.items():
467
+ logger.debug(f"URL: {url}")
468
+ total_chunks += len(chunks)
469
+ for i, chunk in enumerate(chunks):
470
+ logger.debug(f"Chunk {i+1}: {chunk}")
471
+ logger.info(f"βœ… Generated {total_chunks} chunks ...")
472
+ yield "", update_logs()
473
+
474
+ logger.info(f"Saving {total_chunks} chunks to DB ...")
475
+ yield "", update_logs()
476
+ table_name = self.save_to_db(chunking_results)
477
+ logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
478
+ yield "", update_logs()
479
+
480
+ logger.info("Querying the vector DB to get context ...")
481
+ matched_chunks = self.vector_search(table_name, query)
482
+ for i, result in enumerate(matched_chunks):
483
+ logger.debug(f"{i+1}. {result}")
484
+ logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
485
+ yield "", update_logs()
486
+
487
+ logger.info("Running inference with context ...")
488
+ yield "", update_logs()
489
+ answer = self.run_inference(
490
+ query=query,
491
+ model_name=model_name,
492
+ matched_chunks=matched_chunks,
493
+ output_language=output_language,
494
+ output_length=output_length,
495
+ )
496
+ logger.info("βœ… Finished inference API call.")
497
+ logger.info("Generating output ...")
498
+ yield "", update_logs()
499
 
500
+ answer = f"# Answer\n\n{answer}\n"
501
+ references = "\n".join(
502
+ [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
503
+ )
504
+ yield f"{answer}\n\n# References\n\n{references}", update_logs()
505
+
506
+ logs = ""
507
+ final_result = ""
508
+
509
+ try:
510
+ for result, log_update in process_with_logs():
511
+ logs += log_update + "\n"
512
+ final_result = result
513
+ yield final_result, logs
514
+ finally:
515
+ logger.removeHandler(queue_handler)
516
+
517
+ return final_result, logs
518
 
519
 
520
  def launch_gradio(
 
529
  logger: logging.Logger,
530
  ) -> None:
531
  ask = Ask(logger=logger)
532
+ with gr.Blocks() as demo:
533
+ gr.Markdown("# Ask.py - Web Search-Extract-Summarize")
534
+ gr.Markdown(
535
+ "Search the web with the query and summarize the results. Source code: https://github.com/pengfeng/ask.py"
536
+ )
537
+
538
+ with gr.Row():
539
+ with gr.Column():
540
+
541
+ query_input = gr.Textbox(label="Query", value=query)
542
+ date_restrict_input = gr.Number(
543
+ label="Date Restrict (Optional) [0 or empty means no date limit.]",
544
+ value=date_restrict,
545
+ )
546
+ target_site_input = gr.Textbox(
547
+ label="Target Sites (Optional) [Empty means searching the whole web.]",
548
+ value=target_site,
549
+ )
550
+ output_language_input = gr.Textbox(
551
+ label="Output Language (Optional) [Default is English.]",
552
+ value=output_language,
553
+ )
554
+ output_length_input = gr.Number(
555
+ label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
556
+ value=output_length,
557
+ )
558
+ url_list_input = gr.Textbox(
559
+ label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
560
+ lines=5,
561
+ max_lines=20,
562
+ value=url_list_str,
563
+ )
564
+
565
+ with gr.Accordion("More Options", open=False):
566
+ model_name_input = gr.Textbox(label="Model Name", value=model_name)
567
+
568
+ submit_button = gr.Button("Submit")
569
+
570
+ with gr.Column():
571
+ answer_output = gr.Textbox(label="Answer")
572
+ logs_output = gr.Textbox(label="Logs", lines=10)
573
+
574
+ submit_button.click(
575
+ fn=ask.run_query,
576
+ inputs=[
577
+ query_input,
578
+ date_restrict_input,
579
+ target_site_input,
580
+ output_language_input,
581
+ output_length_input,
582
+ url_list_input,
583
+ model_name_input,
584
+ ],
585
+ outputs=[answer_output, logs_output],
586
+ )
587
+
588
+ demo.queue().launch(share=share_ui)
589
 
590
 
591
  @click.command(help="Search web for the query and summarize the results.")
 
681
  if query is None:
682
  raise Exception("Query is required for the command line mode")
683
  ask = Ask(logger=logger)
684
+
685
+ for result, _ in ask.run_query(
686
  query=query,
687
  date_restrict=date_restrict,
688
  target_site=target_site,
 
690
  output_length=output_length,
691
  url_list_str=_read_url_list(url_list_file),
692
  model_name=model_name,
693
+ ):
694
+ final_result = result
695
+
696
+ click.echo(final_result)
697
 
698
 
699
  if __name__ == "__main__":