Smolagents-ExtraSearchTools / extra_search_tools.py
Akjava's picture
init
ca5d696
raw
history blame
7.68 kB
# Copyright 2024-2025 Akihito Miyazaki.
# This code is derived from the DuckDuckGoSearchTool class,
# originally part of the HuggingFace smolagents library.
# https://github.com/huggingface/smolagents
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from smolagents import Tool
import json
import os
from datetime import datetime
"""This module provides search tools. The tools share a common name
because they are designed to be mutually exclusive (only one is used per query).
This code is derived from the DuckDuckGoSearchTool class in the HuggingFace
smolagents library.
TODO:
GoogleCustomSearchTool and BraveSearchTool are not using kwarg in init.
requires:
google-api-python-client
brave-search
"""
class PrioritySearchTool(Tool):
"""A tool that executes searches using multiple search tools in a prioritized order.
This tool takes a list of search tools and executes a query using them. It returns
the first successful result. Results are optionally cached and saved to a JSON file.
Attributes:
name (str): The name of the tool.
description (str): A description of the tool.
inputs (dict): The input schema for the tool.
output_type (str): The output type of the tool.
search_tools (list[Tool]): A list of search tools to use for searching.
save_json_path (str, optional): The path to a JSON file where search results
will be saved. Defaults to None.
history_results (dict): A dictionary storing past search results.
"""
name = "web_search"
description = """Performs a google-custom web search based on your query (think a Google search) then returns the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "string"
def __init__(
self,
search_tools: list[Tool],
save_json_path: str = None,
**kwargs,
):
super().__init__()
self.search_tools = search_tools
self.save_json_path = save_json_path
self.history_results = {}
def forward(self, query: str) -> str:
if os.path.exists(self.save_json_path):
with open(self.save_json_path, "r") as file:
self.history_results = json.load(file)
if query in self.history_results:
return self.history_results[query]["data"]
for search_tool in self.search_tools:
try:
result = search_tool(query=query)
if self.save_json_path:
class_name = search_tool.__class__.__name__
self.history_results[query] = {
"cdate": str(datetime.now()),
"name": class_name,
"data": result,
}
with open(self.save_json_path, "w") as file:
json.dump(self.history_results, file)
return result
except Exception as e:
print(f"{e}")
raise Exception("All search tools failed.")
class GoogleCustomSearchTool(Tool):
"""
use
https://github.com/googleapis/google-api-python-client/
parameter
https://developers.google.com/custom-search/v1/reference/rest/v1/cse/list
Exp:another language
search = GoogleCustomSearchTool("33ec073e195bc4fcf", cr="countryJP", lr="lang_ja")
"""
name = "web_search"
description = """Performs a google-custom web search based on your query (think a Google search) then returns the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "string"
def __init__(self, cx, max_results=10, **kwargs):
super().__init__()
if cx is None:
raise ValueError(
"Need CX(Search Engine ID) need create in custom-search controlpanel"
)
self.cx = cx
self.max_results = max_results
api_key_env_name = "GOOGLE_CUSTOM_SEARCH_KEY"
self.kwargs = kwargs
try:
from googleapiclient.discovery import build
except ImportError as e:
raise ImportError(
"You must install package `google-api-python-client` to run this tool: for instance run `pip install google-api-python-client`."
) from e
import os
self.key = os.getenv(api_key_env_name)
self.custom_search = build("customsearch", "v1")
def forward(self, query: str) -> str:
results = (
self.custom_search.cse()
.list(
key=self.key, q=query, cx=self.cx, num=self.max_results, **self.kwargs
)
.execute()
)
results = results["items"]
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
postprocessed_results = [
f"[{result['title']}]({result['link']})\n{result['snippet']}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
from smolagents import Tool
import json
class BraveSearchTool(Tool):
"""
Use
https://github.com/kayvane1/brave-api
query parameter
https://api-dashboard.search.brave.com/app/documentation/web-search/query
Exp:another language
search = BraveSearchTool(country="JP", search_lang="jp")
"""
name = "web_search"
description = """Performs a google-custom web search based on your query (think a Google search) then returns the top search results."""
inputs = {
"query": {"type": "string", "description": "The search query to perform."}
}
output_type = "string"
def __init__(self, max_results=10, **kwargs):
super().__init__()
self.max_results = max_results
api_key_env_name = "BRAVE_SEARCH_KEY"
self.kwargs = kwargs
try:
from brave import Brave
except ImportError as e:
raise ImportError(
# there are another lib.but this one work one python-3.10
"You must install package `brave-search` to run this tool: for instance run `pip install pip install brave-search`."
) from e
import os
self.brave = Brave(api_key=os.getenv(api_key_env_name))
def clean(text):
return text.replace("<STRING>", "").replace("</STRONG>")
def forward(self, query: str) -> str:
search_results = self.brave.search(
q=query, count=self.max_results, **self.kwargs
)
# pprint.pprint(search_results, indent=4)
results = search_results.web_results
# pprint.pprint(search_results.web_results, indent=4)
if len(results) == 0:
raise Exception("No results found! Try a less restrictive/shorter query.")
postprocessed_results = [
f"[{result['title']}]({result['url']._url})\n{self.clean(result['description'])}"
for result in results
]
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)