Spaces:
Configuration error
Configuration error
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= | |
import os | |
from typing import Any, Dict, List, Optional | |
from camel.retrievers import BaseRetriever | |
from camel.utils import dependencies_required | |
DEFAULT_TOP_K_RESULTS = 1 | |
class CohereRerankRetriever(BaseRetriever): | |
r"""An implementation of the `BaseRetriever` using the `Cohere Re-ranking` | |
model. | |
Attributes: | |
model_name (str): The model name to use for re-ranking. | |
api_key (Optional[str]): The API key for authenticating with the | |
Cohere service. | |
References: | |
https://txt.cohere.com/rerank/ | |
""" | |
def __init__( | |
self, | |
model_name: str = "rerank-multilingual-v2.0", | |
api_key: Optional[str] = None, | |
) -> None: | |
r"""Initializes an instance of the CohereRerankRetriever. This | |
constructor sets up a client for interacting with the Cohere API using | |
the specified model name and API key. If the API key is not provided, | |
it attempts to retrieve it from the COHERE_API_KEY environment | |
variable. | |
Args: | |
model_name (str): The name of the model to be used for re-ranking. | |
Defaults to 'rerank-multilingual-v2.0'. | |
api_key (Optional[str]): The API key for authenticating requests | |
to the Cohere API. If not provided, the method will attempt to | |
retrieve the key from the environment variable | |
'COHERE_API_KEY'. | |
Raises: | |
ImportError: If the 'cohere' package is not installed. | |
ValueError: If the API key is neither passed as an argument nor | |
set in the environment variable. | |
""" | |
import cohere | |
try: | |
self.api_key = api_key or os.environ["COHERE_API_KEY"] | |
except ValueError as e: | |
raise ValueError( | |
"Must pass in cohere api key or specify via COHERE_API_KEY" | |
" environment variable." | |
) from e | |
self.co = cohere.Client(self.api_key) | |
self.model_name = model_name | |
def query( | |
self, | |
query: str, | |
retrieved_result: List[Dict[str, Any]], | |
top_k: int = DEFAULT_TOP_K_RESULTS, | |
) -> List[Dict[str, Any]]: | |
r"""Queries and compiles results using the Cohere re-ranking model. | |
Args: | |
query (str): Query string for information retriever. | |
retrieved_result (List[Dict[str, Any]]): The content to be | |
re-ranked, should be the output from `BaseRetriever` like | |
`VectorRetriever`. | |
top_k (int, optional): The number of top results to return during | |
retriever. Must be a positive integer. Defaults to | |
`DEFAULT_TOP_K_RESULTS`. | |
Returns: | |
List[Dict[str, Any]]: Concatenated list of the query results. | |
""" | |
rerank_results = self.co.rerank( | |
query=query, | |
documents=retrieved_result, | |
top_n=top_k, | |
model=self.model_name, | |
) | |
formatted_results = [] | |
for result in rerank_results.results: | |
selected_chunk = retrieved_result[result.index] | |
selected_chunk['similarity score'] = result.relevance_score | |
formatted_results.append(selected_chunk) | |
return formatted_results | |