Upload ask.py
Browse files
ask.py
CHANGED
@@ -3,6 +3,7 @@ import logging
|
|
3 |
import os
|
4 |
import urllib.parse
|
5 |
from concurrent.futures import ThreadPoolExecutor
|
|
|
6 |
from functools import partial
|
7 |
from typing import Any, Dict, List, Optional, Tuple
|
8 |
|
@@ -19,9 +20,11 @@ script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
19 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
20 |
|
21 |
|
22 |
-
def
|
23 |
logger = logging.getLogger(__name__)
|
24 |
logger.setLevel(log_level)
|
|
|
|
|
25 |
handler = logging.StreamHandler()
|
26 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
27 |
handler.setFormatter(formatter)
|
@@ -29,6 +32,20 @@ def get_logger(log_level: str) -> logging.Logger:
|
|
29 |
return logger
|
30 |
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
class Ask:
|
33 |
|
34 |
def __init__(self, logger: Optional[logging.Logger] = None):
|
@@ -37,9 +54,8 @@ class Ask:
|
|
37 |
if logger is not None:
|
38 |
self.logger = logger
|
39 |
else:
|
40 |
-
self.logger =
|
41 |
|
42 |
-
self.table_name = "document_chunks"
|
43 |
self.db_con = duckdb.connect(":memory:")
|
44 |
|
45 |
self.db_con.install_extension("vss")
|
@@ -48,17 +64,6 @@ class Ask:
|
|
48 |
self.db_con.load_extension("fts")
|
49 |
self.db_con.sql("CREATE SEQUENCE seq_docid START 1000")
|
50 |
|
51 |
-
self.db_con.execute(
|
52 |
-
f"""
|
53 |
-
CREATE TABLE {self.table_name} (
|
54 |
-
doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'),
|
55 |
-
url TEXT,
|
56 |
-
chunk TEXT,
|
57 |
-
vec FLOAT[{self.embedding_dimensions}]
|
58 |
-
);
|
59 |
-
"""
|
60 |
-
)
|
61 |
-
|
62 |
self.session = requests.Session()
|
63 |
user_agent: str = (
|
64 |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
@@ -221,12 +226,31 @@ CREATE TABLE {self.table_name} (
|
|
221 |
embeddings = self.get_embedding(client, texts)
|
222 |
return chunk_batch, embeddings
|
223 |
|
224 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
client = self._get_api_client()
|
226 |
embed_batch_size = 50
|
227 |
query_batch_size = 100
|
228 |
insert_data = []
|
229 |
|
|
|
|
|
230 |
batches: List[Tuple[str, List[str]]] = []
|
231 |
for url, list_chunks in chunking_results.items():
|
232 |
for i in range(0, len(list_chunks), embed_batch_size):
|
@@ -258,13 +282,13 @@ CREATE TABLE {self.table_name} (
|
|
258 |
]
|
259 |
)
|
260 |
query = f"""
|
261 |
-
INSERT INTO {
|
262 |
"""
|
263 |
self.db_con.execute(query)
|
264 |
|
265 |
self.db_con.execute(
|
266 |
f"""
|
267 |
-
CREATE INDEX
|
268 |
WITH (metric = 'cosine');
|
269 |
"""
|
270 |
)
|
@@ -272,19 +296,20 @@ CREATE TABLE {self.table_name} (
|
|
272 |
self.db_con.execute(
|
273 |
f"""
|
274 |
PRAGMA create_fts_index(
|
275 |
-
{
|
276 |
);
|
277 |
"""
|
278 |
)
|
279 |
self.logger.info(f"✅ Created the full text search index ...")
|
|
|
280 |
|
281 |
-
def vector_search(self, query: str) -> List[Dict[str, Any]]:
|
282 |
client = self._get_api_client()
|
283 |
embeddings = self.get_embedding(client, [query])[0]
|
284 |
|
285 |
query_result: duckdb.DuckDBPyRelation = self.db_con.sql(
|
286 |
f"""
|
287 |
-
SELECT * FROM {
|
288 |
ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}])
|
289 |
LIMIT 10;
|
290 |
"""
|
@@ -383,84 +408,66 @@ Here is the context:
|
|
383 |
response_str = completion.choices[0].message.content
|
384 |
return response_str
|
385 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
links = f.readlines()
|
393 |
-
links = [
|
394 |
-
link.strip()
|
395 |
-
for link in links
|
396 |
-
if link.strip() != "" and not link.startswith("#")
|
397 |
-
]
|
398 |
-
return "\n".join(links)
|
399 |
-
|
400 |
-
|
401 |
-
def _run_query(
|
402 |
-
query: str,
|
403 |
-
date_restrict: int,
|
404 |
-
target_site: str,
|
405 |
-
output_language: str,
|
406 |
-
output_length: int,
|
407 |
-
url_list_str: str,
|
408 |
-
model_name: str,
|
409 |
-
log_level: str,
|
410 |
-
) -> str:
|
411 |
-
logger = get_logger(log_level)
|
412 |
-
|
413 |
-
ask = Ask(logger=logger)
|
414 |
-
|
415 |
-
if url_list_str is None or url_list_str.strip() == "":
|
416 |
-
logger.info("Searching the web ...")
|
417 |
-
links = ask.search_web(query, date_restrict, target_site)
|
418 |
-
logger.info(f"✅ Found {len(links)} links for query: {query}")
|
419 |
-
for i, link in enumerate(links):
|
420 |
-
logger.debug(f"{i+1}. {link}")
|
421 |
-
else:
|
422 |
-
links = url_list_str.split("\n")
|
423 |
-
|
424 |
-
logger.info("Scraping the URLs ...")
|
425 |
-
scrape_results = ask.scrape_urls(links)
|
426 |
-
logger.info(f"✅ Scraped {len(scrape_results)} URLs.")
|
427 |
-
|
428 |
-
logger.info("Chunking the text ...")
|
429 |
-
chunking_results = ask.chunk_results(scrape_results, 1000, 100)
|
430 |
-
total_chunks = 0
|
431 |
-
for url, chunks in chunking_results.items():
|
432 |
-
logger.debug(f"URL: {url}")
|
433 |
-
total_chunks += len(chunks)
|
434 |
-
for i, chunk in enumerate(chunks):
|
435 |
-
logger.debug(f"Chunk {i+1}: {chunk}")
|
436 |
-
logger.info(f"✅ Generated {total_chunks} chunks ...")
|
437 |
-
|
438 |
-
logger.info(f"Saving {total_chunks} chunks to DB ...")
|
439 |
-
ask.save_to_db(chunking_results)
|
440 |
-
logger.info(f"✅ Successfully embedded and saved chunks to DB.")
|
441 |
-
|
442 |
-
logger.info("Querying the vector DB to get context ...")
|
443 |
-
matched_chunks = ask.vector_search(query)
|
444 |
-
for i, result in enumerate(matched_chunks):
|
445 |
-
logger.debug(f"{i+1}. {result}")
|
446 |
-
logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
|
447 |
-
|
448 |
-
logger.info("Running inference with context ...")
|
449 |
-
answer = ask.run_inference(
|
450 |
-
query=query,
|
451 |
-
model_name=model_name,
|
452 |
-
matched_chunks=matched_chunks,
|
453 |
-
output_language=output_language,
|
454 |
-
output_length=output_length,
|
455 |
-
)
|
456 |
-
logger.info("✅ Finished inference API call.")
|
457 |
-
logger.info("generateing output ...")
|
458 |
-
|
459 |
-
answer = f"# Answer\n\n{answer}\n"
|
460 |
-
references = "\n".join(
|
461 |
-
[f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
|
462 |
-
)
|
463 |
-
return f"{answer}\n\n# References\n\n{references}"
|
464 |
|
465 |
|
466 |
def launch_gradio(
|
@@ -471,11 +478,12 @@ def launch_gradio(
|
|
471 |
output_length: int,
|
472 |
url_list_str: str,
|
473 |
model_name: str,
|
474 |
-
log_level: str,
|
475 |
share_ui: bool,
|
|
|
476 |
) -> None:
|
|
|
477 |
iface = gr.Interface(
|
478 |
-
fn=
|
479 |
inputs=[
|
480 |
gr.Textbox(label="Query", value=query),
|
481 |
gr.Number(
|
@@ -483,7 +491,7 @@ def launch_gradio(
|
|
483 |
value=date_restrict,
|
484 |
),
|
485 |
gr.Textbox(
|
486 |
-
label="Target Sites (Optional) [Empty means
|
487 |
value=target_site,
|
488 |
),
|
489 |
gr.Textbox(
|
@@ -503,7 +511,6 @@ def launch_gradio(
|
|
503 |
],
|
504 |
additional_inputs=[
|
505 |
gr.Textbox(label="Model Name", value=model_name),
|
506 |
-
gr.Textbox(label="Log Level", value=log_level),
|
507 |
],
|
508 |
outputs="text",
|
509 |
show_progress=True,
|
@@ -515,12 +522,7 @@ def launch_gradio(
|
|
515 |
iface.launch(share=share_ui)
|
516 |
|
517 |
|
518 |
-
@click.command(help="Search web for the query and summarize the results")
|
519 |
-
@click.option(
|
520 |
-
"--web-ui",
|
521 |
-
is_flag=True,
|
522 |
-
help="Launch the web interface",
|
523 |
-
)
|
524 |
@click.option("--query", "-q", required=False, help="Query to search")
|
525 |
@click.option(
|
526 |
"--date-restrict",
|
@@ -565,6 +567,11 @@ def launch_gradio(
|
|
565 |
default="gpt-4o-mini",
|
566 |
help="Model name to use for inference",
|
567 |
)
|
|
|
|
|
|
|
|
|
|
|
568 |
@click.option(
|
569 |
"-l",
|
570 |
"--log-level",
|
@@ -575,7 +582,6 @@ def launch_gradio(
|
|
575 |
show_default=True,
|
576 |
)
|
577 |
def search_extract_summarize(
|
578 |
-
web_ui: bool,
|
579 |
query: str,
|
580 |
date_restrict: int,
|
581 |
target_site: str,
|
@@ -583,9 +589,11 @@ def search_extract_summarize(
|
|
583 |
output_length: int,
|
584 |
url_list_file: str,
|
585 |
model_name: str,
|
|
|
586 |
log_level: str,
|
587 |
):
|
588 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
|
|
589 |
|
590 |
if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
|
591 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
@@ -600,14 +608,14 @@ def search_extract_summarize(
|
|
600 |
output_length=output_length,
|
601 |
url_list_str=_read_url_list(url_list_file),
|
602 |
model_name=model_name,
|
603 |
-
log_level=log_level,
|
604 |
share_ui=share_ui,
|
|
|
605 |
)
|
606 |
else:
|
607 |
if query is None:
|
608 |
raise Exception("Query is required for the command line mode")
|
609 |
-
|
610 |
-
result =
|
611 |
query=query,
|
612 |
date_restrict=date_restrict,
|
613 |
target_site=target_site,
|
@@ -615,7 +623,6 @@ def search_extract_summarize(
|
|
615 |
output_length=output_length,
|
616 |
url_list_str=_read_url_list(url_list_file),
|
617 |
model_name=model_name,
|
618 |
-
log_level=log_level,
|
619 |
)
|
620 |
click.echo(result)
|
621 |
|
|
|
3 |
import os
|
4 |
import urllib.parse
|
5 |
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from datetime import datetime
|
7 |
from functools import partial
|
8 |
from typing import Any, Dict, List, Optional, Tuple
|
9 |
|
|
|
20 |
default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
|
21 |
|
22 |
|
23 |
+
def _get_logger(log_level: str) -> logging.Logger:
|
24 |
logger = logging.getLogger(__name__)
|
25 |
logger.setLevel(log_level)
|
26 |
+
if len(logger.handlers) > 0:
|
27 |
+
return logger
|
28 |
handler = logging.StreamHandler()
|
29 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
30 |
handler.setFormatter(formatter)
|
|
|
32 |
return logger
|
33 |
|
34 |
|
35 |
+
def _read_url_list(url_list_file: str) -> str:
|
36 |
+
if url_list_file is None:
|
37 |
+
return None
|
38 |
+
|
39 |
+
with open(url_list_file, "r") as f:
|
40 |
+
links = f.readlines()
|
41 |
+
links = [
|
42 |
+
link.strip()
|
43 |
+
for link in links
|
44 |
+
if link.strip() != "" and not link.startswith("#")
|
45 |
+
]
|
46 |
+
return "\n".join(links)
|
47 |
+
|
48 |
+
|
49 |
class Ask:
|
50 |
|
51 |
def __init__(self, logger: Optional[logging.Logger] = None):
|
|
|
54 |
if logger is not None:
|
55 |
self.logger = logger
|
56 |
else:
|
57 |
+
self.logger = _get_logger("INFO")
|
58 |
|
|
|
59 |
self.db_con = duckdb.connect(":memory:")
|
60 |
|
61 |
self.db_con.install_extension("vss")
|
|
|
64 |
self.db_con.load_extension("fts")
|
65 |
self.db_con.sql("CREATE SEQUENCE seq_docid START 1000")
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
self.session = requests.Session()
|
68 |
user_agent: str = (
|
69 |
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
|
|
226 |
embeddings = self.get_embedding(client, texts)
|
227 |
return chunk_batch, embeddings
|
228 |
|
229 |
+
def _create_table(self) -> str:
|
230 |
+
# Simple ways to get a unique table name
|
231 |
+
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
|
232 |
+
table_name = f"document_chunks_{timestamp}"
|
233 |
+
|
234 |
+
self.db_con.execute(
|
235 |
+
f"""
|
236 |
+
CREATE TABLE {table_name} (
|
237 |
+
doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'),
|
238 |
+
url TEXT,
|
239 |
+
chunk TEXT,
|
240 |
+
vec FLOAT[{self.embedding_dimensions}]
|
241 |
+
);
|
242 |
+
"""
|
243 |
+
)
|
244 |
+
return table_name
|
245 |
+
|
246 |
+
def save_to_db(self, chunking_results: Dict[str, List[str]]) -> str:
|
247 |
client = self._get_api_client()
|
248 |
embed_batch_size = 50
|
249 |
query_batch_size = 100
|
250 |
insert_data = []
|
251 |
|
252 |
+
table_name = self._create_table()
|
253 |
+
|
254 |
batches: List[Tuple[str, List[str]]] = []
|
255 |
for url, list_chunks in chunking_results.items():
|
256 |
for i in range(0, len(list_chunks), embed_batch_size):
|
|
|
282 |
]
|
283 |
)
|
284 |
query = f"""
|
285 |
+
INSERT INTO {table_name} (url, chunk, vec) VALUES {value_str};
|
286 |
"""
|
287 |
self.db_con.execute(query)
|
288 |
|
289 |
self.db_con.execute(
|
290 |
f"""
|
291 |
+
CREATE INDEX {table_name}_cos_idx ON {table_name} USING HNSW (vec)
|
292 |
WITH (metric = 'cosine');
|
293 |
"""
|
294 |
)
|
|
|
296 |
self.db_con.execute(
|
297 |
f"""
|
298 |
PRAGMA create_fts_index(
|
299 |
+
{table_name}, 'doc_id', 'chunk'
|
300 |
);
|
301 |
"""
|
302 |
)
|
303 |
self.logger.info(f"✅ Created the full text search index ...")
|
304 |
+
return table_name
|
305 |
|
306 |
+
def vector_search(self, table_name: str, query: str) -> List[Dict[str, Any]]:
|
307 |
client = self._get_api_client()
|
308 |
embeddings = self.get_embedding(client, [query])[0]
|
309 |
|
310 |
query_result: duckdb.DuckDBPyRelation = self.db_con.sql(
|
311 |
f"""
|
312 |
+
SELECT * FROM {table_name}
|
313 |
ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}])
|
314 |
LIMIT 10;
|
315 |
"""
|
|
|
408 |
response_str = completion.choices[0].message.content
|
409 |
return response_str
|
410 |
|
411 |
+
def run_query(
|
412 |
+
self,
|
413 |
+
query: str,
|
414 |
+
date_restrict: int,
|
415 |
+
target_site: str,
|
416 |
+
output_language: str,
|
417 |
+
output_length: int,
|
418 |
+
url_list_str: str,
|
419 |
+
model_name: str,
|
420 |
+
) -> str:
|
421 |
+
logger = self.logger
|
422 |
+
if url_list_str is None or url_list_str.strip() == "":
|
423 |
+
logger.info("Searching the web ...")
|
424 |
+
links = self.search_web(query, date_restrict, target_site)
|
425 |
+
logger.info(f"✅ Found {len(links)} links for query: {query}")
|
426 |
+
for i, link in enumerate(links):
|
427 |
+
logger.debug(f"{i+1}. {link}")
|
428 |
+
else:
|
429 |
+
links = url_list_str.split("\n")
|
430 |
+
|
431 |
+
logger.info("Scraping the URLs ...")
|
432 |
+
scrape_results = self.scrape_urls(links)
|
433 |
+
logger.info(f"✅ Scraped {len(scrape_results)} URLs.")
|
434 |
+
|
435 |
+
logger.info("Chunking the text ...")
|
436 |
+
chunking_results = self.chunk_results(scrape_results, 1000, 100)
|
437 |
+
total_chunks = 0
|
438 |
+
for url, chunks in chunking_results.items():
|
439 |
+
logger.debug(f"URL: {url}")
|
440 |
+
total_chunks += len(chunks)
|
441 |
+
for i, chunk in enumerate(chunks):
|
442 |
+
logger.debug(f"Chunk {i+1}: {chunk}")
|
443 |
+
logger.info(f"✅ Generated {total_chunks} chunks ...")
|
444 |
+
|
445 |
+
logger.info(f"Saving {total_chunks} chunks to DB ...")
|
446 |
+
table_name = self.save_to_db(chunking_results)
|
447 |
+
logger.info(f"✅ Successfully embedded and saved chunks to DB.")
|
448 |
+
|
449 |
+
logger.info("Querying the vector DB to get context ...")
|
450 |
+
matched_chunks = self.vector_search(table_name, query)
|
451 |
+
for i, result in enumerate(matched_chunks):
|
452 |
+
logger.debug(f"{i+1}. {result}")
|
453 |
+
logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
|
454 |
+
|
455 |
+
logger.info("Running inference with context ...")
|
456 |
+
answer = self.run_inference(
|
457 |
+
query=query,
|
458 |
+
model_name=model_name,
|
459 |
+
matched_chunks=matched_chunks,
|
460 |
+
output_language=output_language,
|
461 |
+
output_length=output_length,
|
462 |
+
)
|
463 |
+
logger.info("✅ Finished inference API call.")
|
464 |
+
logger.info("Generating output ...")
|
465 |
|
466 |
+
answer = f"# Answer\n\n{answer}\n"
|
467 |
+
references = "\n".join(
|
468 |
+
[f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
|
469 |
+
)
|
470 |
+
return f"{answer}\n\n# References\n\n{references}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
471 |
|
472 |
|
473 |
def launch_gradio(
|
|
|
478 |
output_length: int,
|
479 |
url_list_str: str,
|
480 |
model_name: str,
|
|
|
481 |
share_ui: bool,
|
482 |
+
logger: logging.Logger,
|
483 |
) -> None:
|
484 |
+
ask = Ask(logger=logger)
|
485 |
iface = gr.Interface(
|
486 |
+
fn=ask.run_query,
|
487 |
inputs=[
|
488 |
gr.Textbox(label="Query", value=query),
|
489 |
gr.Number(
|
|
|
491 |
value=date_restrict,
|
492 |
),
|
493 |
gr.Textbox(
|
494 |
+
label="Target Sites (Optional) [Empty means search the whole web.]",
|
495 |
value=target_site,
|
496 |
),
|
497 |
gr.Textbox(
|
|
|
511 |
],
|
512 |
additional_inputs=[
|
513 |
gr.Textbox(label="Model Name", value=model_name),
|
|
|
514 |
],
|
515 |
outputs="text",
|
516 |
show_progress=True,
|
|
|
522 |
iface.launch(share=share_ui)
|
523 |
|
524 |
|
525 |
+
@click.command(help="Search web for the query and summarize the results.")
|
|
|
|
|
|
|
|
|
|
|
526 |
@click.option("--query", "-q", required=False, help="Query to search")
|
527 |
@click.option(
|
528 |
"--date-restrict",
|
|
|
567 |
default="gpt-4o-mini",
|
568 |
help="Model name to use for inference",
|
569 |
)
|
570 |
+
@click.option(
|
571 |
+
"--web-ui",
|
572 |
+
is_flag=True,
|
573 |
+
help="Launch the web interface",
|
574 |
+
)
|
575 |
@click.option(
|
576 |
"-l",
|
577 |
"--log-level",
|
|
|
582 |
show_default=True,
|
583 |
)
|
584 |
def search_extract_summarize(
|
|
|
585 |
query: str,
|
586 |
date_restrict: int,
|
587 |
target_site: str,
|
|
|
589 |
output_length: int,
|
590 |
url_list_file: str,
|
591 |
model_name: str,
|
592 |
+
web_ui: bool,
|
593 |
log_level: str,
|
594 |
):
|
595 |
load_dotenv(dotenv_path=default_env_file, override=False)
|
596 |
+
logger = _get_logger(log_level)
|
597 |
|
598 |
if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
|
599 |
if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
|
|
|
608 |
output_length=output_length,
|
609 |
url_list_str=_read_url_list(url_list_file),
|
610 |
model_name=model_name,
|
|
|
611 |
share_ui=share_ui,
|
612 |
+
logger=logger,
|
613 |
)
|
614 |
else:
|
615 |
if query is None:
|
616 |
raise Exception("Query is required for the command line mode")
|
617 |
+
ask = Ask(logger=logger)
|
618 |
+
result = ask.run_query(
|
619 |
query=query,
|
620 |
date_restrict=date_restrict,
|
621 |
target_site=target_site,
|
|
|
623 |
output_length=output_length,
|
624 |
url_list_str=_read_url_list(url_list_file),
|
625 |
model_name=model_name,
|
|
|
626 |
)
|
627 |
click.echo(result)
|
628 |
|