LeetTools commited on
Commit
72ff970
Β·
verified Β·
1 Parent(s): 30df565

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +86 -82
ask.py CHANGED
@@ -19,7 +19,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
19
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
 
21
 
22
- def _get_logger(log_level: str) -> logging.Logger:
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(log_level)
25
  handler = logging.StreamHandler()
@@ -29,20 +29,6 @@ def _get_logger(log_level: str) -> logging.Logger:
29
  return logger
30
 
31
 
32
- def _read_url_list(url_list_file: str) -> str:
33
- if url_list_file is None:
34
- return None
35
-
36
- with open(url_list_file, "r") as f:
37
- links = f.readlines()
38
- links = [
39
- link.strip()
40
- for link in links
41
- if link.strip() != "" and not link.startswith("#")
42
- ]
43
- return "\n".join(links)
44
-
45
-
46
  class Ask:
47
 
48
  def __init__(self, logger: Optional[logging.Logger] = None):
@@ -51,7 +37,7 @@ class Ask:
51
  if logger is not None:
52
  self.logger = logger
53
  else:
54
- self.logger = _get_logger("INFO")
55
 
56
  self.table_name = "document_chunks"
57
  self.db_con = duckdb.connect(":memory:")
@@ -397,66 +383,84 @@ Here is the context:
397
  response_str = completion.choices[0].message.content
398
  return response_str
399
 
400
- def run_query(
401
- self,
402
- query: str,
403
- date_restrict: int,
404
- target_site: str,
405
- output_language: str,
406
- output_length: int,
407
- url_list_str: str,
408
- model_name: str,
409
- ) -> str:
410
- logger = self.logger
411
- if url_list_str is None or url_list_str.strip() == "":
412
- logger.info("Searching the web ...")
413
- links = self.search_web(query, date_restrict, target_site)
414
- logger.info(f"βœ… Found {len(links)} links for query: {query}")
415
- for i, link in enumerate(links):
416
- logger.debug(f"{i+1}. {link}")
417
- else:
418
- links = url_list_str.split("\n")
419
-
420
- logger.info("Scraping the URLs ...")
421
- scrape_results = self.scrape_urls(links)
422
- logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
423
-
424
- logger.info("Chunking the text ...")
425
- chunking_results = self.chunk_results(scrape_results, 1000, 100)
426
- total_chunks = 0
427
- for url, chunks in chunking_results.items():
428
- logger.debug(f"URL: {url}")
429
- total_chunks += len(chunks)
430
- for i, chunk in enumerate(chunks):
431
- logger.debug(f"Chunk {i+1}: {chunk}")
432
- logger.info(f"βœ… Generated {total_chunks} chunks ...")
433
-
434
- logger.info(f"Saving {total_chunks} chunks to DB ...")
435
- self.save_to_db(chunking_results)
436
- logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
437
-
438
- logger.info("Querying the vector DB to get context ...")
439
- matched_chunks = self.vector_search(query)
440
- for i, result in enumerate(matched_chunks):
441
- logger.debug(f"{i+1}. {result}")
442
- logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
443
-
444
- logger.info("Running inference with context ...")
445
- answer = self.run_inference(
446
- query=query,
447
- model_name=model_name,
448
- matched_chunks=matched_chunks,
449
- output_language=output_language,
450
- output_length=output_length,
451
- )
452
- logger.info("βœ… Finished inference API call.")
453
- logger.info("generateing output ...")
454
 
455
- answer = f"# Answer\n\n{answer}\n"
456
- references = "\n".join(
457
- [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
458
- )
459
- return f"{answer}\n\n# References\n\n{references}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
 
462
  def launch_gradio(
@@ -467,12 +471,11 @@ def launch_gradio(
467
  output_length: int,
468
  url_list_str: str,
469
  model_name: str,
 
470
  share_ui: bool,
471
- logger: logging.Logger,
472
  ) -> None:
473
- ask = Ask(logger=logger)
474
  iface = gr.Interface(
475
- fn=ask.run_query,
476
  inputs=[
477
  gr.Textbox(label="Query", value=query),
478
  gr.Number(
@@ -500,6 +503,7 @@ def launch_gradio(
500
  ],
501
  additional_inputs=[
502
  gr.Textbox(label="Model Name", value=model_name),
 
503
  ],
504
  outputs="text",
505
  show_progress=True,
@@ -582,7 +586,6 @@ def search_extract_summarize(
582
  log_level: str,
583
  ):
584
  load_dotenv(dotenv_path=default_env_file, override=False)
585
- logger = _get_logger(log_level)
586
 
587
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
588
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
@@ -597,14 +600,14 @@ def search_extract_summarize(
597
  output_length=output_length,
598
  url_list_str=_read_url_list(url_list_file),
599
  model_name=model_name,
 
600
  share_ui=share_ui,
601
- logger=logger,
602
  )
603
  else:
604
  if query is None:
605
  raise Exception("Query is required for the command line mode")
606
- ask = Ask(logger=logger)
607
- result = ask.run_query(
608
  query=query,
609
  date_restrict=date_restrict,
610
  target_site=target_site,
@@ -612,6 +615,7 @@ def search_extract_summarize(
612
  output_length=output_length,
613
  url_list_str=_read_url_list(url_list_file),
614
  model_name=model_name,
 
615
  )
616
  click.echo(result)
617
 
 
19
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
 
21
 
22
+ def get_logger(log_level: str) -> logging.Logger:
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(log_level)
25
  handler = logging.StreamHandler()
 
29
  return logger
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class Ask:
33
 
34
  def __init__(self, logger: Optional[logging.Logger] = None):
 
37
  if logger is not None:
38
  self.logger = logger
39
  else:
40
+ self.logger = get_logger("INFO")
41
 
42
  self.table_name = "document_chunks"
43
  self.db_con = duckdb.connect(":memory:")
 
383
  response_str = completion.choices[0].message.content
384
  return response_str
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ def _read_url_list(url_list_file: str) -> str:
388
+ if url_list_file is None:
389
+ return None
390
+
391
+ with open(url_list_file, "r") as f:
392
+ links = f.readlines()
393
+ links = [
394
+ link.strip()
395
+ for link in links
396
+ if link.strip() != "" and not link.startswith("#")
397
+ ]
398
+ return "\n".join(links)
399
+
400
+
401
+ def _run_query(
402
+ query: str,
403
+ date_restrict: int,
404
+ target_site: str,
405
+ output_language: str,
406
+ output_length: int,
407
+ url_list_str: str,
408
+ model_name: str,
409
+ log_level: str,
410
+ ) -> str:
411
+ logger = get_logger(log_level)
412
+
413
+ ask = Ask(logger=logger)
414
+
415
+ if url_list_str is None or url_list_str.strip() == "":
416
+ logger.info("Searching the web ...")
417
+ links = ask.search_web(query, date_restrict, target_site)
418
+ logger.info(f"βœ… Found {len(links)} links for query: {query}")
419
+ for i, link in enumerate(links):
420
+ logger.debug(f"{i+1}. {link}")
421
+ else:
422
+ links = url_list_str.split("\n")
423
+
424
+ logger.info("Scraping the URLs ...")
425
+ scrape_results = ask.scrape_urls(links)
426
+ logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
427
+
428
+ logger.info("Chunking the text ...")
429
+ chunking_results = ask.chunk_results(scrape_results, 1000, 100)
430
+ total_chunks = 0
431
+ for url, chunks in chunking_results.items():
432
+ logger.debug(f"URL: {url}")
433
+ total_chunks += len(chunks)
434
+ for i, chunk in enumerate(chunks):
435
+ logger.debug(f"Chunk {i+1}: {chunk}")
436
+ logger.info(f"βœ… Generated {total_chunks} chunks ...")
437
+
438
+ logger.info(f"Saving {total_chunks} chunks to DB ...")
439
+ ask.save_to_db(chunking_results)
440
+ logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
441
+
442
+ logger.info("Querying the vector DB to get context ...")
443
+ matched_chunks = ask.vector_search(query)
444
+ for i, result in enumerate(matched_chunks):
445
+ logger.debug(f"{i+1}. {result}")
446
+ logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
447
+
448
+ logger.info("Running inference with context ...")
449
+ answer = ask.run_inference(
450
+ query=query,
451
+ model_name=model_name,
452
+ matched_chunks=matched_chunks,
453
+ output_language=output_language,
454
+ output_length=output_length,
455
+ )
456
+ logger.info("βœ… Finished inference API call.")
457
+ logger.info("generateing output ...")
458
+
459
+ answer = f"# Answer\n\n{answer}\n"
460
+ references = "\n".join(
461
+ [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
462
+ )
463
+ return f"{answer}\n\n# References\n\n{references}"
464
 
465
 
466
  def launch_gradio(
 
471
  output_length: int,
472
  url_list_str: str,
473
  model_name: str,
474
+ log_level: str,
475
  share_ui: bool,
 
476
  ) -> None:
 
477
  iface = gr.Interface(
478
+ fn=_run_query,
479
  inputs=[
480
  gr.Textbox(label="Query", value=query),
481
  gr.Number(
 
503
  ],
504
  additional_inputs=[
505
  gr.Textbox(label="Model Name", value=model_name),
506
+ gr.Textbox(label="Log Level", value=log_level),
507
  ],
508
  outputs="text",
509
  show_progress=True,
 
586
  log_level: str,
587
  ):
588
  load_dotenv(dotenv_path=default_env_file, override=False)
 
589
 
590
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
591
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
 
600
  output_length=output_length,
601
  url_list_str=_read_url_list(url_list_file),
602
  model_name=model_name,
603
+ log_level=log_level,
604
  share_ui=share_ui,
 
605
  )
606
  else:
607
  if query is None:
608
  raise Exception("Query is required for the command line mode")
609
+
610
+ result = _run_query(
611
  query=query,
612
  date_restrict=date_restrict,
613
  target_site=target_site,
 
615
  output_length=output_length,
616
  url_list_str=_read_url_list(url_list_file),
617
  model_name=model_name,
618
+ log_level=log_level,
619
  )
620
  click.echo(result)
621