|
import threading |
|
from typing import Optional |
|
|
|
from flask import Flask, current_app |
|
|
|
from core.rag.data_post_processor.data_post_processor import DataPostProcessor |
|
from core.rag.datasource.keyword.keyword_factory import Keyword |
|
from core.rag.datasource.vdb.vector_factory import Vector |
|
from core.rag.rerank.rerank_type import RerankMode |
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod |
|
from extensions.ext_database import db |
|
from models.dataset import Dataset |
|
from services.external_knowledge_service import ExternalDatasetService |
|
|
|
default_retrieval_model = { |
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, |
|
"reranking_enable": False, |
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, |
|
"top_k": 2, |
|
"score_threshold_enabled": False, |
|
} |
|
|
|
|
|
class RetrievalService: |
|
@classmethod |
|
def retrieve( |
|
cls, |
|
retrieval_method: str, |
|
dataset_id: str, |
|
query: str, |
|
top_k: int, |
|
score_threshold: Optional[float] = 0.0, |
|
reranking_model: Optional[dict] = None, |
|
reranking_mode: Optional[str] = "reranking_model", |
|
weights: Optional[dict] = None, |
|
): |
|
if not query: |
|
return [] |
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
if not dataset: |
|
return [] |
|
|
|
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: |
|
return [] |
|
all_documents = [] |
|
threads = [] |
|
exceptions = [] |
|
|
|
if retrieval_method == "keyword_search": |
|
keyword_thread = threading.Thread( |
|
target=RetrievalService.keyword_search, |
|
kwargs={ |
|
"flask_app": current_app._get_current_object(), |
|
"dataset_id": dataset_id, |
|
"query": query, |
|
"top_k": top_k, |
|
"all_documents": all_documents, |
|
"exceptions": exceptions, |
|
}, |
|
) |
|
threads.append(keyword_thread) |
|
keyword_thread.start() |
|
|
|
if RetrievalMethod.is_support_semantic_search(retrieval_method): |
|
embedding_thread = threading.Thread( |
|
target=RetrievalService.embedding_search, |
|
kwargs={ |
|
"flask_app": current_app._get_current_object(), |
|
"dataset_id": dataset_id, |
|
"query": query, |
|
"top_k": top_k, |
|
"score_threshold": score_threshold, |
|
"reranking_model": reranking_model, |
|
"all_documents": all_documents, |
|
"retrieval_method": retrieval_method, |
|
"exceptions": exceptions, |
|
}, |
|
) |
|
threads.append(embedding_thread) |
|
embedding_thread.start() |
|
|
|
|
|
if RetrievalMethod.is_support_fulltext_search(retrieval_method): |
|
full_text_index_thread = threading.Thread( |
|
target=RetrievalService.full_text_index_search, |
|
kwargs={ |
|
"flask_app": current_app._get_current_object(), |
|
"dataset_id": dataset_id, |
|
"query": query, |
|
"retrieval_method": retrieval_method, |
|
"score_threshold": score_threshold, |
|
"top_k": top_k, |
|
"reranking_model": reranking_model, |
|
"all_documents": all_documents, |
|
"exceptions": exceptions, |
|
}, |
|
) |
|
threads.append(full_text_index_thread) |
|
full_text_index_thread.start() |
|
|
|
for thread in threads: |
|
thread.join() |
|
|
|
if exceptions: |
|
exception_message = ";\n".join(exceptions) |
|
raise Exception(exception_message) |
|
|
|
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: |
|
data_post_processor = DataPostProcessor( |
|
str(dataset.tenant_id), reranking_mode, reranking_model, weights, False |
|
) |
|
all_documents = data_post_processor.invoke( |
|
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k |
|
) |
|
return all_documents |
|
|
|
@classmethod |
|
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): |
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
if not dataset: |
|
return [] |
|
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( |
|
dataset.tenant_id, dataset_id, query, external_retrieval_model |
|
) |
|
return all_documents |
|
|
|
@classmethod |
|
def keyword_search( |
|
cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list, exceptions: list |
|
): |
|
with flask_app.app_context(): |
|
try: |
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
keyword = Keyword(dataset=dataset) |
|
|
|
documents = keyword.search(cls.escape_query_for_search(query), top_k=top_k) |
|
all_documents.extend(documents) |
|
except Exception as e: |
|
exceptions.append(str(e)) |
|
|
|
@classmethod |
|
def embedding_search( |
|
cls, |
|
flask_app: Flask, |
|
dataset_id: str, |
|
query: str, |
|
top_k: int, |
|
score_threshold: Optional[float], |
|
reranking_model: Optional[dict], |
|
all_documents: list, |
|
retrieval_method: str, |
|
exceptions: list, |
|
): |
|
with flask_app.app_context(): |
|
try: |
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
vector = Vector(dataset=dataset) |
|
|
|
documents = vector.search_by_vector( |
|
cls.escape_query_for_search(query), |
|
search_type="similarity_score_threshold", |
|
top_k=top_k, |
|
score_threshold=score_threshold, |
|
filter={"group_id": [dataset.id]}, |
|
) |
|
|
|
if documents: |
|
if ( |
|
reranking_model |
|
and reranking_model.get("reranking_model_name") |
|
and reranking_model.get("reranking_provider_name") |
|
and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH.value |
|
): |
|
data_post_processor = DataPostProcessor( |
|
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False |
|
) |
|
all_documents.extend( |
|
data_post_processor.invoke( |
|
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) |
|
) |
|
) |
|
else: |
|
all_documents.extend(documents) |
|
except Exception as e: |
|
exceptions.append(str(e)) |
|
|
|
@classmethod |
|
def full_text_index_search( |
|
cls, |
|
flask_app: Flask, |
|
dataset_id: str, |
|
query: str, |
|
top_k: int, |
|
score_threshold: Optional[float], |
|
reranking_model: Optional[dict], |
|
all_documents: list, |
|
retrieval_method: str, |
|
exceptions: list, |
|
): |
|
with flask_app.app_context(): |
|
try: |
|
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() |
|
|
|
vector_processor = Vector( |
|
dataset=dataset, |
|
) |
|
|
|
documents = vector_processor.search_by_full_text(cls.escape_query_for_search(query), top_k=top_k) |
|
if documents: |
|
if ( |
|
reranking_model |
|
and reranking_model.get("reranking_model_name") |
|
and reranking_model.get("reranking_provider_name") |
|
and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH.value |
|
): |
|
data_post_processor = DataPostProcessor( |
|
str(dataset.tenant_id), RerankMode.RERANKING_MODEL.value, reranking_model, None, False |
|
) |
|
all_documents.extend( |
|
data_post_processor.invoke( |
|
query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) |
|
) |
|
) |
|
else: |
|
all_documents.extend(documents) |
|
except Exception as e: |
|
exceptions.append(str(e)) |
|
|
|
@staticmethod |
|
def escape_query_for_search(query: str) -> str: |
|
return query.replace('"', '\\"') |
|
|