Spaces:
Runtime error
Runtime error
| import json | |
| from typing import List | |
| from langchain.pydantic_v1 import BaseModel, Field | |
| from langchain.schema import BaseRetriever, Document | |
| from langchain.tools import Tool | |
| from backend.chat_bot.json_decoder import CustomJSONEncoder | |
| class RetrieverInput(BaseModel): | |
| query: str = Field(description="query to look up in retriever") | |
| def create_retriever_tool( | |
| retriever: BaseRetriever, | |
| tool_name: str, | |
| description: str | |
| ) -> Tool: | |
| """Create a tool to do retrieval of documents. | |
| Args: | |
| retriever: The retriever to use for the retrieval | |
| tool_name: The name for the tool. This will be passed to the language model, | |
| so should be unique and somewhat descriptive. | |
| description: The description for the tool. This will be passed to the language | |
| model, so should be descriptive. | |
| Returns: | |
| Tool class to pass to an agent | |
| """ | |
| def wrap(func): | |
| def wrapped_retrieve(*args, **kwargs): | |
| docs: List[Document] = func(*args, **kwargs) | |
| return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder) | |
| return wrapped_retrieve | |
| return Tool( | |
| name=tool_name, | |
| description=description, | |
| func=wrap(retriever.get_relevant_documents), | |
| coroutine=retriever.aget_relevant_documents, | |
| args_schema=RetrieverInput, | |
| ) | |