LeetTools commited on
Commit
d493b39
·
verified ·
1 Parent(s): 084346e

Upload ask.py

Browse files
Files changed (1) hide show
  1. ask.py +138 -69
ask.py CHANGED
@@ -17,11 +17,22 @@ from bs4 import BeautifulSoup
17
  from dotenv import load_dotenv
18
  from jinja2 import BaseLoader, Environment
19
  from openai import OpenAI
 
20
 
21
  script_dir = os.path.dirname(os.path.abspath(__file__))
22
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
23
 
24
 
 
 
 
 
 
 
 
 
 
 
25
  def _get_logger(log_level: str) -> logging.Logger:
26
  logger = logging.getLogger(__name__)
27
  logger.setLevel(log_level)
@@ -35,18 +46,18 @@ def _get_logger(log_level: str) -> logging.Logger:
35
  return logger
36
 
37
 
38
- def _read_url_list(url_list_file: str) -> str:
39
- if url_list_file is None:
40
- return None
41
 
42
  with open(url_list_file, "r") as f:
43
  links = f.readlines()
44
- links = [
45
  link.strip()
46
  for link in links
47
  if link.strip() != "" and not link.startswith("#")
48
  ]
49
- return "\n".join(links)
50
 
51
 
52
  class Ask:
@@ -102,17 +113,17 @@ class Ask:
102
  self.embedding_model = "text-embedding-3-small"
103
  self.embedding_dimensions = 1536
104
 
105
- def search_web(self, query: str, date_restrict: int, target_site: str) -> List[str]:
106
  escaped_query = urllib.parse.quote(query)
107
  url_base = (
108
  f"https://www.googleapis.com/customsearch/v1?key={self.search_api_key}"
109
  f"&cx={self.search_project_id}&q={escaped_query}"
110
  )
111
  url_paras = f"&safe=active"
112
- if date_restrict is not None and date_restrict > 0:
113
- url_paras += f"&dateRestrict={date_restrict}"
114
- if target_site is not None and target_site != "":
115
- url_paras += f"&siteSearch={target_site}&siteSearchFilter=i"
116
  url = f"{url_base}{url_paras}"
117
 
118
  self.logger.debug(f"Searching for query: {query}")
@@ -153,6 +164,7 @@ class Ask:
153
  return found_links
154
 
155
  def _scape_url(self, url: str) -> Tuple[str, str]:
 
156
  try:
157
  response = self.session.get(url, timeout=10)
158
  soup = BeautifulSoup(response.content, "lxml", from_encoding="utf-8")
@@ -163,6 +175,9 @@ class Ask:
163
  body_text = " ".join(body_text.split()).strip()
164
  self.logger.debug(f"Scraped {url}: {body_text}...")
165
  if len(body_text) > 100:
 
 
 
166
  return url, body_text
167
  else:
168
  self.logger.warning(
@@ -246,7 +261,10 @@ CREATE TABLE {table_name} (
246
  )
247
  return table_name
248
 
249
- def save_to_db(self, chunking_results: Dict[str, List[str]]) -> str:
 
 
 
250
  client = self._get_api_client()
251
  embed_batch_size = 50
252
  query_batch_size = 100
@@ -266,6 +284,9 @@ CREATE TABLE {table_name} (
266
  all_embeddings = executor.map(partial_get_embedding, batches)
267
  self.logger.info(f"✅ Finished embedding.")
268
 
 
 
 
269
  for chunk_batch, embeddings in all_embeddings:
270
  url = chunk_batch[0]
271
  list_chunks = chunk_batch[1]
@@ -277,7 +298,6 @@ CREATE TABLE {table_name} (
277
  )
278
 
279
  for i in range(0, len(insert_data), query_batch_size):
280
- # insert the batch into DuckDB
281
  value_str = ", ".join(
282
  [
283
  f"('{url}', '{chunk}', {embedding})"
@@ -306,7 +326,13 @@ CREATE TABLE {table_name} (
306
  self.logger.info(f"✅ Created the full text search index ...")
307
  return table_name
308
 
309
- def vector_search(self, table_name: str, query: str) -> List[Dict[str, Any]]:
 
 
 
 
 
 
310
  client = self._get_api_client()
311
  embeddings = self.get_embedding(client, [query])[0]
312
 
@@ -328,6 +354,10 @@ CREATE TABLE {table_name} (
328
  }
329
  matched_chunks.append(result_record)
330
 
 
 
 
 
331
  return matched_chunks
332
 
333
  def _get_api_client(self) -> OpenAI:
@@ -341,10 +371,8 @@ CREATE TABLE {table_name} (
341
  def run_inference(
342
  self,
343
  query: str,
344
- model_name: str,
345
  matched_chunks: List[Dict[str, Any]],
346
- output_language: str,
347
- output_length: int,
348
  ) -> str:
349
  system_prompt = (
350
  "You are an expert summarizing the answers based on the provided contents."
@@ -371,11 +399,11 @@ Here is the context:
371
  for i, chunk in enumerate(matched_chunks):
372
  context += f"[{i+1}] {chunk['chunk']}\n"
373
 
374
- if output_length is None or output_length == 0:
375
  length_instructions = ""
376
  else:
377
  length_instructions = (
378
- f"Please provide the answer in { output_length } words."
379
  )
380
 
381
  user_prompt = self._render_template(
@@ -383,17 +411,19 @@ Here is the context:
383
  {
384
  "query": query,
385
  "context": context,
386
- "language": output_language,
387
  "length_instructions": length_instructions,
388
  },
389
  )
390
 
391
- self.logger.debug(f"Running inference with model: {model_name}")
 
 
392
  self.logger.debug(f"Final user prompt: {user_prompt}")
393
 
394
  api_client = self._get_api_client()
395
  completion = api_client.chat.completions.create(
396
- model=model_name,
397
  messages=[
398
  {
399
  "role": "system",
@@ -411,7 +441,7 @@ Here is the context:
411
  response_str = completion.choices[0].message.content
412
  return response_str
413
 
414
- def run_query(
415
  self,
416
  query: str,
417
  date_restrict: int,
@@ -419,11 +449,27 @@ Here is the context:
419
  output_language: str,
420
  output_length: int,
421
  url_list_str: str,
422
- model_name: str,
 
423
  ) -> Generator[Tuple[str, str], None, Tuple[str, str]]:
424
  logger = self.logger
425
  log_queue = Queue()
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  queue_handler = logging.Handler()
428
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
429
  queue_handler.emit = lambda record: log_queue.put(formatter.format(record))
@@ -439,17 +485,18 @@ Here is the context:
439
  break
440
  return "\n".join(logs)
441
 
 
442
  def process_with_logs():
443
- if url_list_str is None or url_list_str.strip() == "":
 
 
444
  logger.info("Searching the web ...")
445
  yield "", update_logs()
446
- links = self.search_web(query, date_restrict, target_site)
447
  logger.info(f"✅ Found {len(links)} links for query: {query}")
448
  for i, link in enumerate(links):
449
  logger.debug(f"{i+1}. {link}")
450
  yield "", update_logs()
451
- else:
452
- links = url_list_str.split("\n")
453
 
454
  logger.info("Scraping the URLs ...")
455
  yield "", update_logs()
@@ -471,12 +518,12 @@ Here is the context:
471
 
472
  logger.info(f"Saving {total_chunks} chunks to DB ...")
473
  yield "", update_logs()
474
- table_name = self.save_to_db(chunking_results)
475
  logger.info(f"✅ Successfully embedded and saved chunks to DB.")
476
  yield "", update_logs()
477
 
478
  logger.info("Querying the vector DB to get context ...")
479
- matched_chunks = self.vector_search(table_name, query)
480
  for i, result in enumerate(matched_chunks):
481
  logger.debug(f"{i+1}. {result}")
482
  logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
@@ -486,10 +533,8 @@ Here is the context:
486
  yield "", update_logs()
487
  answer = self.run_inference(
488
  query=query,
489
- model_name=model_name,
490
  matched_chunks=matched_chunks,
491
- output_language=output_language,
492
- output_length=output_length,
493
  )
494
  logger.info("✅ Finished inference API call.")
495
  logger.info("Generating output ...")
@@ -514,15 +559,30 @@ Here is the context:
514
 
515
  return final_result, logs
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
 
518
  def launch_gradio(
519
  query: str,
520
- date_restrict: int,
521
- target_site: str,
522
- output_language: str,
523
- output_length: int,
524
- url_list_str: str,
525
- model_name: str,
526
  share_ui: bool,
527
  logger: logging.Logger,
528
  ) -> None:
@@ -537,31 +597,38 @@ def launch_gradio(
537
  with gr.Column():
538
 
539
  query_input = gr.Textbox(label="Query", value=query)
 
 
 
 
540
  date_restrict_input = gr.Number(
541
  label="Date Restrict (Optional) [0 or empty means no date limit.]",
542
- value=date_restrict,
543
  )
544
  target_site_input = gr.Textbox(
545
  label="Target Sites (Optional) [Empty means searching the whole web.]",
546
- value=target_site,
547
  )
548
  output_language_input = gr.Textbox(
549
  label="Output Language (Optional) [Default is English.]",
550
- value=output_language,
551
  )
552
  output_length_input = gr.Number(
553
  label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
554
- value=output_length,
555
  )
556
  url_list_input = gr.Textbox(
557
  label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
558
  lines=5,
559
  max_lines=20,
560
- value=url_list_str,
561
  )
562
 
563
  with gr.Accordion("More Options", open=False):
564
- model_name_input = gr.Textbox(label="Model Name", value=model_name)
 
 
 
565
 
566
  submit_button = gr.Button("Submit")
567
 
@@ -570,7 +637,7 @@ def launch_gradio(
570
  logs_output = gr.Textbox(label="Logs", lines=10)
571
 
572
  submit_button.click(
573
- fn=ask.run_query,
574
  inputs=[
575
  query_input,
576
  date_restrict_input,
@@ -578,7 +645,8 @@ def launch_gradio(
578
  output_language_input,
579
  output_length_input,
580
  url_list_input,
581
- model_name_input,
 
582
  ],
583
  outputs=[answer_output, logs_output],
584
  )
@@ -593,14 +661,14 @@ def launch_gradio(
593
  "-d",
594
  type=int,
595
  required=False,
596
- default=None,
597
  help="Restrict search results to a specific date range, default is no restriction",
598
  )
599
  @click.option(
600
  "--target-site",
601
  "-s",
602
  required=False,
603
- default=None,
604
  help="Restrict search results to a specific site, default is no restriction",
605
  )
606
  @click.option(
@@ -613,24 +681,29 @@ def launch_gradio(
613
  "--output-length",
614
  type=int,
615
  required=False,
616
- default=None,
617
  help="Output length for the answer",
618
  )
619
  @click.option(
620
  "--url-list-file",
621
  type=str,
622
  required=False,
623
- default=None,
624
  show_default=True,
625
  help="Instead of doing web search, scrape the target URL list and answer the query based on the content",
626
  )
627
  @click.option(
628
- "--model-name",
629
  "-m",
630
  required=False,
631
  default="gpt-4o-mini",
632
  help="Model name to use for inference",
633
  )
 
 
 
 
 
634
  @click.option(
635
  "--web-ui",
636
  is_flag=True,
@@ -652,13 +725,24 @@ def search_extract_summarize(
652
  output_language: str,
653
  output_length: int,
654
  url_list_file: str,
655
- model_name: str,
 
656
  web_ui: bool,
657
  log_level: str,
658
  ):
659
  load_dotenv(dotenv_path=default_env_file, override=False)
660
  logger = _get_logger(log_level)
661
 
 
 
 
 
 
 
 
 
 
 
662
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
663
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
664
  share_ui = True
@@ -666,12 +750,7 @@ def search_extract_summarize(
666
  share_ui = False
667
  launch_gradio(
668
  query=query,
669
- date_restrict=date_restrict,
670
- target_site=target_site,
671
- output_language=output_language,
672
- output_length=output_length,
673
- url_list_str=_read_url_list(url_list_file),
674
- model_name=model_name,
675
  share_ui=share_ui,
676
  logger=logger,
677
  )
@@ -680,17 +759,7 @@ def search_extract_summarize(
680
  raise Exception("Query is required for the command line mode")
681
  ask = Ask(logger=logger)
682
 
683
- for result, _ in ask.run_query(
684
- query=query,
685
- date_restrict=date_restrict,
686
- target_site=target_site,
687
- output_language=output_language,
688
- output_length=output_length,
689
- url_list_str=_read_url_list(url_list_file),
690
- model_name=model_name,
691
- ):
692
- final_result = result
693
-
694
  click.echo(final_result)
695
 
696
 
 
17
  from dotenv import load_dotenv
18
  from jinja2 import BaseLoader, Environment
19
  from openai import OpenAI
20
+ from pydantic import BaseModel
21
 
22
  script_dir = os.path.dirname(os.path.abspath(__file__))
23
  default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
24
 
25
 
26
+ class AskSettings(BaseModel):
27
+ date_restrict: int
28
+ target_site: str
29
+ output_language: str
30
+ output_length: int
31
+ url_list: List[str]
32
+ inference_model_name: str
33
+ hybrid_search: bool
34
+
35
+
36
  def _get_logger(log_level: str) -> logging.Logger:
37
  logger = logging.getLogger(__name__)
38
  logger.setLevel(log_level)
 
46
  return logger
47
 
48
 
49
+ def _read_url_list(url_list_file: str) -> List[str]:
50
+ if not url_list_file:
51
+ return []
52
 
53
  with open(url_list_file, "r") as f:
54
  links = f.readlines()
55
+ url_list = [
56
  link.strip()
57
  for link in links
58
  if link.strip() != "" and not link.startswith("#")
59
  ]
60
+ return url_list
61
 
62
 
63
  class Ask:
 
113
  self.embedding_model = "text-embedding-3-small"
114
  self.embedding_dimensions = 1536
115
 
116
+ def search_web(self, query: str, settings: AskSettings) -> List[str]:
117
  escaped_query = urllib.parse.quote(query)
118
  url_base = (
119
  f"https://www.googleapis.com/customsearch/v1?key={self.search_api_key}"
120
  f"&cx={self.search_project_id}&q={escaped_query}"
121
  )
122
  url_paras = f"&safe=active"
123
+ if settings.date_restrict > 0:
124
+ url_paras += f"&dateRestrict={settings.date_restrict}"
125
+ if settings.target_site:
126
+ url_paras += f"&siteSearch={settings.target_site}&siteSearchFilter=i"
127
  url = f"{url_base}{url_paras}"
128
 
129
  self.logger.debug(f"Searching for query: {query}")
 
164
  return found_links
165
 
166
  def _scape_url(self, url: str) -> Tuple[str, str]:
167
+ self.logger.info(f"Scraping {url} ...")
168
  try:
169
  response = self.session.get(url, timeout=10)
170
  soup = BeautifulSoup(response.content, "lxml", from_encoding="utf-8")
 
175
  body_text = " ".join(body_text.split()).strip()
176
  self.logger.debug(f"Scraped {url}: {body_text}...")
177
  if len(body_text) > 100:
178
+ self.logger.info(
179
+ f"✅ Successfully scraped {url} with length: {len(body_text)}"
180
+ )
181
  return url, body_text
182
  else:
183
  self.logger.warning(
 
261
  )
262
  return table_name
263
 
264
+ def save_chunks_to_db(self, chunking_results: Dict[str, List[str]]) -> str:
265
+ """
266
+ The key of chunking_results is the URL and the value is the list of chunks.
267
+ """
268
  client = self._get_api_client()
269
  embed_batch_size = 50
270
  query_batch_size = 100
 
284
  all_embeddings = executor.map(partial_get_embedding, batches)
285
  self.logger.info(f"✅ Finished embedding.")
286
 
287
+ # we batch the insert data to speed up the insertion operation
288
+ # although the DuckDB doc says executeMany is optimized for batch insert
289
+ # but we found that it is faster to batch the insert data and run a single insert
290
  for chunk_batch, embeddings in all_embeddings:
291
  url = chunk_batch[0]
292
  list_chunks = chunk_batch[1]
 
298
  )
299
 
300
  for i in range(0, len(insert_data), query_batch_size):
 
301
  value_str = ", ".join(
302
  [
303
  f"('{url}', '{chunk}', {embedding})"
 
326
  self.logger.info(f"✅ Created the full text search index ...")
327
  return table_name
328
 
329
+ def vector_search(
330
+ self, table_name: str, query: str, settings: AskSettings
331
+ ) -> List[Dict[str, Any]]:
332
+ """
333
+ The return value is a list of {url: str, chunk: str} records.
334
+ In a real world, we will define a class of Chunk to have more metadata such as offsets.
335
+ """
336
  client = self._get_api_client()
337
  embeddings = self.get_embedding(client, [query])[0]
338
 
 
354
  }
355
  matched_chunks.append(result_record)
356
 
357
+ if settings.hybrid_search:
358
+ self.logger.info("Running full-text search ...")
359
+ pass
360
+
361
  return matched_chunks
362
 
363
  def _get_api_client(self) -> OpenAI:
 
371
  def run_inference(
372
  self,
373
  query: str,
 
374
  matched_chunks: List[Dict[str, Any]],
375
+ settings: AskSettings,
 
376
  ) -> str:
377
  system_prompt = (
378
  "You are an expert summarizing the answers based on the provided contents."
 
399
  for i, chunk in enumerate(matched_chunks):
400
  context += f"[{i+1}] {chunk['chunk']}\n"
401
 
402
+ if not settings.output_length:
403
  length_instructions = ""
404
  else:
405
  length_instructions = (
406
+ f"Please provide the answer in { settings.output_length } words."
407
  )
408
 
409
  user_prompt = self._render_template(
 
411
  {
412
  "query": query,
413
  "context": context,
414
+ "language": settings.output_language,
415
  "length_instructions": length_instructions,
416
  },
417
  )
418
 
419
+ self.logger.debug(
420
+ f"Running inference with model: {settings.inference_model_name}"
421
+ )
422
  self.logger.debug(f"Final user prompt: {user_prompt}")
423
 
424
  api_client = self._get_api_client()
425
  completion = api_client.chat.completions.create(
426
+ model=settings.inference_model_name,
427
  messages=[
428
  {
429
  "role": "system",
 
441
  response_str = completion.choices[0].message.content
442
  return response_str
443
 
444
+ def run_query_gradio(
445
  self,
446
  query: str,
447
  date_restrict: int,
 
449
  output_language: str,
450
  output_length: int,
451
  url_list_str: str,
452
+ inference_model_name: str,
453
+ hybrid_search: bool,
454
  ) -> Generator[Tuple[str, str], None, Tuple[str, str]]:
455
  logger = self.logger
456
  log_queue = Queue()
457
 
458
+ if url_list_str:
459
+ url_list = url_list_str.split("\n")
460
+ else:
461
+ url_list = []
462
+
463
+ settings = AskSettings(
464
+ date_restrict=date_restrict,
465
+ target_site=target_site,
466
+ output_language=output_language,
467
+ output_length=output_length,
468
+ url_list=url_list,
469
+ inference_model_name=inference_model_name,
470
+ hybrid_search=hybrid_search,
471
+ )
472
+
473
  queue_handler = logging.Handler()
474
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
475
  queue_handler.emit = lambda record: log_queue.put(formatter.format(record))
 
485
  break
486
  return "\n".join(logs)
487
 
488
+ # wrap the process in a generator to yield the logs to integrate with GradIO
489
  def process_with_logs():
490
+ if len(settings.url_list) > 0:
491
+ links = settings.url_list
492
+ else:
493
  logger.info("Searching the web ...")
494
  yield "", update_logs()
495
+ links = self.search_web(query, settings)
496
  logger.info(f"✅ Found {len(links)} links for query: {query}")
497
  for i, link in enumerate(links):
498
  logger.debug(f"{i+1}. {link}")
499
  yield "", update_logs()
 
 
500
 
501
  logger.info("Scraping the URLs ...")
502
  yield "", update_logs()
 
518
 
519
  logger.info(f"Saving {total_chunks} chunks to DB ...")
520
  yield "", update_logs()
521
+ table_name = self.save_chunks_to_db(chunking_results)
522
  logger.info(f"✅ Successfully embedded and saved chunks to DB.")
523
  yield "", update_logs()
524
 
525
  logger.info("Querying the vector DB to get context ...")
526
+ matched_chunks = self.vector_search(table_name, query, settings)
527
  for i, result in enumerate(matched_chunks):
528
  logger.debug(f"{i+1}. {result}")
529
  logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
 
533
  yield "", update_logs()
534
  answer = self.run_inference(
535
  query=query,
 
536
  matched_chunks=matched_chunks,
537
+ settings=settings,
 
538
  )
539
  logger.info("✅ Finished inference API call.")
540
  logger.info("Generating output ...")
 
559
 
560
  return final_result, logs
561
 
562
+ def run_query(
563
+ self,
564
+ query: str,
565
+ settings: AskSettings,
566
+ ) -> str:
567
+ url_list_str = "\n".join(settings.url_list)
568
+
569
+ for result, logs in self.run_query_gradio(
570
+ query=query,
571
+ date_restrict=settings.date_restrict,
572
+ target_site=settings.target_site,
573
+ output_language=settings.output_language,
574
+ output_length=settings.output_length,
575
+ url_list_str=url_list_str,
576
+ inference_model_name=settings.inference_model_name,
577
+ hybrid_search=settings.hybrid_search,
578
+ ):
579
+ final_result = result
580
+ return final_result
581
+
582
 
583
  def launch_gradio(
584
  query: str,
585
+ init_settings: AskSettings,
 
 
 
 
 
586
  share_ui: bool,
587
  logger: logging.Logger,
588
  ) -> None:
 
597
  with gr.Column():
598
 
599
  query_input = gr.Textbox(label="Query", value=query)
600
+ hybrid_search_input = gr.Checkbox(
601
+ label="Hybrid Search [Use both vector search and full-text search.]",
602
+ value=init_settings.hybrid_search,
603
+ )
604
  date_restrict_input = gr.Number(
605
  label="Date Restrict (Optional) [0 or empty means no date limit.]",
606
+ value=init_settings.date_restrict,
607
  )
608
  target_site_input = gr.Textbox(
609
  label="Target Sites (Optional) [Empty means searching the whole web.]",
610
+ value=init_settings.target_site,
611
  )
612
  output_language_input = gr.Textbox(
613
  label="Output Language (Optional) [Default is English.]",
614
+ value=init_settings.output_language,
615
  )
616
  output_length_input = gr.Number(
617
  label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
618
+ value=init_settings.output_length,
619
  )
620
  url_list_input = gr.Textbox(
621
  label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
622
  lines=5,
623
  max_lines=20,
624
+ value="\n".join(init_settings.url_list),
625
  )
626
 
627
  with gr.Accordion("More Options", open=False):
628
+ inference_model_name_input = gr.Textbox(
629
+ label="Inference Model Name",
630
+ value=init_settings.inference_model_name,
631
+ )
632
 
633
  submit_button = gr.Button("Submit")
634
 
 
637
  logs_output = gr.Textbox(label="Logs", lines=10)
638
 
639
  submit_button.click(
640
+ fn=ask.run_query_gradio,
641
  inputs=[
642
  query_input,
643
  date_restrict_input,
 
645
  output_language_input,
646
  output_length_input,
647
  url_list_input,
648
+ inference_model_name_input,
649
+ hybrid_search_input,
650
  ],
651
  outputs=[answer_output, logs_output],
652
  )
 
661
  "-d",
662
  type=int,
663
  required=False,
664
+ default=0,
665
  help="Restrict search results to a specific date range, default is no restriction",
666
  )
667
  @click.option(
668
  "--target-site",
669
  "-s",
670
  required=False,
671
+ default="",
672
  help="Restrict search results to a specific site, default is no restriction",
673
  )
674
  @click.option(
 
681
  "--output-length",
682
  type=int,
683
  required=False,
684
+ default=0,
685
  help="Output length for the answer",
686
  )
687
  @click.option(
688
  "--url-list-file",
689
  type=str,
690
  required=False,
691
+ default="",
692
  show_default=True,
693
  help="Instead of doing web search, scrape the target URL list and answer the query based on the content",
694
  )
695
  @click.option(
696
+ "--inference-model-name",
697
  "-m",
698
  required=False,
699
  default="gpt-4o-mini",
700
  help="Model name to use for inference",
701
  )
702
+ @click.option(
703
+ "--hybrid-search",
704
+ is_flag=True,
705
+ help="Use hybrid search mode with both vector search and full-text search",
706
+ )
707
  @click.option(
708
  "--web-ui",
709
  is_flag=True,
 
725
  output_language: str,
726
  output_length: int,
727
  url_list_file: str,
728
+ inference_model_name: str,
729
+ hybrid_search: bool,
730
  web_ui: bool,
731
  log_level: str,
732
  ):
733
  load_dotenv(dotenv_path=default_env_file, override=False)
734
  logger = _get_logger(log_level)
735
 
736
+ settings = AskSettings(
737
+ date_restrict=date_restrict,
738
+ target_site=target_site,
739
+ output_language=output_language,
740
+ output_length=output_length,
741
+ url_list=_read_url_list(url_list_file),
742
+ inference_model_name=inference_model_name,
743
+ hybrid_search=hybrid_search,
744
+ )
745
+
746
  if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
747
  if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
748
  share_ui = True
 
750
  share_ui = False
751
  launch_gradio(
752
  query=query,
753
+ init_settings=settings,
 
 
 
 
 
754
  share_ui=share_ui,
755
  logger=logger,
756
  )
 
759
  raise Exception("Query is required for the command line mode")
760
  ask = Ask(logger=logger)
761
 
762
+ final_result = ask.run_query(query=query, settings=settings)
 
 
 
 
 
 
 
 
 
 
763
  click.echo(final_result)
764
 
765