Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import time | |
| import traceback | |
| from typing import List, Optional | |
| import logfire | |
| import tiktoken | |
| from cohere import AsyncClient | |
| from dotenv import load_dotenv | |
| from llama_index.core import Document, QueryBundle | |
| from llama_index.core.async_utils import run_async_tasks | |
| from llama_index.core.callbacks import CBEventType, EventPayload | |
| from llama_index.core.retrievers import ( | |
| BaseRetriever, | |
| KeywordTableSimpleRetriever, | |
| VectorIndexRetriever, | |
| ) | |
| from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode | |
| from llama_index.core.vector_stores import ( | |
| FilterCondition, | |
| FilterOperator, | |
| MetadataFilter, | |
| MetadataFilters, | |
| ) | |
| from llama_index.postprocessor.cohere_rerank import CohereRerank | |
| from llama_index.postprocessor.cohere_rerank.base import CohereRerank | |
| load_dotenv() | |
| class AsyncCohereRerank(CohereRerank): | |
| def __init__( | |
| self, | |
| top_n: int = 5, | |
| model: str = "rerank-english-v3.0", | |
| api_key: Optional[str] = None, | |
| ) -> None: | |
| super().__init__(top_n=top_n, model=model, api_key=api_key) | |
| self._api_key = api_key | |
| self._model = model | |
| self._top_n = top_n | |
| async def apostprocess_nodes( | |
| self, | |
| nodes: List[NodeWithScore], | |
| query_bundle: Optional[QueryBundle] = None, | |
| ) -> List[NodeWithScore]: | |
| if query_bundle is None: | |
| raise ValueError("Query bundle must be provided.") | |
| if len(nodes) == 0: | |
| return [] | |
| async_client = AsyncClient(api_key=self._api_key) | |
| with self.callback_manager.event( | |
| CBEventType.RERANKING, | |
| payload={ | |
| EventPayload.NODES: nodes, | |
| EventPayload.MODEL_NAME: self._model, | |
| EventPayload.QUERY_STR: query_bundle.query_str, | |
| EventPayload.TOP_K: self._top_n, | |
| }, | |
| ) as event: | |
| texts = [ | |
| node.node.get_content(metadata_mode=MetadataMode.EMBED) | |
| for node in nodes | |
| ] | |
| results = await async_client.rerank( | |
| model=self._model, | |
| top_n=self._top_n, | |
| query=query_bundle.query_str, | |
| documents=texts, | |
| ) | |
| new_nodes = [] | |
| for result in results.results: | |
| new_node_with_score = NodeWithScore( | |
| node=nodes[result.index].node, score=result.relevance_score | |
| ) | |
| new_nodes.append(new_node_with_score) | |
| event.on_end(payload={EventPayload.NODES: new_nodes}) | |
| return new_nodes | |
| class CustomRetriever(BaseRetriever): | |
| """Custom retriever that performs both semantic search and hybrid search.""" | |
| def __init__( | |
| self, | |
| vector_retriever: VectorIndexRetriever, | |
| document_dict: dict, | |
| keyword_retriever=None, | |
| mode: str = "AND", | |
| ) -> None: | |
| """Init params.""" | |
| self._vector_retriever = vector_retriever | |
| self._document_dict = document_dict | |
| self._keyword_retriever = keyword_retriever | |
| if mode not in ("AND", "OR"): | |
| raise ValueError("Invalid mode.") | |
| self._mode = mode | |
| super().__init__() | |
| async def _process_retrieval( | |
| self, query_bundle: QueryBundle, is_async: bool = True | |
| ) -> List[NodeWithScore]: | |
| """Common processing logic for both sync and async retrieval.""" | |
| # Clean query string | |
| query_bundle.query_str = query_bundle.query_str.replace( | |
| "\ninput is ", "" | |
| ).rstrip() | |
| logfire.info(f"Retrieving nodes with string: '{query_bundle}'") | |
| start = time.time() | |
| # Get nodes from both retrievers | |
| if is_async: | |
| nodes = await self._vector_retriever.aretrieve(query_bundle) | |
| else: | |
| nodes = self._vector_retriever.retrieve(query_bundle) | |
| keyword_nodes = [] | |
| if self._keyword_retriever: | |
| if is_async: | |
| keyword_nodes = await self._keyword_retriever.aretrieve(query_bundle) | |
| else: | |
| keyword_nodes = self._keyword_retriever.retrieve(query_bundle) | |
| logfire.info(f"Number of vector nodes: {len(nodes)}") | |
| logfire.info(f"Number of keyword nodes: {len(keyword_nodes)}") | |
| # # Filter keyword nodes based on metadata filters from vector retriever | |
| # if ( | |
| # hasattr(self._vector_retriever, "_filters") | |
| # and self._vector_retriever._filters | |
| # ): | |
| # filtered_keyword_nodes = [] | |
| # for node in keyword_nodes: | |
| # node_source = node.node.metadata.get("source") | |
| # # Check if node's source matches any of the filter conditions | |
| # for filter in self._vector_retriever._filters.filters: | |
| # if ( | |
| # isinstance(filter, MetadataFilter) | |
| # and filter.key == "source" | |
| # and filter.operator == FilterOperator.EQ | |
| # and filter.value == node_source | |
| # ): | |
| # filtered_keyword_nodes.append(node) | |
| # break | |
| # keyword_nodes = filtered_keyword_nodes | |
| # logfire.info( | |
| # f"Number of keyword nodes after filtering: {len(keyword_nodes)}" | |
| # ) | |
| # Combine results based on mode | |
| vector_ids = {n.node.node_id for n in nodes} | |
| keyword_ids = {n.node.node_id for n in keyword_nodes} | |
| combined_dict = {n.node.node_id: n for n in nodes} | |
| combined_dict.update({n.node.node_id: n for n in keyword_nodes}) | |
| # If no keyword retriever or no keyword nodes, just use vector nodes | |
| if not self._keyword_retriever or not keyword_nodes: | |
| retrieve_ids = vector_ids | |
| else: | |
| retrieve_ids = ( | |
| vector_ids.intersection(keyword_ids) | |
| if self._mode == "AND" | |
| else vector_ids.union(keyword_ids) | |
| ) | |
| nodes = [combined_dict[rid] for rid in retrieve_ids] | |
| logfire.info(f"Number of combined nodes: {len(nodes)}") | |
| # Filter unique doc IDs | |
| nodes = self._filter_nodes_by_unique_doc_id(nodes) | |
| logfire.info(f"Number of nodes without duplicate doc IDs: {len(nodes)}") | |
| # Process node contents | |
| for node in nodes: | |
| doc_id = node.node.source_node.node_id | |
| if node.metadata["retrieve_doc"]: | |
| doc = self._document_dict[doc_id] | |
| node.node.text = doc.text | |
| node.node.node_id = doc_id | |
| # Rerank results | |
| try: | |
| reranker = ( | |
| AsyncCohereRerank(top_n=5, model="rerank-english-v3.0") | |
| if is_async | |
| else CohereRerank(top_n=5, model="rerank-english-v3.0") | |
| ) | |
| nodes = ( | |
| await reranker.apostprocess_nodes(nodes, query_bundle) | |
| if is_async | |
| else reranker.postprocess_nodes(nodes, query_bundle) | |
| ) | |
| except Exception as e: | |
| error_msg = f"Error during reranking: {type(e).__name__}: {str(e)}\n" | |
| error_msg += "Traceback:\n" | |
| error_msg += traceback.format_exc() | |
| logfire.error(error_msg) | |
| # Filter by score and token count | |
| nodes_filtered = self._filter_by_score_and_tokens(nodes) | |
| duration = time.time() - start | |
| logfire.info(f"Retrieving nodes took {duration:.2f}s") | |
| logfire.info(f"Nodes sent to LLM: {nodes_filtered[:5]}") | |
| return nodes_filtered[:5] | |
| def _filter_nodes_by_unique_doc_id( | |
| self, nodes: List[NodeWithScore] | |
| ) -> List[NodeWithScore]: | |
| """Filter nodes to keep only unique doc IDs.""" | |
| unique_nodes = {} | |
| for node in nodes: | |
| doc_id = node.node.source_node.node_id | |
| if doc_id is not None and doc_id not in unique_nodes: | |
| unique_nodes[doc_id] = node | |
| return list(unique_nodes.values()) | |
| def _filter_by_score_and_tokens( | |
| self, nodes: List[NodeWithScore] | |
| ) -> List[NodeWithScore]: | |
| """Filter nodes by score and token count.""" | |
| nodes_filtered = [] | |
| total_tokens = 0 | |
| enc = tiktoken.encoding_for_model("gpt-4o-mini") | |
| for node in nodes: | |
| if node.score < 0.10: | |
| logfire.info(f"Skipping node with score {node.score}") | |
| continue | |
| node_tokens = len(enc.encode(node.node.text)) | |
| if total_tokens + node_tokens > 100_000: | |
| logfire.info("Skipping node due to token count exceeding 100k") | |
| break | |
| total_tokens += node_tokens | |
| nodes_filtered.append(node) | |
| return nodes_filtered | |
| async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
| """Async retrieve nodes given query.""" | |
| return await self._process_retrieval(query_bundle, is_async=True) | |
| def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
| """Sync retrieve nodes given query.""" | |
| return asyncio.run(self._process_retrieval(query_bundle, is_async=False)) | |