Upload ask.py
Browse files
ask.py
CHANGED
@@ -19,7 +19,7 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
19 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
20 |
|
21 |
|
22 |
-
def
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logger.setLevel(log_level)
|
25 |
handler = logging.StreamHandler()
|
@@ -29,20 +29,6 @@ def _get_logger(log_level: str) -> logging.Logger:
|
|
29 |
return logger
|
30 |
|
31 |
|
32 |
-
def _read_url_list(url_list_file: str) -> str:
|
33 |
-
if url_list_file is None:
|
34 |
-
return None
|
35 |
-
|
36 |
-
with open(url_list_file, "r") as f:
|
37 |
-
links = f.readlines()
|
38 |
-
links = [
|
39 |
-
link.strip()
|
40 |
-
for link in links
|
41 |
-
if link.strip() != "" and not link.startswith("#")
|
42 |
-
]
|
43 |
-
return "\n".join(links)
|
44 |
-
|
45 |
-
|
46 |
class Ask:
|
47 |
|
48 |
def __init__(self, logger: Optional[logging.Logger] = None):
|
@@ -51,7 +37,7 @@ class Ask:
|
|
51 |
if logger is not None:
|
52 |
self.logger = logger
|
53 |
else:
|
54 |
-
self.logger =
|
55 |
|
56 |
self.table_name = "document_chunks"
|
57 |
self.db_con = duckdb.connect(":memory:")
|
@@ -397,66 +383,84 @@ Here is the context:
|
|
397 |
response_str = completion.choices[0].message.content
|
398 |
return response_str
|
399 |
|
400 |
-
def run_query(
|
401 |
-
self,
|
402 |
-
query: str,
|
403 |
-
date_restrict: int,
|
404 |
-
target_site: str,
|
405 |
-
output_language: str,
|
406 |
-
output_length: int,
|
407 |
-
url_list_str: str,
|
408 |
-
model_name: str,
|
409 |
-
) -> str:
|
410 |
-
logger = self.logger
|
411 |
-
if url_list_str is None or url_list_str.strip() == "":
|
412 |
-
logger.info("Searching the web ...")
|
413 |
-
links = self.search_web(query, date_restrict, target_site)
|
414 |
-
logger.info(f"β
Found {len(links)} links for query: {query}")
|
415 |
-
for i, link in enumerate(links):
|
416 |
-
logger.debug(f"{i+1}. {link}")
|
417 |
-
else:
|
418 |
-
links = url_list_str.split("\n")
|
419 |
-
|
420 |
-
logger.info("Scraping the URLs ...")
|
421 |
-
scrape_results = self.scrape_urls(links)
|
422 |
-
logger.info(f"β
Scraped {len(scrape_results)} URLs.")
|
423 |
-
|
424 |
-
logger.info("Chunking the text ...")
|
425 |
-
chunking_results = self.chunk_results(scrape_results, 1000, 100)
|
426 |
-
total_chunks = 0
|
427 |
-
for url, chunks in chunking_results.items():
|
428 |
-
logger.debug(f"URL: {url}")
|
429 |
-
total_chunks += len(chunks)
|
430 |
-
for i, chunk in enumerate(chunks):
|
431 |
-
logger.debug(f"Chunk {i+1}: {chunk}")
|
432 |
-
logger.info(f"β
Generated {total_chunks} chunks ...")
|
433 |
-
|
434 |
-
logger.info(f"Saving {total_chunks} chunks to DB ...")
|
435 |
-
self.save_to_db(chunking_results)
|
436 |
-
logger.info(f"β
Successfully embedded and saved chunks to DB.")
|
437 |
-
|
438 |
-
logger.info("Querying the vector DB to get context ...")
|
439 |
-
matched_chunks = self.vector_search(query)
|
440 |
-
for i, result in enumerate(matched_chunks):
|
441 |
-
logger.debug(f"{i+1}. {result}")
|
442 |
-
logger.info(f"β
Got {len(matched_chunks)} matched chunks.")
|
443 |
-
|
444 |
-
logger.info("Running inference with context ...")
|
445 |
-
answer = self.run_inference(
|
446 |
-
query=query,
|
447 |
-
model_name=model_name,
|
448 |
-
matched_chunks=matched_chunks,
|
449 |
-
output_language=output_language,
|
450 |
-
output_length=output_length,
|
451 |
-
)
|
452 |
-
logger.info("β
Finished inference API call.")
|
453 |
-
logger.info("generateing output ...")
|
454 |
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
|
461 |
|
462 |
def launch_gradio(
|
@@ -467,12 +471,11 @@ def launch_gradio(
|
|
467 |
output_length: int,
|
468 |
url_list_str: str,
|
469 |
model_name: str,
|
|
|
470 |
share_ui: bool,
|
471 |
-
logger: logging.Logger,
|
472 |
) -> None:
|
473 |
-
ask = Ask(logger=logger)
|
474 |
iface = gr.Interface(
|
475 |
-
fn=
|
476 |
inputs=[
|
477 |
gr.Textbox(label="Query", value=query),
|
478 |
gr.Number(
|
@@ -500,6 +503,7 @@ def launch_gradio(
|
|
500 |
],
|
501 |
additional_inputs=[
|
502 |
gr.Textbox(label="Model Name", value=model_name),
|
|
|
503 |
],
|
504 |
outputs="text",
|
505 |
show_progress=True,
|
@@ -582,7 +586,6 @@ def search_extract_summarize(
|
|
582 |
log_level: str,
|
583 |
):
|
584 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
585 |
-
logger = _get_logger(log_level)
|
586 |
|
587 |
if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
|
588 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
@@ -597,14 +600,14 @@ def search_extract_summarize(
|
|
597 |
output_length=output_length,
|
598 |
url_list_str=_read_url_list(url_list_file),
|
599 |
model_name=model_name,
|
|
|
600 |
share_ui=share_ui,
|
601 |
-
logger=logger,
|
602 |
)
|
603 |
else:
|
604 |
if query is None:
|
605 |
raise Exception("Query is required for the command line mode")
|
606 |
-
|
607 |
-
result =
|
608 |
query=query,
|
609 |
date_restrict=date_restrict,
|
610 |
target_site=target_site,
|
@@ -612,6 +615,7 @@ def search_extract_summarize(
|
|
612 |
output_length=output_length,
|
613 |
url_list_str=_read_url_list(url_list_file),
|
614 |
model_name=model_name,
|
|
|
615 |
)
|
616 |
click.echo(result)
|
617 |
|
|
|
19 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
20 |
|
21 |
|
22 |
+
def get_logger(log_level: str) -> logging.Logger:
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logger.setLevel(log_level)
|
25 |
handler = logging.StreamHandler()
|
|
|
29 |
return logger
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class Ask:
|
33 |
|
34 |
def __init__(self, logger: Optional[logging.Logger] = None):
|
|
|
37 |
if logger is not None:
|
38 |
self.logger = logger
|
39 |
else:
|
40 |
+
self.logger = get_logger("INFO")
|
41 |
|
42 |
self.table_name = "document_chunks"
|
43 |
self.db_con = duckdb.connect(":memory:")
|
|
|
383 |
response_str = completion.choices[0].message.content
|
384 |
return response_str
|
385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
+
def _read_url_list(url_list_file: str) -> str:
|
388 |
+
if url_list_file is None:
|
389 |
+
return None
|
390 |
+
|
391 |
+
with open(url_list_file, "r") as f:
|
392 |
+
links = f.readlines()
|
393 |
+
links = [
|
394 |
+
link.strip()
|
395 |
+
for link in links
|
396 |
+
if link.strip() != "" and not link.startswith("#")
|
397 |
+
]
|
398 |
+
return "\n".join(links)
|
399 |
+
|
400 |
+
|
401 |
+
def _run_query(
|
402 |
+
query: str,
|
403 |
+
date_restrict: int,
|
404 |
+
target_site: str,
|
405 |
+
output_language: str,
|
406 |
+
output_length: int,
|
407 |
+
url_list_str: str,
|
408 |
+
model_name: str,
|
409 |
+
log_level: str,
|
410 |
+
) -> str:
|
411 |
+
logger = get_logger(log_level)
|
412 |
+
|
413 |
+
ask = Ask(logger=logger)
|
414 |
+
|
415 |
+
if url_list_str is None or url_list_str.strip() == "":
|
416 |
+
logger.info("Searching the web ...")
|
417 |
+
links = ask.search_web(query, date_restrict, target_site)
|
418 |
+
logger.info(f"β
Found {len(links)} links for query: {query}")
|
419 |
+
for i, link in enumerate(links):
|
420 |
+
logger.debug(f"{i+1}. {link}")
|
421 |
+
else:
|
422 |
+
links = url_list_str.split("\n")
|
423 |
+
|
424 |
+
logger.info("Scraping the URLs ...")
|
425 |
+
scrape_results = ask.scrape_urls(links)
|
426 |
+
logger.info(f"β
Scraped {len(scrape_results)} URLs.")
|
427 |
+
|
428 |
+
logger.info("Chunking the text ...")
|
429 |
+
chunking_results = ask.chunk_results(scrape_results, 1000, 100)
|
430 |
+
total_chunks = 0
|
431 |
+
for url, chunks in chunking_results.items():
|
432 |
+
logger.debug(f"URL: {url}")
|
433 |
+
total_chunks += len(chunks)
|
434 |
+
for i, chunk in enumerate(chunks):
|
435 |
+
logger.debug(f"Chunk {i+1}: {chunk}")
|
436 |
+
logger.info(f"β
Generated {total_chunks} chunks ...")
|
437 |
+
|
438 |
+
logger.info(f"Saving {total_chunks} chunks to DB ...")
|
439 |
+
ask.save_to_db(chunking_results)
|
440 |
+
logger.info(f"β
Successfully embedded and saved chunks to DB.")
|
441 |
+
|
442 |
+
logger.info("Querying the vector DB to get context ...")
|
443 |
+
matched_chunks = ask.vector_search(query)
|
444 |
+
for i, result in enumerate(matched_chunks):
|
445 |
+
logger.debug(f"{i+1}. {result}")
|
446 |
+
logger.info(f"β
Got {len(matched_chunks)} matched chunks.")
|
447 |
+
|
448 |
+
logger.info("Running inference with context ...")
|
449 |
+
answer = ask.run_inference(
|
450 |
+
query=query,
|
451 |
+
model_name=model_name,
|
452 |
+
matched_chunks=matched_chunks,
|
453 |
+
output_language=output_language,
|
454 |
+
output_length=output_length,
|
455 |
+
)
|
456 |
+
logger.info("β
Finished inference API call.")
|
457 |
+
logger.info("generateing output ...")
|
458 |
+
|
459 |
+
answer = f"# Answer\n\n{answer}\n"
|
460 |
+
references = "\n".join(
|
461 |
+
[f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
|
462 |
+
)
|
463 |
+
return f"{answer}\n\n# References\n\n{references}"
|
464 |
|
465 |
|
466 |
def launch_gradio(
|
|
|
471 |
output_length: int,
|
472 |
url_list_str: str,
|
473 |
model_name: str,
|
474 |
+
log_level: str,
|
475 |
share_ui: bool,
|
|
|
476 |
) -> None:
|
|
|
477 |
iface = gr.Interface(
|
478 |
+
fn=_run_query,
|
479 |
inputs=[
|
480 |
gr.Textbox(label="Query", value=query),
|
481 |
gr.Number(
|
|
|
503 |
],
|
504 |
additional_inputs=[
|
505 |
gr.Textbox(label="Model Name", value=model_name),
|
506 |
+
gr.Textbox(label="Log Level", value=log_level),
|
507 |
],
|
508 |
outputs="text",
|
509 |
show_progress=True,
|
|
|
586 |
log_level: str,
|
587 |
):
|
588 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
|
|
589 |
|
590 |
if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
|
591 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
|
|
600 |
output_length=output_length,
|
601 |
url_list_str=_read_url_list(url_list_file),
|
602 |
model_name=model_name,
|
603 |
+
log_level=log_level,
|
604 |
share_ui=share_ui,
|
|
|
605 |
)
|
606 |
else:
|
607 |
if query is None:
|
608 |
raise Exception("Query is required for the command line mode")
|
609 |
+
|
610 |
+
result = _run_query(
|
611 |
query=query,
|
612 |
date_restrict=date_restrict,
|
613 |
target_site=target_site,
|
|
|
615 |
output_length=output_length,
|
616 |
url_list_str=_read_url_list(url_list_file),
|
617 |
model_name=model_name,
|
618 |
+
log_level=log_level,
|
619 |
)
|
620 |
click.echo(result)
|
621 |
|