LeetTools commited on
Commit
1d2d847
·
verified ·
1 Parent(s): c6e4506

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +207 -9
  2. ask.py +624 -0
  3. requirements.txt +9 -0
README.md CHANGED
@@ -1,14 +1,212 @@
1
  ---
2
- title: AskPy
3
- emoji: 😻
4
- colorFrom: red
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.3.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: GradIO Demo for AskPy
12
  ---
 
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: ask.py
3
+ app_file: ask.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.3.0
 
 
 
 
6
  ---
7
+ # ask.py
8
 
9
+ [![License](https://img.shields.io/github/license/pengfeng/ask.py)](LICENSE)
10
+
11
+ A single Python program to implement the search-extract-summarize flow, similar to AI search
12
+ engines such as Perplexity.
13
+
14
+ > [UPDATE]
15
+ >
16
+ > - 2024-10-22: add GradIO integation
17
+ > - 2024-10-21: use DuckDB for the vector search and use API for embedding
18
+ > - 2024-10-20: allow to specify a list of input urls
19
+ > - 2024-10-18: output-language and output-length parameters for LLM
20
+ > - 2024-10-18: date-restrict and target-site parameters for seach
21
+
22
+ > [!NOTE]
23
+ > Our main goal is to illustrate the basic concepts of AI search engines with the raw constructs.
24
+ > Performance or scalability is not in the scope of this program.
25
+
26
+ ## The search-extract-summarize flow
27
+
28
+ Given a query, the program will
29
+
30
+ - search Google for the top 10 web pages
31
+ - crawl and scape the pages for their text content
32
+ - chunk the text content into chunks and save them into a vectordb
33
+ - perform a vector search with the query and find the top 10 matched chunks
34
+ - use the top 10 chunks as the context to ask an LLM to generate the answer
35
+ - output the answer with the references
36
+
37
+ Of course this flow is a very simplified version of the real AI search engines, but it is a good
38
+ starting point to understand the basic concepts.
39
+
40
+ One benefit is that we can manipulate the search function and output format.
41
+
42
+ For example, we can:
43
+
44
+ - search with date-restrict to only retrieve the latest information.
45
+ - search within a target-site to only create the answer from the contents from it.
46
+ - ask LLM to use a specific language to answer the question.
47
+ - ask LLM to answer with a specific length.
48
+ - crawl a specific list of urls and answer based on those contents only.
49
+
50
+ ## Quick start
51
+
52
+ ```bash
53
+
54
+ pip install -r requirements.txt
55
+
56
+ # modify .env file to set the API keys or export them as environment variables as below
57
+
58
+ # right now we use Google search API
59
+ export SEARCH_API_KEY="your-google-search-api-key"
60
+ export SEARCH_PROJECT_KEY="your-google-cx-key"
61
+
62
+ # right now we use OpenAI API
63
+ export LLM_API_KEY="your-openai-api-key"
64
+
65
+ # run the program
66
+ python ask.py -q "What is an LLM agent?"
67
+
68
+ # we can specify more parameters to control the behavior such as date_restrict and target_site
69
+ python ask.py --help
70
+ Usage: ask.py [OPTIONS]
71
+
72
+ Search web for the query and summarize the results
73
+
74
+ Options:
75
+ --web-ui Launch the web interface
76
+ -q, --query TEXT Query to search
77
+ -d, --date-restrict INTEGER Restrict search results to a specific date
78
+ range, default is no restriction
79
+ -s, --target-site TEXT Restrict search results to a specific site,
80
+ default is no restriction
81
+ --output-language TEXT Output language for the answer
82
+ --output-length INTEGER Output length for the answer
83
+ --url-list-file TEXT Instead of doing web search, scrape the
84
+ target URL list and answer the query based
85
+ on the content
86
+ -m, --model-name TEXT Model name to use for inference
87
+ -l, --log-level [DEBUG|INFO|WARNING|ERROR]
88
+ Set the logging level [default: INFO]
89
+ --help Show this message and exit.
90
+ ```
91
+
92
+ ## Libraries and APIs used
93
+
94
+ - [Google Search API](https://developers.google.com/custom-search/v1/overview)
95
+ - [OpenAI API](https://beta.openai.com/docs/api-reference/completions/create)
96
+ - [Jinja2](https://jinja.palletsprojects.com/en/3.0.x/)
97
+ - [bs4](https://www.crummy.com/software/BeautifulSoup/bs4/doc/)
98
+ - [DuckDB](https://github.com/duckdb/duckdb)
99
+ - [GradIO](https://grad.io)
100
+
101
+ ## Screenshot for the GradIO integration
102
+
103
+ ![image](https://github.com/user-attachments/assets/0483e6a2-75d7-4fbd-813f-bfa13839c836)
104
+
105
+ ## Sample output
106
+
107
+ ### General Search
108
+
109
+ ```
110
+ % python ask.py -q "Why do we need agentic RAG even if we have ChatGPT?"
111
+
112
+ ✅ Found 10 links for query: Why do we need agentic RAG even if we have ChatGPT?
113
+ ✅ Scraping the URLs ...
114
+ ✅ Scraped 10 URLs ...
115
+ ✅ Chunking the text ...
116
+ ✅ Saving to vector DB ...
117
+ ✅ Querying the vector DB ...
118
+ ✅ Running inference with context ...
119
+
120
+ # Answer
121
+
122
+ Agentic RAG (Retrieval-Augmented Generation) is needed alongside ChatGPT for several reasons:
123
+
124
+ 1. **Precision and Contextual Relevance**: While ChatGPT offers generative responses, it may not
125
+ reliably provide precise answers, especially when specific, accurate information is critical[5].
126
+ Agentic RAG enhances this by integrating retrieval mechanisms that improve response context and
127
+ accuracy, allowing users to access the most relevant and recent data without the need for costly
128
+ model fine-tuning[2].
129
+
130
+ 2. **Customizability**: RAG allows businesses to create tailored chatbots that can securely
131
+ reference company-specific data[2]. In contrast, ChatGPT’s broader capabilities may not be
132
+ directly suited for specialized, domain-specific questions without comprehensive customization[3].
133
+
134
+ 3. **Complex Query Handling**: RAG can be optimized for complex queries and can be adjusted to
135
+ work better with specific types of inputs, such as comparing and contrasting information, a task
136
+ where ChatGPT may struggle under certain circumstances[9]. This level of customization can lead to
137
+ better performance in niche applications where precise retrieval of information is crucial.
138
+
139
+ 4. **Asynchronous Processing Capabilities**: Future agentic systems aim to integrate asynchronous
140
+ handling of actions, allowing for parallel processing and reducing wait times for retrieval and
141
+ computation, which is a limitation in the current form of ChatGPT[7]. This advancement would enhance
142
+ overall efficiency and responsiveness in conversations.
143
+
144
+ 5. **Incorporating Retrieved Information Effectively**: Using RAG can significantly improve how
145
+ retrieved information is utilized within a conversation. By effectively managing the context and
146
+ relevance of retrieved documents, RAG helps in framing prompts that can guide ChatGPT towards
147
+ delivering more accurate responses[10].
148
+
149
+ In summary, while ChatGPT excels in generating conversational responses, agentic RAG brings
150
+ precision, customization, and efficiency that can significantly enhance the overall conversational
151
+ AI experience.
152
+
153
+ # References
154
+
155
+ [1] https://community.openai.com/t/how-to-use-rag-properly-and-what-types-of-query-it-is-good-at/658204
156
+ [2] https://www.linkedin.com/posts/brianjuliusdc_dax-powerbi-chatgpt-activity-7235953280177041408-wQqq
157
+ [3] https://community.openai.com/t/how-to-use-rag-properly-and-what-types-of-query-it-is-good-at/658204
158
+ [4] https://community.openai.com/t/prompt-engineering-for-rag/621495
159
+ [5] https://www.ben-evans.com/benedictevans/2024/6/8/building-ai-products
160
+ [6] https://community.openai.com/t/prompt-engineering-for-rag/621495
161
+ [7] https://www.linkedin.com/posts/kurtcagle_agentic-rag-personalizing-and-optimizing-activity-7198097129993613312-z7Sm
162
+ [8] https://community.openai.com/t/how-to-use-rag-properly-and-what-types-of-query-it-is-good-at/658204
163
+ [9] https://community.openai.com/t/how-to-use-rag-properly-and-what-types-of-query-it-is-good-at/658204
164
+ [10] https://community.openai.com/t/prompt-engineering-for-rag/621495
165
+ ```
166
+
167
+ ### Only use the latest information from a specific site
168
+
169
+ This following query will only use the information from openai.com that are updated in the previous
170
+ day. The behavior is similar to the "site:openai.com" and "date-restrict" search parameters in Google
171
+ search.
172
+
173
+ ```
174
+ % python ask.py -q "OpenAI Swarm Framework" -d 1 -s openai.com
175
+ ✅ Found 10 links for query: OpenAI Swarm Framework
176
+ ✅ Scraping the URLs ...
177
+ ✅ Scraped 10 URLs ...
178
+ ✅ Chunking the text ...
179
+ ✅ Saving to vector DB ...
180
+ ✅ Querying the vector DB to get context ...
181
+ ✅ Running inference with context ...
182
+
183
+ # Answer
184
+
185
+ OpenAI Swarm Framework is an experimental platform designed for building, orchestrating, and
186
+ deploying multi-agent systems, enabling multiple AI agents to collaborate on complex tasks. It contrasts
187
+ with traditional single-agent models by facilitating agent interaction and coordination, thus enhancing
188
+ efficiency[5][9]. The framework provides developers with a way to orchestrate these agent systems in
189
+ a lightweight manner, leveraging Node.js for scalable applications[1][4].
190
+
191
+ One implementation of this framework is Swarm.js, which serves as a Node.js SDK, allowing users to
192
+ create and manage agents that perform tasks and hand off conversations. Swarm.js is positioned as
193
+ an educational tool, making it accessible for both beginners and experts, although it may still contain
194
+ bugs and is currently lightweight[1][3][7]. This new approach emphasizes multi-agent collaboration and is
195
+ well-suited for back-end development, requiring some programming expertise for effective implementation[9].
196
+
197
+ Overall, OpenAI Swarm facilitates a shift in how AI systems can collaborate, differing from existing
198
+ OpenAI tools by focusing on backend orchestration rather than user-interactive front-end applications[9].
199
+
200
+ # References
201
+
202
+ [1] https://community.openai.com/t/introducing-swarm-js-node-js-implementation-of-openai-swarm/977510
203
+ [2] https://community.openai.com/t/introducing-swarm-js-a-node-js-implementation-of-openai-swarm/977510
204
+ [3] https://community.openai.com/t/introducing-swarm-js-node-js-implementation-of-openai-swarm/977510
205
+ [4] https://community.openai.com/t/introducing-swarm-js-a-node-js-implementation-of-openai-swarm/977510
206
+ [5] https://community.openai.com/t/swarm-some-initial-insights/976602
207
+ [6] https://community.openai.com/t/swarm-some-initial-insights/976602
208
+ [7] https://community.openai.com/t/introducing-swarm-js-node-js-implementation-of-openai-swarm/977510
209
+ [8] https://community.openai.com/t/introducing-swarm-js-a-node-js-implementation-of-openai-swarm/977510
210
+ [9] https://community.openai.com/t/swarm-some-initial-insights/976602
211
+ [10] https://community.openai.com/t/swarm-some-initial-insights/976602
212
+ ```
ask.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import urllib.parse
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from functools import partial
7
+ from typing import Any, Dict, List, Optional, Tuple
8
+
9
+ import click
10
+ import duckdb
11
+ import gradio as gr
12
+ import requests
13
+ from bs4 import BeautifulSoup
14
+ from dotenv import load_dotenv
15
+ from jinja2 import BaseLoader, Environment
16
+ from openai import OpenAI
17
+
18
+ script_dir = os.path.dirname(os.path.abspath(__file__))
19
+ default_env_file = os.path.abspath(os.path.join(script_dir, ".env"))
20
+
21
+
22
+ def get_logger(log_level: str) -> logging.Logger:
23
+ logger = logging.getLogger(__name__)
24
+ logger.setLevel(log_level)
25
+ handler = logging.StreamHandler()
26
+ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
27
+ handler.setFormatter(formatter)
28
+ logger.addHandler(handler)
29
+ return logger
30
+
31
+
32
+ class Ask:
33
+
34
+ def __init__(self, logger: Optional[logging.Logger] = None):
35
+ self.read_env_variables()
36
+
37
+ if logger is not None:
38
+ self.logger = logger
39
+ else:
40
+ self.logger = get_logger("INFO")
41
+
42
+ self.table_name = "document_chunks"
43
+ self.db_con = duckdb.connect(":memory:")
44
+
45
+ self.db_con.install_extension("vss")
46
+ self.db_con.load_extension("vss")
47
+ self.db_con.install_extension("fts")
48
+ self.db_con.load_extension("fts")
49
+ self.db_con.sql("CREATE SEQUENCE seq_docid START 1000")
50
+
51
+ self.db_con.execute(
52
+ f"""
53
+ CREATE TABLE {self.table_name} (
54
+ doc_id INTEGER PRIMARY KEY DEFAULT nextval('seq_docid'),
55
+ url TEXT,
56
+ chunk TEXT,
57
+ vec FLOAT[{self.embedding_dimensions}]
58
+ );
59
+ """
60
+ )
61
+
62
+ self.session = requests.Session()
63
+ user_agent: str = (
64
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
65
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
66
+ "Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
67
+ )
68
+ self.session.headers.update({"User-Agent": user_agent})
69
+
70
+ def read_env_variables(self) -> None:
71
+ err_msg = ""
72
+
73
+ self.search_api_key = os.environ.get("SEARCH_API_KEY")
74
+ if self.search_api_key is None:
75
+ err_msg += "SEARCH_API_KEY env variable not set.\n"
76
+ self.search_project_id = os.environ.get("SEARCH_PROJECT_KEY")
77
+ if self.search_project_id is None:
78
+ err_msg += "SEARCH_PROJECT_KEY env variable not set.\n"
79
+ self.llm_api_key = os.environ.get("LLM_API_KEY")
80
+ if self.llm_api_key is None:
81
+ err_msg += "LLM_API_KEY env variable not set.\n"
82
+
83
+ if err_msg != "":
84
+ raise Exception(f"\n{err_msg}\n")
85
+
86
+ self.llm_base_url = os.environ.get("LLM_BASE_URL")
87
+ if self.llm_base_url is None:
88
+ self.llm_base_url = "https://api.openai.com/v1"
89
+
90
+ self.embedding_model = os.environ.get("EMBEDDING_MODEL")
91
+ self.embedding_dimensions = os.environ.get("EMBEDDING_DIMENSIONS")
92
+
93
+ if self.embedding_model is None or self.embedding_dimensions is None:
94
+ self.embedding_model = "text-embedding-3-small"
95
+ self.embedding_dimensions = 1536
96
+
97
+ def search_web(self, query: str, date_restrict: int, target_site: str) -> List[str]:
98
+ escaped_query = urllib.parse.quote(query)
99
+ url_base = (
100
+ f"https://www.googleapis.com/customsearch/v1?key={self.search_api_key}"
101
+ f"&cx={self.search_project_id}&q={escaped_query}"
102
+ )
103
+ url_paras = f"&safe=active"
104
+ if date_restrict is not None and date_restrict > 0:
105
+ url_paras += f"&dateRestrict={date_restrict}"
106
+ if target_site is not None and target_site != "":
107
+ url_paras += f"&siteSearch={target_site}&siteSearchFilter=i"
108
+ url = f"{url_base}{url_paras}"
109
+
110
+ self.logger.debug(f"Searching for query: {query}")
111
+
112
+ resp = requests.get(url)
113
+
114
+ if resp is None:
115
+ raise Exception("No response from search API")
116
+
117
+ search_results_dict = json.loads(resp.text)
118
+ if "error" in search_results_dict:
119
+ raise Exception(
120
+ f"Error in search API response: {search_results_dict['error']}"
121
+ )
122
+
123
+ if "searchInformation" not in search_results_dict:
124
+ raise Exception(
125
+ f"No search information in search API response: {resp.text}"
126
+ )
127
+
128
+ total_results = search_results_dict["searchInformation"].get("totalResults", 0)
129
+ if total_results == 0:
130
+ self.logger.warning(f"No results found for query: {query}")
131
+ return []
132
+
133
+ results = search_results_dict.get("items", [])
134
+ if results is None or len(results) == 0:
135
+ self.logger.warning(f"No result items in the response for query: {query}")
136
+ return []
137
+
138
+ found_links = []
139
+ for result in results:
140
+ link = result.get("link", None)
141
+ if link is None or link == "":
142
+ self.logger.warning(f"Search result link missing: {result}")
143
+ continue
144
+ found_links.append(link)
145
+ return found_links
146
+
147
+ def _scape_url(self, url: str) -> Tuple[str, str]:
148
+ try:
149
+ response = self.session.get(url, timeout=10)
150
+ soup = BeautifulSoup(response.content, "lxml", from_encoding="utf-8")
151
+
152
+ body_tag = soup.body
153
+ if body_tag:
154
+ body_text = body_tag.get_text()
155
+ body_text = " ".join(body_text.split()).strip()
156
+ self.logger.debug(f"Scraped {url}: {body_text}...")
157
+ if len(body_text) > 100:
158
+ return url, body_text
159
+ else:
160
+ self.logger.warning(
161
+ f"Body text too short for url: {url}, length: {len(body_text)}"
162
+ )
163
+ return url, ""
164
+ else:
165
+ self.logger.warning(f"No body tag found in the response for url: {url}")
166
+ return url, ""
167
+ except Exception as e:
168
+ self.logger.error(f"Scraping error {url}: {e}")
169
+ return url, ""
170
+
171
+ def scrape_urls(self, urls: List[str]) -> Dict[str, str]:
172
+ # the key is the url and the value is the body text
173
+ scrape_results: Dict[str, str] = {}
174
+
175
+ partial_scrape = partial(self._scape_url)
176
+ with ThreadPoolExecutor(max_workers=10) as executor:
177
+ results = executor.map(partial_scrape, urls)
178
+
179
+ for url, body_text in results:
180
+ if body_text != "":
181
+ scrape_results[url] = body_text
182
+
183
+ return scrape_results
184
+
185
+ def chunk_results(
186
+ self, scrape_results: Dict[str, str], size: int, overlap: int
187
+ ) -> Dict[str, List[str]]:
188
+ chunking_results: Dict[str, List[str]] = {}
189
+ for url, text in scrape_results.items():
190
+ chunks = []
191
+ for pos in range(0, len(text), size - overlap):
192
+ chunks.append(text[pos : pos + size])
193
+ chunking_results[url] = chunks
194
+ return chunking_results
195
+
196
+ def get_embedding(self, client: OpenAI, texts: List[str]) -> List[List[float]]:
197
+ if len(texts) == 0:
198
+ return []
199
+
200
+ response = client.embeddings.create(input=texts, model=self.embedding_model)
201
+ embeddings = []
202
+ for i in range(len(response.data)):
203
+ embeddings.append(response.data[i].embedding)
204
+ return embeddings
205
+
206
+ def batch_get_embedding(
207
+ self, client: OpenAI, chunk_batch: Tuple[str, List[str]]
208
+ ) -> Tuple[Tuple[str, List[str]], List[List[float]]]:
209
+ """
210
+ Return the chunk_batch as well as the embeddings for each chunk so that
211
+ we can aggregate them and save them to the database together.
212
+
213
+ Args:
214
+ - client: OpenAI client
215
+ - chunk_batch: Tuple of URL and list of chunks scraped from the URL
216
+
217
+ Returns:
218
+ - Tuple of chunk_bach and list of result embeddings
219
+ """
220
+ texts = chunk_batch[1]
221
+ embeddings = self.get_embedding(client, texts)
222
+ return chunk_batch, embeddings
223
+
224
+ def save_to_db(self, chunking_results: Dict[str, List[str]]) -> None:
225
+ client = self._get_api_client()
226
+ embed_batch_size = 50
227
+ query_batch_size = 100
228
+ insert_data = []
229
+
230
+ batches: List[Tuple[str, List[str]]] = []
231
+ for url, list_chunks in chunking_results.items():
232
+ for i in range(0, len(list_chunks), embed_batch_size):
233
+ list_chunks = list_chunks[i : i + embed_batch_size]
234
+ batches.append((url, list_chunks))
235
+
236
+ self.logger.info(f"Embedding {len(batches)} batches of chunks ...")
237
+ partial_get_embedding = partial(self.batch_get_embedding, client)
238
+ with ThreadPoolExecutor(max_workers=10) as executor:
239
+ all_embeddings = executor.map(partial_get_embedding, batches)
240
+ self.logger.info(f"✅ Finished embedding.")
241
+
242
+ for chunk_batch, embeddings in all_embeddings:
243
+ url = chunk_batch[0]
244
+ list_chunks = chunk_batch[1]
245
+ insert_data.extend(
246
+ [
247
+ (url.replace("'", " "), chunk.replace("'", " "), embedding)
248
+ for chunk, embedding in zip(list_chunks, embeddings)
249
+ ]
250
+ )
251
+
252
+ for i in range(0, len(insert_data), query_batch_size):
253
+ # insert the batch into DuckDB
254
+ value_str = ", ".join(
255
+ [
256
+ f"('{url}', '{chunk}', {embedding})"
257
+ for url, chunk, embedding in insert_data[i : i + embed_batch_size]
258
+ ]
259
+ )
260
+ query = f"""
261
+ INSERT INTO {self.table_name} (url, chunk, vec) VALUES {value_str};
262
+ """
263
+ self.db_con.execute(query)
264
+
265
+ self.db_con.execute(
266
+ f"""
267
+ CREATE INDEX cos_idx ON {self.table_name} USING HNSW (vec)
268
+ WITH (metric = 'cosine');
269
+ """
270
+ )
271
+ self.logger.info(f"✅ Created the vector index ...")
272
+ self.db_con.execute(
273
+ f"""
274
+ PRAGMA create_fts_index(
275
+ {self.table_name}, 'doc_id', 'chunk'
276
+ );
277
+ """
278
+ )
279
+ self.logger.info(f"✅ Created the full text search index ...")
280
+
281
+ def vector_search(self, query: str) -> List[Dict[str, Any]]:
282
+ client = self._get_api_client()
283
+ embeddings = self.get_embedding(client, [query])[0]
284
+
285
+ query_result: duckdb.DuckDBPyRelation = self.db_con.sql(
286
+ f"""
287
+ SELECT * FROM {self.table_name}
288
+ ORDER BY array_distance(vec, {embeddings}::FLOAT[{self.embedding_dimensions}])
289
+ LIMIT 10;
290
+ """
291
+ )
292
+
293
+ self.logger.debug(query_result)
294
+
295
+ matched_chunks = []
296
+ for record in query_result.fetchall():
297
+ result_record = {
298
+ "url": record[1],
299
+ "chunk": record[2],
300
+ }
301
+ matched_chunks.append(result_record)
302
+
303
+ return matched_chunks
304
+
305
+ def _get_api_client(self) -> OpenAI:
306
+ return OpenAI(api_key=self.llm_api_key, base_url=self.llm_base_url)
307
+
308
+ def _render_template(self, template_str: str, variables: Dict[str, Any]) -> str:
309
+ env = Environment(loader=BaseLoader(), autoescape=False)
310
+ template = env.from_string(template_str)
311
+ return template.render(variables)
312
+
313
+ def run_inference(
314
+ self,
315
+ query: str,
316
+ model_name: str,
317
+ matched_chunks: List[Dict[str, Any]],
318
+ output_language: str,
319
+ output_length: int,
320
+ ) -> str:
321
+ system_prompt = (
322
+ "You are an expert summarizing the answers based on the provided contents."
323
+ )
324
+ user_promt_template = """
325
+ Given the context as a sequence of references with a reference id in the
326
+ format of a leading [x], please answer the following question using {{ language }}:
327
+
328
+ {{ query }}
329
+
330
+ In the answer, use format [1], [2], ..., [n] in line where the reference is used.
331
+ For example, "According to the research from Google[3], ...".
332
+
333
+ Please create the answer strictly related to the context. If the context has no
334
+ information about the query, please write "No related information found in the context."
335
+ using {{ language }}.
336
+
337
+ {{ length_instructions }}
338
+
339
+ Here is the context:
340
+ {{ context }}
341
+ """
342
+ context = ""
343
+ for i, chunk in enumerate(matched_chunks):
344
+ context += f"[{i+1}] {chunk['chunk']}\n"
345
+
346
+ if output_length is None or output_length == 0:
347
+ length_instructions = ""
348
+ else:
349
+ length_instructions = (
350
+ f"Please provide the answer in { output_length } words."
351
+ )
352
+
353
+ user_prompt = self._render_template(
354
+ user_promt_template,
355
+ {
356
+ "query": query,
357
+ "context": context,
358
+ "language": output_language,
359
+ "length_instructions": length_instructions,
360
+ },
361
+ )
362
+
363
+ self.logger.debug(f"Running inference with model: {model_name}")
364
+ self.logger.debug(f"Final user prompt: {user_prompt}")
365
+
366
+ api_client = self._get_api_client()
367
+ completion = api_client.chat.completions.create(
368
+ model=model_name,
369
+ messages=[
370
+ {
371
+ "role": "system",
372
+ "content": system_prompt,
373
+ },
374
+ {
375
+ "role": "user",
376
+ "content": user_prompt,
377
+ },
378
+ ],
379
+ )
380
+ if completion is None:
381
+ raise Exception("No completion from the API")
382
+
383
+ response_str = completion.choices[0].message.content
384
+ return response_str
385
+
386
+
387
+ def _read_url_list(url_list_file: str) -> str:
388
+ if url_list_file is None:
389
+ return None
390
+
391
+ with open(url_list_file, "r") as f:
392
+ links = f.readlines()
393
+ links = [
394
+ link.strip()
395
+ for link in links
396
+ if link.strip() != "" and not link.startswith("#")
397
+ ]
398
+ return "\n".join(links)
399
+
400
+
401
+ def _run_query(
402
+ query: str,
403
+ date_restrict: int,
404
+ target_site: str,
405
+ output_language: str,
406
+ output_length: int,
407
+ url_list_str: str,
408
+ model_name: str,
409
+ log_level: str,
410
+ ) -> str:
411
+ logger = get_logger(log_level)
412
+
413
+ ask = Ask(logger=logger)
414
+
415
+ if url_list_str is None or url_list_str.strip() == "":
416
+ logger.info("Searching the web ...")
417
+ links = ask.search_web(query, date_restrict, target_site)
418
+ logger.info(f"✅ Found {len(links)} links for query: {query}")
419
+ for i, link in enumerate(links):
420
+ logger.debug(f"{i+1}. {link}")
421
+ else:
422
+ links = url_list_str.split("\n")
423
+
424
+ logger.info("Scraping the URLs ...")
425
+ scrape_results = ask.scrape_urls(links)
426
+ logger.info(f"✅ Scraped {len(scrape_results)} URLs.")
427
+
428
+ logger.info("Chunking the text ...")
429
+ chunking_results = ask.chunk_results(scrape_results, 1000, 100)
430
+ total_chunks = 0
431
+ for url, chunks in chunking_results.items():
432
+ logger.debug(f"URL: {url}")
433
+ total_chunks += len(chunks)
434
+ for i, chunk in enumerate(chunks):
435
+ logger.debug(f"Chunk {i+1}: {chunk}")
436
+ logger.info(f"✅ Generated {total_chunks} chunks ...")
437
+
438
+ logger.info(f"Saving {total_chunks} chunks to DB ...")
439
+ ask.save_to_db(chunking_results)
440
+ logger.info(f"✅ Successfully embedded and saved chunks to DB.")
441
+
442
+ logger.info("Querying the vector DB to get context ...")
443
+ matched_chunks = ask.vector_search(query)
444
+ for i, result in enumerate(matched_chunks):
445
+ logger.debug(f"{i+1}. {result}")
446
+ logger.info(f"✅ Got {len(matched_chunks)} matched chunks.")
447
+
448
+ logger.info("Running inference with context ...")
449
+ answer = ask.run_inference(
450
+ query=query,
451
+ model_name=model_name,
452
+ matched_chunks=matched_chunks,
453
+ output_language=output_language,
454
+ output_length=output_length,
455
+ )
456
+ logger.info("✅ Finished inference API call.")
457
+ logger.info("generateing output ...")
458
+
459
+ answer = f"# Answer\n\n{answer}\n"
460
+ references = "\n".join(
461
+ [f"[{i+1}] {result['url']}" for i, result in enumerate(matched_chunks)]
462
+ )
463
+ return f"{answer}\n\n# References\n\n{references}"
464
+
465
+
466
+ def launch_gradio(
467
+ query: str,
468
+ date_restrict: int,
469
+ target_site: str,
470
+ output_language: str,
471
+ output_length: int,
472
+ url_list_str: str,
473
+ model_name: str,
474
+ log_level: str,
475
+ share_ui: bool,
476
+ ) -> None:
477
+ iface = gr.Interface(
478
+ fn=_run_query,
479
+ inputs=[
480
+ gr.Textbox(label="Query", value=query),
481
+ gr.Number(
482
+ label="Date Restrict (Optional) [0 or empty means no date limit.]",
483
+ value=date_restrict,
484
+ ),
485
+ gr.Textbox(
486
+ label="Target Sites (Optional) [Empty means seach the whole web.]",
487
+ value=target_site,
488
+ ),
489
+ gr.Textbox(
490
+ label="Output Language (Optional) [Default is English.]",
491
+ value=output_language,
492
+ ),
493
+ gr.Number(
494
+ label="Output Length in words (Optional) [Default is automatically decided by LLM.]",
495
+ value=output_length,
496
+ ),
497
+ gr.Textbox(
498
+ label="URL List (Optional) [When specified, scrape the urls instead of searching the web.]",
499
+ lines=5,
500
+ max_lines=20,
501
+ value=url_list_str,
502
+ ),
503
+ ],
504
+ additional_inputs=[
505
+ gr.Textbox(label="Model Name", value=model_name),
506
+ gr.Textbox(label="Log Level", value=log_level),
507
+ ],
508
+ outputs="text",
509
+ show_progress=True,
510
+ flagging_options=[("Report Error", None)],
511
+ title="Ask.py - Web Search-Extract-Summarize",
512
+ description="Search the web with the query and summarize the results. Source code: https://github.com/pengfeng/ask.py",
513
+ )
514
+
515
+ iface.launch(share=share_ui)
516
+
517
+
518
+ @click.command(help="Search web for the query and summarize the results")
519
+ @click.option(
520
+ "--web-ui",
521
+ is_flag=True,
522
+ help="Launch the web interface",
523
+ )
524
+ @click.option("--query", "-q", required=False, help="Query to search")
525
+ @click.option(
526
+ "--date-restrict",
527
+ "-d",
528
+ type=int,
529
+ required=False,
530
+ default=None,
531
+ help="Restrict search results to a specific date range, default is no restriction",
532
+ )
533
+ @click.option(
534
+ "--target-site",
535
+ "-s",
536
+ required=False,
537
+ default=None,
538
+ help="Restrict search results to a specific site, default is no restriction",
539
+ )
540
+ @click.option(
541
+ "--output-language",
542
+ required=False,
543
+ default="English",
544
+ help="Output language for the answer",
545
+ )
546
+ @click.option(
547
+ "--output-length",
548
+ type=int,
549
+ required=False,
550
+ default=None,
551
+ help="Output length for the answer",
552
+ )
553
+ @click.option(
554
+ "--url-list-file",
555
+ type=str,
556
+ required=False,
557
+ default=None,
558
+ show_default=True,
559
+ help="Instead of doing web search, scrape the target URL list and answer the query based on the content",
560
+ )
561
+ @click.option(
562
+ "--model-name",
563
+ "-m",
564
+ required=False,
565
+ default="gpt-4o-mini",
566
+ help="Model name to use for inference",
567
+ )
568
+ @click.option(
569
+ "-l",
570
+ "--log-level",
571
+ "log_level",
572
+ default="INFO",
573
+ type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"], case_sensitive=False),
574
+ help="Set the logging level",
575
+ show_default=True,
576
+ )
577
+ def search_extract_summarize(
578
+ web_ui: bool,
579
+ query: str,
580
+ date_restrict: int,
581
+ target_site: str,
582
+ output_language: str,
583
+ output_length: int,
584
+ url_list_file: str,
585
+ model_name: str,
586
+ log_level: str,
587
+ ):
588
+ load_dotenv(dotenv_path=default_env_file, override=False)
589
+
590
+ if web_ui or os.environ.get("RUN_GRADIO_UI", "false").lower() != "false":
591
+ if os.environ.get("SHARE_GRADIO_UI", "false").lower() == "true":
592
+ share_ui = True
593
+ else:
594
+ share_ui = False
595
+ launch_gradio(
596
+ query=query,
597
+ date_restrict=date_restrict,
598
+ target_site=target_site,
599
+ output_language=output_language,
600
+ output_length=output_length,
601
+ url_list_str=_read_url_list(url_list_file),
602
+ model_name=model_name,
603
+ log_level=log_level,
604
+ share_ui=share_ui,
605
+ )
606
+ else:
607
+ if query is None:
608
+ raise Exception("Query is required for the command line mode")
609
+
610
+ result = _run_query(
611
+ query=query,
612
+ date_restrict=date_restrict,
613
+ target_site=target_site,
614
+ output_language=output_language,
615
+ output_length=output_length,
616
+ url_list_str=_read_url_list(url_list_file),
617
+ model_name=model_name,
618
+ log_level=log_level,
619
+ )
620
+ click.echo(result)
621
+
622
+
623
+ if __name__ == "__main__":
624
+ search_extract_summarize()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ click==8.1.7
2
+ requests==2.31.0
3
+ openai==1.40.2
4
+ jinja2==3.1.3
5
+ bs4==0.0.2
6
+ lxml==4.8.0
7
+ python-dotenv==1.0.1
8
+ duckdb==1.1.2
9
+ gradio==5.3.0