LeetTools commited on
Commit
30df565
Β·
verified Β·
1 Parent(s): 1d2d847

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +82 -86
ask.py CHANGED
@@ -19,7 +19,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
19
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
 
21
 
22
- def get_logger(log_level: str) -> logging.Logger:
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(log_level)
25
  handler = logging.StreamHandler()
@@ -29,6 +29,20 @@ def get_logger(log_level: str) -> logging.Logger:
29
  return logger
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class Ask:
33
 
34
  def __init__(self, logger: Optional[logging.Logger] = None):
@@ -37,7 +51,7 @@ class Ask:
37
  if logger is not None:
38
  self.logger = logger
39
  else:
40
- self.logger = get_logger("INFO")
41
 
42
  self.table_name = "document_chunks"
43
  self.db_con = duckdb.connect(":memory:")
@@ -383,84 +397,66 @@ Here is the context:
383
  response_str = completion.choices[0].message.content
384
  return response_str
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- def _read_url_list(url_list_file: str) -> str:
388
- if url_list_file is None:
389
- return None
390
-
391
- with open(url_list_file, "r") as f:
392
- links = f.readlines()
393
- links = [
394
- link.strip()
395
- for link in links
396
- if link.strip() != "" and not link.startswith("#")
397
- ]
398
- return "\n".join(links)
399
-
400
-
401
- def _run_query(
402
- query: str,
403
- date_restrict: int,
404
- target_site: str,
405
- output_language: str,
406
- output_length: int,
407
- url_list_str: str,
408
- model_name: str,
409
- log_level: str,
410
- ) -> str:
411
- logger = get_logger(log_level)
412
-
413
- ask = Ask(logger=logger)
414
-
415
- if url_list_str is None or url_list_str.strip() == "":
416
- logger.info("Searching the web ...")
417
- links = ask.search_web(query, date_restrict, target_site)
418
- logger.info(f"βœ… Found {len(links)} links for query: {query}")
419
- for i, link in enumerate(links):
420
- logger.debug(f"{i+1}. {link}")
421
- else:
422
- links = url_list_str.split("\n")
423
-
424
- logger.info("Scraping the URLs ...")
425
- scrape_results = ask.scrape_urls(links)
426
- logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
427
-
428
- logger.info("Chunking the text ...")
429
- chunking_results = ask.chunk_results(scrape_results, 1000, 100)
430
- total_chunks = 0
431
- for url, chunks in chunking_results.items():
432
- logger.debug(f"URL: {url}")
433
- total_chunks += len(chunks)
434
- for i, chunk in enumerate(chunks):
435
- logger.debug(f"Chunk {i+1}: {chunk}")
436
- logger.info(f"βœ… Generated {total_chunks} chunks ...")
437
-
438
- logger.info(f"Saving {total_chunks} chunks to DB ...")
439
- ask.save_to_db(chunking_results)
440
- logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
441
-
442
- logger.info("Querying the vector DB to get context ...")
443
- matched_chunks = ask.vector_search(query)
444
- for i, result in enumerate(matched_chunks):
445
- logger.debug(f"{i+1}. {result}")
446
- logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
447
-
448
- logger.info("Running inference with context ...")
449
- answer = ask.run_inference(
450
- query=query,
451
- model_name=model_name,
452
- matched_chunks=matched_chunks,
453
- output_language=output_language,
454
- output_length=output_length,
455
- )
456
- logger.info("βœ… Finished inference API call.")
457
- logger.info("generateing output ...")
458
-
459
- answer = f"# Answer\n\n{answer}\n"
460
- references = "\n".join(
461
- [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
462
- )
463
- return f"{answer}\n\n# References\n\n{references}"
464
 
465
 
466
  def launch_gradio(
@@ -471,11 +467,12 @@ def launch_gradio(
471
  output_length: int,
472
  url_list_str: str,
473
  model_name: str,
474
- log_level: str,
475
  share_ui: bool,
 
476
  ) -> None:
 
477
  iface = gr.Interface(
478
- fn=_run_query,
479
  inputs=[
480
  gr.Textbox(label="Query", value=query),
481
  gr.Number(
@@ -503,7 +500,6 @@ def launch_gradio(
503
  ],
504
  additional_inputs=[
505
  gr.Textbox(label="Model Name", value=model_name),
506
- gr.Textbox(label="Log Level", value=log_level),
507
  ],
508
  outputs="text",
509
  show_progress=True,
@@ -586,6 +582,7 @@ def search_extract_summarize(
586
  log_level: str,
587
  ):
588
  load_dotenv(dotenv_path=default_env_file, override=False)
 
589
 
590
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
591
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
@@ -600,14 +597,14 @@ def search_extract_summarize(
600
  output_length=output_length,
601
  url_list_str=_read_url_list(url_list_file),
602
  model_name=model_name,
603
- log_level=log_level,
604
  share_ui=share_ui,
 
605
  )
606
  else:
607
  if query is None:
608
  raise Exception("Query is required for the command line mode")
609
-
610
- result = _run_query(
611
  query=query,
612
  date_restrict=date_restrict,
613
  target_site=target_site,
@@ -615,7 +612,6 @@ def search_extract_summarize(
615
  output_length=output_length,
616
  url_list_str=_read_url_list(url_list_file),
617
  model_name=model_name,
618
- log_level=log_level,
619
  )
620
  click.echo(result)
621
 
 
19
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
 
21
 
22
+ def _get_logger(log_level: str) -> logging.Logger:
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(log_level)
25
  handler = logging.StreamHandler()
 
29
  return logger
30
 
31
 
32
+ def _read_url_list(url_list_file: str) -> str:
33
+ if url_list_file is None:
34
+ return None
35
+
36
+ with open(url_list_file, "r") as f:
37
+ links = f.readlines()
38
+ links = [
39
+ link.strip()
40
+ for link in links
41
+ if link.strip() != "" and not link.startswith("#")
42
+ ]
43
+ return "\n".join(links)
44
+
45
+
46
  class Ask:
47
 
48
  def __init__(self, logger: Optional[logging.Logger] = None):
 
51
  if logger is not None:
52
  self.logger = logger
53
  else:
54
+ self.logger = _get_logger("INFO")
55
 
56
  self.table_name = "document_chunks"
57
  self.db_con = duckdb.connect(":memory:")
 
397
  response_str = completion.choices[0].message.content
398
  return response_str
399
 
400
+ def run_query(
401
+ self,
402
+ query: str,
403
+ date_restrict: int,
404
+ target_site: str,
405
+ output_language: str,
406
+ output_length: int,
407
+ url_list_str: str,
408
+ model_name: str,
409
+ ) -> str:
410
+ logger = self.logger
411
+ if url_list_str is None or url_list_str.strip() == "":
412
+ logger.info("Searching the web ...")
413
+ links = self.search_web(query, date_restrict, target_site)
414
+ logger.info(f"βœ… Found {len(links)} links for query: {query}")
415
+ for i, link in enumerate(links):
416
+ logger.debug(f"{i+1}. {link}")
417
+ else:
418
+ links = url_list_str.split("\n")
419
+
420
+ logger.info("Scraping the URLs ...")
421
+ scrape_results = self.scrape_urls(links)
422
+ logger.info(f"βœ… Scraped {len(scrape_results)} URLs.")
423
+
424
+ logger.info("Chunking the text ...")
425
+ chunking_results = self.chunk_results(scrape_results, 1000, 100)
426
+ total_chunks = 0
427
+ for url, chunks in chunking_results.items():
428
+ logger.debug(f"URL: {url}")
429
+ total_chunks += len(chunks)
430
+ for i, chunk in enumerate(chunks):
431
+ logger.debug(f"Chunk {i+1}: {chunk}")
432
+ logger.info(f"βœ… Generated {total_chunks} chunks ...")
433
+
434
+ logger.info(f"Saving {total_chunks} chunks to DB ...")
435
+ self.save_to_db(chunking_results)
436
+ logger.info(f"βœ… Successfully embedded and saved chunks to DB.")
437
+
438
+ logger.info("Querying the vector DB to get context ...")
439
+ matched_chunks = self.vector_search(query)
440
+ for i, result in enumerate(matched_chunks):
441
+ logger.debug(f"{i+1}. {result}")
442
+ logger.info(f"βœ… Got {len(matched_chunks)} matched chunks.")
443
+
444
+ logger.info("Running inference with context ...")
445
+ answer = self.run_inference(
446
+ query=query,
447
+ model_name=model_name,
448
+ matched_chunks=matched_chunks,
449
+ output_language=output_language,
450
+ output_length=output_length,
451
+ )
452
+ logger.info("βœ… Finished inference API call.")
453
+ logger.info("generateing output ...")
454
 
455
+ answer = f"# Answer\n\n{answer}\n"
456
+ references = "\n".join(
457
+ [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
458
+ )
459
+ return f"{answer}\n\n# References\n\n{references}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
461
 
462
  def launch_gradio(
 
467
  output_length: int,
468
  url_list_str: str,
469
  model_name: str,
 
470
  share_ui: bool,
471
+ logger: logging.Logger,
472
  ) -> None:
473
+ ask = Ask(logger=logger)
474
  iface = gr.Interface(
475
+ fn=ask.run_query,
476
  inputs=[
477
  gr.Textbox(label="Query", value=query),
478
  gr.Number(
 
500
  ],
501
  additional_inputs=[
502
  gr.Textbox(label="Model Name", value=model_name),
 
503
  ],
504
  outputs="text",
505
  show_progress=True,
 
582
  log_level: str,
583
  ):
584
  load_dotenv(dotenv_path=default_env_file, override=False)
585
+ logger = _get_logger(log_level)
586
 
587
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
588
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
 
597
  output_length=output_length,
598
  url_list_str=_read_url_list(url_list_file),
599
  model_name=model_name,
 
600
  share_ui=share_ui,
601
+ logger=logger,
602
  )
603
  else:
604
  if query is None:
605
  raise Exception("Query is required for the command line mode")
606
+ ask = Ask(logger=logger)
607
+ result = ask.run_query(
608
  query=query,
609
  date_restrict=date_restrict,
610
  target_site=target_site,
 
612
  output_length=output_length,
613
  url_list_str=_read_url_list(url_list_file),
614
  model_name=model_name,
 
615
  )
616
  click.echo(result)
617