LeetTools commited on
Commit
fc0d356
·
verified ·
1 Parent(s): 72ff970

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +119 -112
ask.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import os
4
  import urllib.parse
5
  from concurrent.futures import ThreadPoolExecutor
 
6
  from functools import partial
7
  from typing import Any, Dict, List, Optional, Tuple
8
 
@@ -19,9 +20,11 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
19
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
 
21
 
22
- def get_logger(log_level: str) -> logging.Logger:
23
  logger = logging.getLogger(__name__)
24
  logger.setLevel(log_level)
 
 
25
  handler = logging.StreamHandler()
26
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
27
  handler.setFormatter(formatter)
@@ -29,6 +32,20 @@ def get_logger(log_level: str) -> logging.Logger:
29
  return logger
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  class Ask:
33
 
34
  def __init__(self, logger: Optional[logging.Logger] = None):
@@ -37,9 +54,8 @@ class Ask:
37
  if logger is not None:
38
  self.logger = logger
39
  else:
40
- self.logger = get_logger("INFO")
41
 
42
- self.table_name = "document_chunks"
43
  self.db_con = duckdb.connect(":memory:")
44
 
45
  self.db_con.install_extension("vss")
@@ -48,17 +64,6 @@ class Ask:
48
  self.db_con.load_extension("fts")
49
  self.db_con.sql("CREATE SEQUENCE seq_docid START 1000")
50
 
51
- self.db_con.execute(
52
- f"""
53
- CREATE TABLE {self.table_name} (
54
- doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'),
55
- url TEXT,
56
- chunk TEXT,
57
- vec FLOAT[{self.embedding_dimensions}]
58
- );
59
- """
60
- )
61
-
62
  self.session = requests.Session()
63
  user_agent: str = (
64
  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
@@ -221,12 +226,31 @@ CREATE TABLE {self.table_name} (
221
  embeddings = self.get_embedding(client, texts)
222
  return chunk_batch, embeddings
223
 
224
- def save_to_db(self, chunking_results: Dict[str, List[str]]) -> None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  client = self._get_api_client()
226
  embed_batch_size = 50
227
  query_batch_size = 100
228
  insert_data = []
229
 
 
 
230
  batches: List[Tuple[str, List[str]]] = []
231
  for url, list_chunks in chunking_results.items():
232
  for i in range(0, len(list_chunks), embed_batch_size):
@@ -258,13 +282,13 @@ CREATE TABLE {self.table_name} (
258
  ]
259
  )
260
  query = f"""
261
- INSERT INTO {self.table_name} (url, chunk, vec) VALUES {value_str};
262
  """
263
  self.db_con.execute(query)
264
 
265
  self.db_con.execute(
266
  f"""
267
- CREATE INDEX cos_idx ON {self.table_name} USING HNSW (vec)
268
  WITH (metric = 'cosine');
269
  """
270
  )
@@ -272,19 +296,20 @@ CREATE TABLE {self.table_name} (
272
  self.db_con.execute(
273
  f"""
274
  PRAGMA create_fts_index(
275
- {self.table_name}, 'doc_id', 'chunk'
276
  );
277
  """
278
  )
279
  self.logger.info(f"✅ Created the full text search index ...")
 
280
 
281
- def vector_search(self, query: str) -> List[Dict[str, Any]]:
282
  client = self._get_api_client()
283
  embeddings = self.get_embedding(client, [query])[0]
284
 
285
  query_result: duckdb.DuckDBPyRelation = self.db_con.sql(
286
  f"""
287
- SELECT * FROM {self.table_name}
288
  ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}])
289
  LIMIT 10;
290
  """
@@ -383,84 +408,66 @@ Here is the context:
383
  response_str = completion.choices[0].message.content
384
  return response_str
385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- def _read_url_list(url_list_file: str) -> str:
388
- if url_list_file is None:
389
- return None
390
-
391
- with open(url_list_file, "r") as f:
392
- links = f.readlines()
393
- links = [
394
- link.strip()
395
- for link in links
396
- if link.strip() != "" and not link.startswith("#")
397
- ]
398
- return "\n".join(links)
399
-
400
-
401
- def _run_query(
402
- query: str,
403
- date_restrict: int,
404
- target_site: str,
405
- output_language: str,
406
- output_length: int,
407
- url_list_str: str,
408
- model_name: str,
409
- log_level: str,
410
- ) -> str:
411
- logger = get_logger(log_level)
412
-
413
- ask = Ask(logger=logger)
414
-
415
- if url_list_str is None or url_list_str.strip() == "":
416
- logger.info("Searching the web ...")
417
- links = ask.search_web(query, date_restrict, target_site)
418
- logger.info(f"✅ Found {len(links)} links for query: {query}")
419
- for i, link in enumerate(links):
420
- logger.debug(f"{i+1}. {link}")
421
- else:
422
- links = url_list_str.split("\n")
423
-
424
- logger.info("Scraping the URLs ...")
425
- scrape_results = ask.scrape_urls(links)
426
- logger.info(f"✅ Scraped {len(scrape_results)} URLs.")
427
-
428
- logger.info("Chunking the text ...")
429
- chunking_results = ask.chunk_results(scrape_results, 1000, 100)
430
- total_chunks = 0
431
- for url, chunks in chunking_results.items():
432
- logger.debug(f"URL: {url}")
433
- total_chunks += len(chunks)
434
- for i, chunk in enumerate(chunks):
435
- logger.debug(f"Chunk {i+1}: {chunk}")
436
- logger.info(f"✅ Generated {total_chunks} chunks ...")
437
-
438
- logger.info(f"Saving {total_chunks} chunks to DB ...")
439
- ask.save_to_db(chunking_results)
440
- logger.info(f"✅ Successfully embedded and saved chunks to DB.")
441
-
442
- logger.info("Querying the vector DB to get context ...")
443
- matched_chunks = ask.vector_search(query)
444
- for i, result in enumerate(matched_chunks):
445
- logger.debug(f"{i+1}. {result}")
446
- logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
447
-
448
- logger.info("Running inference with context ...")
449
- answer = ask.run_inference(
450
- query=query,
451
- model_name=model_name,
452
- matched_chunks=matched_chunks,
453
- output_language=output_language,
454
- output_length=output_length,
455
- )
456
- logger.info("✅ Finished inference API call.")
457
- logger.info("generateing output ...")
458
-
459
- answer = f"# Answer\n\n{answer}\n"
460
- references = "\n".join(
461
- [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
462
- )
463
- return f"{answer}\n\n# References\n\n{references}"
464
 
465
 
466
  def launch_gradio(
@@ -471,11 +478,12 @@ def launch_gradio(
471
  output_length: int,
472
  url_list_str: str,
473
  model_name: str,
474
- log_level: str,
475
  share_ui: bool,
 
476
  ) -> None:
 
477
  iface = gr.Interface(
478
- fn=_run_query,
479
  inputs=[
480
  gr.Textbox(label="Query", value=query),
481
  gr.Number(
@@ -483,7 +491,7 @@ def launch_gradio(
483
  value=date_restrict,
484
  ),
485
  gr.Textbox(
486
- label="Target Sites (Optional) [Empty means seach the whole web.]",
487
  value=target_site,
488
  ),
489
  gr.Textbox(
@@ -503,7 +511,6 @@ def launch_gradio(
503
  ],
504
  additional_inputs=[
505
  gr.Textbox(label="Model Name", value=model_name),
506
- gr.Textbox(label="Log Level", value=log_level),
507
  ],
508
  outputs="text",
509
  show_progress=True,
@@ -515,12 +522,7 @@ def launch_gradio(
515
  iface.launch(share=share_ui)
516
 
517
 
518
- @click.command(help="Search web for the query and summarize the results")
519
- @click.option(
520
- "--web-ui",
521
- is_flag=True,
522
- help="Launch the web interface",
523
- )
524
  @click.option("--query", "-q", required=False, help="Query to search")
525
  @click.option(
526
  "--date-restrict",
@@ -565,6 +567,11 @@ def launch_gradio(
565
  default="gpt-4o-mini",
566
  help="Model name to use for inference",
567
  )
 
 
 
 
 
568
  @click.option(
569
  "-l",
570
  "--log-level",
@@ -575,7 +582,6 @@ def launch_gradio(
575
  show_default=True,
576
  )
577
  def search_extract_summarize(
578
- web_ui: bool,
579
  query: str,
580
  date_restrict: int,
581
  target_site: str,
@@ -583,9 +589,11 @@ def search_extract_summarize(
583
  output_length: int,
584
  url_list_file: str,
585
  model_name: str,
 
586
  log_level: str,
587
  ):
588
  load_dotenv(dotenv_path=default_env_file, override=False)
 
589
 
590
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
591
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
@@ -600,14 +608,14 @@ def search_extract_summarize(
600
  output_length=output_length,
601
  url_list_str=_read_url_list(url_list_file),
602
  model_name=model_name,
603
- log_level=log_level,
604
  share_ui=share_ui,
 
605
  )
606
  else:
607
  if query is None:
608
  raise Exception("Query is required for the command line mode")
609
-
610
- result = _run_query(
611
  query=query,
612
  date_restrict=date_restrict,
613
  target_site=target_site,
@@ -615,7 +623,6 @@ def search_extract_summarize(
615
  output_length=output_length,
616
  url_list_str=_read_url_list(url_list_file),
617
  model_name=model_name,
618
- log_level=log_level,
619
  )
620
  click.echo(result)
621
 
 
3
  import os
4
  import urllib.parse
5
  from concurrent.futures import ThreadPoolExecutor
6
+ from datetime import datetime
7
  from functools import partial
8
  from typing import Any, Dict, List, Optional, Tuple
9
 
 
20
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
21
 
22
 
23
+ def _get_logger(log_level: str) -> logging.Logger:
24
  logger = logging.getLogger(__name__)
25
  logger.setLevel(log_level)
26
+ if len(logger.handlers) > 0:
27
+ return logger
28
  handler = logging.StreamHandler()
29
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
30
  handler.setFormatter(formatter)
 
32
  return logger
33
 
34
 
35
+ def _read_url_list(url_list_file: str) -> str:
36
+ if url_list_file is None:
37
+ return None
38
+
39
+ with open(url_list_file, "r") as f:
40
+ links = f.readlines()
41
+ links = [
42
+ link.strip()
43
+ for link in links
44
+ if link.strip() != "" and not link.startswith("#")
45
+ ]
46
+ return "\n".join(links)
47
+
48
+
49
  class Ask:
50
 
51
  def __init__(self, logger: Optional[logging.Logger] = None):
 
54
  if logger is not None:
55
  self.logger = logger
56
  else:
57
+ self.logger = _get_logger("INFO")
58
 
 
59
  self.db_con = duckdb.connect(":memory:")
60
 
61
  self.db_con.install_extension("vss")
 
64
  self.db_con.load_extension("fts")
65
  self.db_con.sql("CREATE SEQUENCE seq_docid START 1000")
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  self.session = requests.Session()
68
  user_agent: str = (
69
  "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
 
226
  embeddings = self.get_embedding(client, texts)
227
  return chunk_batch, embeddings
228
 
229
+ def _create_table(self) -> str:
230
+ # Simple ways to get a unique table name
231
+ timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
232
+ table_name = f"document_chunks_{timestamp}"
233
+
234
+ self.db_con.execute(
235
+ f"""
236
+ CREATE TABLE {table_name} (
237
+ doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'),
238
+ url TEXT,
239
+ chunk TEXT,
240
+ vec FLOAT[{self.embedding_dimensions}]
241
+ );
242
+ """
243
+ )
244
+ return table_name
245
+
246
+ def save_to_db(self, chunking_results: Dict[str, List[str]]) -> str:
247
  client = self._get_api_client()
248
  embed_batch_size = 50
249
  query_batch_size = 100
250
  insert_data = []
251
 
252
+ table_name = self._create_table()
253
+
254
  batches: List[Tuple[str, List[str]]] = []
255
  for url, list_chunks in chunking_results.items():
256
  for i in range(0, len(list_chunks), embed_batch_size):
 
282
  ]
283
  )
284
  query = f"""
285
+ INSERT INTO {table_name} (url, chunk, vec) VALUES {value_str};
286
  """
287
  self.db_con.execute(query)
288
 
289
  self.db_con.execute(
290
  f"""
291
+ CREATE INDEX {table_name}_cos_idx ON {table_name} USING HNSW (vec)
292
  WITH (metric = 'cosine');
293
  """
294
  )
 
296
  self.db_con.execute(
297
  f"""
298
  PRAGMA create_fts_index(
299
+ {table_name}, 'doc_id', 'chunk'
300
  );
301
  """
302
  )
303
  self.logger.info(f"✅ Created the full text search index ...")
304
+ return table_name
305
 
306
+ def vector_search(self, table_name: str, query: str) -> List[Dict[str, Any]]:
307
  client = self._get_api_client()
308
  embeddings = self.get_embedding(client, [query])[0]
309
 
310
  query_result: duckdb.DuckDBPyRelation = self.db_con.sql(
311
  f"""
312
+ SELECT * FROM {table_name}
313
  ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}])
314
  LIMIT 10;
315
  """
 
408
  response_str = completion.choices[0].message.content
409
  return response_str
410
 
411
+ def run_query(
412
+ self,
413
+ query: str,
414
+ date_restrict: int,
415
+ target_site: str,
416
+ output_language: str,
417
+ output_length: int,
418
+ url_list_str: str,
419
+ model_name: str,
420
+ ) -> str:
421
+ logger = self.logger
422
+ if url_list_str is None or url_list_str.strip() == "":
423
+ logger.info("Searching the web ...")
424
+ links = self.search_web(query, date_restrict, target_site)
425
+ logger.info(f"✅ Found {len(links)} links for query: {query}")
426
+ for i, link in enumerate(links):
427
+ logger.debug(f"{i+1}. {link}")
428
+ else:
429
+ links = url_list_str.split("\n")
430
+
431
+ logger.info("Scraping the URLs ...")
432
+ scrape_results = self.scrape_urls(links)
433
+ logger.info(f"✅ Scraped {len(scrape_results)} URLs.")
434
+
435
+ logger.info("Chunking the text ...")
436
+ chunking_results = self.chunk_results(scrape_results, 1000, 100)
437
+ total_chunks = 0
438
+ for url, chunks in chunking_results.items():
439
+ logger.debug(f"URL: {url}")
440
+ total_chunks += len(chunks)
441
+ for i, chunk in enumerate(chunks):
442
+ logger.debug(f"Chunk {i+1}: {chunk}")
443
+ logger.info(f"✅ Generated {total_chunks} chunks ...")
444
+
445
+ logger.info(f"Saving {total_chunks} chunks to DB ...")
446
+ table_name = self.save_to_db(chunking_results)
447
+ logger.info(f"✅ Successfully embedded and saved chunks to DB.")
448
+
449
+ logger.info("Querying the vector DB to get context ...")
450
+ matched_chunks = self.vector_search(table_name, query)
451
+ for i, result in enumerate(matched_chunks):
452
+ logger.debug(f"{i+1}. {result}")
453
+ logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
454
+
455
+ logger.info("Running inference with context ...")
456
+ answer = self.run_inference(
457
+ query=query,
458
+ model_name=model_name,
459
+ matched_chunks=matched_chunks,
460
+ output_language=output_language,
461
+ output_length=output_length,
462
+ )
463
+ logger.info("✅ Finished inference API call.")
464
+ logger.info("Generating output ...")
465
 
466
+ answer = f"# Answer\n\n{answer}\n"
467
+ references = "\n".join(
468
+ [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
469
+ )
470
+ return f"{answer}\n\n# References\n\n{references}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
 
473
  def launch_gradio(
 
478
  output_length: int,
479
  url_list_str: str,
480
  model_name: str,
 
481
  share_ui: bool,
482
+ logger: logging.Logger,
483
  ) -> None:
484
+ ask = Ask(logger=logger)
485
  iface = gr.Interface(
486
+ fn=ask.run_query,
487
  inputs=[
488
  gr.Textbox(label="Query", value=query),
489
  gr.Number(
 
491
  value=date_restrict,
492
  ),
493
  gr.Textbox(
494
+ label="Target Sites (Optional) [Empty means search the whole web.]",
495
  value=target_site,
496
  ),
497
  gr.Textbox(
 
511
  ],
512
  additional_inputs=[
513
  gr.Textbox(label="Model Name", value=model_name),
 
514
  ],
515
  outputs="text",
516
  show_progress=True,
 
522
  iface.launch(share=share_ui)
523
 
524
 
525
+ @click.command(help="Search web for the query and summarize the results.")
 
 
 
 
 
526
  @click.option("--query", "-q", required=False, help="Query to search")
527
  @click.option(
528
  "--date-restrict",
 
567
  default="gpt-4o-mini",
568
  help="Model name to use for inference",
569
  )
570
+ @click.option(
571
+ "--web-ui",
572
+ is_flag=True,
573
+ help="Launch the web interface",
574
+ )
575
  @click.option(
576
  "-l",
577
  "--log-level",
 
582
  show_default=True,
583
  )
584
  def search_extract_summarize(
 
585
  query: str,
586
  date_restrict: int,
587
  target_site: str,
 
589
  output_length: int,
590
  url_list_file: str,
591
  model_name: str,
592
+ web_ui: bool,
593
  log_level: str,
594
  ):
595
  load_dotenv(dotenv_path=default_env_file, override=False)
596
+ logger = _get_logger(log_level)
597
 
598
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
599
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
 
608
  output_length=output_length,
609
  url_list_str=_read_url_list(url_list_file),
610
  model_name=model_name,
 
611
  share_ui=share_ui,
612
+ logger=logger,
613
  )
614
  else:
615
  if query is None:
616
  raise Exception("Query is required for the command line mode")
617
+ ask = Ask(logger=logger)
618
+ result = ask.run_query(
619
  query=query,
620
  date_restrict=date_restrict,
621
  target_site=target_site,
 
623
  output_length=output_length,
624
  url_list_str=_read_url_list(url_list_file),
625
  model_name=model_name,
 
626
  )
627
  click.echo(result)
628