Spaces:
Running
Running
import threading | |
from typing import Optional | |
from flask import Flask, current_app | |
from core.rag.data_post_processor.data_post_processor import DataPostProcessor | |
from core.rag.datasource.keyword.keyword_factory import Keyword | |
from core.rag.datasource.vdb.vector_factory import Vector | |
from extensions.ext_database import db | |
from models.dataset import Dataset | |
default_retrieval_model = { | |
'search_method': 'semantic_search', | |
'reranking_enable': False, | |
'reranking_model': { | |
'reranking_provider_name': '', | |
'reranking_model_name': '' | |
}, | |
'top_k': 2, | |
'score_threshold_enabled': False | |
} | |
class RetrievalService: | |
def retrieve(cls, retrival_method: str, dataset_id: str, query: str, | |
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): | |
dataset = db.session.query(Dataset).filter( | |
Dataset.id == dataset_id | |
).first() | |
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: | |
return [] | |
all_documents = [] | |
threads = [] | |
# retrieval_model source with keyword | |
if retrival_method == 'keyword_search': | |
keyword_thread = threading.Thread(target=RetrievalService.keyword_search, kwargs={ | |
'flask_app': current_app._get_current_object(), | |
'dataset_id': dataset_id, | |
'query': query, | |
'top_k': top_k, | |
'all_documents': all_documents | |
}) | |
threads.append(keyword_thread) | |
keyword_thread.start() | |
# retrieval_model source with semantic | |
if retrival_method == 'semantic_search' or retrival_method == 'hybrid_search': | |
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ | |
'flask_app': current_app._get_current_object(), | |
'dataset_id': dataset_id, | |
'query': query, | |
'top_k': top_k, | |
'score_threshold': score_threshold, | |
'reranking_model': reranking_model, | |
'all_documents': all_documents, | |
'retrival_method': retrival_method | |
}) | |
threads.append(embedding_thread) | |
embedding_thread.start() | |
# retrieval source with full text | |
if retrival_method == 'full_text_search' or retrival_method == 'hybrid_search': | |
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ | |
'flask_app': current_app._get_current_object(), | |
'dataset_id': dataset_id, | |
'query': query, | |
'retrival_method': retrival_method, | |
'score_threshold': score_threshold, | |
'top_k': top_k, | |
'reranking_model': reranking_model, | |
'all_documents': all_documents | |
}) | |
threads.append(full_text_index_thread) | |
full_text_index_thread.start() | |
for thread in threads: | |
thread.join() | |
if retrival_method == 'hybrid_search': | |
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |
all_documents = data_post_processor.invoke( | |
query=query, | |
documents=all_documents, | |
score_threshold=score_threshold, | |
top_n=top_k | |
) | |
return all_documents | |
def keyword_search(cls, flask_app: Flask, dataset_id: str, query: str, | |
top_k: int, all_documents: list): | |
with flask_app.app_context(): | |
dataset = db.session.query(Dataset).filter( | |
Dataset.id == dataset_id | |
).first() | |
keyword = Keyword( | |
dataset=dataset | |
) | |
documents = keyword.search( | |
query, | |
top_k=top_k | |
) | |
all_documents.extend(documents) | |
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, | |
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |
all_documents: list, retrival_method: str): | |
with flask_app.app_context(): | |
dataset = db.session.query(Dataset).filter( | |
Dataset.id == dataset_id | |
).first() | |
vector = Vector( | |
dataset=dataset | |
) | |
documents = vector.search_by_vector( | |
query, | |
search_type='similarity_score_threshold', | |
top_k=top_k, | |
score_threshold=score_threshold, | |
filter={ | |
'group_id': [dataset.id] | |
} | |
) | |
if documents: | |
if reranking_model and retrival_method == 'semantic_search': | |
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |
all_documents.extend(data_post_processor.invoke( | |
query=query, | |
documents=documents, | |
score_threshold=score_threshold, | |
top_n=len(documents) | |
)) | |
else: | |
all_documents.extend(documents) | |
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, | |
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], | |
all_documents: list, retrival_method: str): | |
with flask_app.app_context(): | |
dataset = db.session.query(Dataset).filter( | |
Dataset.id == dataset_id | |
).first() | |
vector_processor = Vector( | |
dataset=dataset, | |
) | |
documents = vector_processor.search_by_full_text( | |
query, | |
top_k=top_k | |
) | |
if documents: | |
if reranking_model and retrival_method == 'full_text_search': | |
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) | |
all_documents.extend(data_post_processor.invoke( | |
query=query, | |
documents=documents, | |
score_threshold=score_threshold, | |
top_n=len(documents) | |
)) | |
else: | |
all_documents.extend(documents) | |