SUMANA SUMANAKUL (ING)
first commit
30adccc
raw
history blame
4.73 kB
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch
from langchain_mongodb.retrievers.hybrid_search import MongoDBAtlasHybridSearchRetriever
import os
from dotenv import load_dotenv
load_dotenv()
# ---- MongoDB credentials ----
# mongo_username = os.getenv('MONGO_USERNAME')
# mongo_password = os.getenv('MONGO_PASSWORD')
mongo_database = os.getenv('MONGO_DATABASE')
mongo_connection_str = os.getenv('MONGO_CONNECTION_STRING')
mongo_collection_name = os.getenv('MONGO_COLLECTION')
# ---- Common Configurations & Hybrid Retrieval Configuration ----
MODEL_KWARGS = {"device": "cpu"}
ENCODE_KWARGS = {"normalize_embeddings": True,
"batch_size": 32}
EMBEDDING_DIMENSIONS = 1024
MODEL_NAME = "BAAI/bge-m3"
FINAL_TOP_K = 50 # 30
HYBRID_FULLTEXT_PENALTY = 0 # 60
HYBRID_VECTOR_PENALTY = 0.8 # 60
# ---- Embedding model ----
embed_model = HuggingFaceEmbeddings(
model_name=MODEL_NAME,
model_kwargs=MODEL_KWARGS,
encode_kwargs=ENCODE_KWARGS
)
# ---- Vectore Search ----
num_vector_candidates = max(20, 2 * FINAL_TOP_K)
num_text_candidates = max(20, 2 * FINAL_TOP_K)
vector_k = num_vector_candidates
vector_num_candidates_for_operator = vector_k * 10
# ---- Vectore Store ----
vector_store = MongoDBAtlasVectorSearch.from_connection_string(
connection_string=mongo_connection_str,
namespace=f"{mongo_database}.{mongo_collection_name}",
embedding=embed_model,
index_name="search_index_v1",
)
# ---- Retriever (Hybrid) ----
def get_retriever(**kwargs):
"""
สร้าง Retriever โดยสามารถรับ filter พิเศษสำหรับ Vector Search ได้
"""
# ดึง vector_search_filter ออกมาจาก kwargs
vector_search_filter = kwargs.pop('vector_search_filter', None)
# kwargs ที่เหลือ (ถ้ามี) จะถูกใช้เป็น pre_filter
pre_filter = kwargs if kwargs else None
retriever = MongoDBAtlasHybridSearchRetriever(
vectorstore=vector_store,
search_index_name='search_index_v1',
embedding=embed_model,
text_key= 'text',
embedding_key='embedding',
top_k=FINAL_TOP_K,
vector_penalty=HYBRID_VECTOR_PENALTY,
fulltext_penalty=HYBRID_FULLTEXT_PENALTY,
vector_search_params={
"k": vector_k,
"numCandidates": vector_num_candidates_for_operator,
# --- ส่ง filter ที่ถูกต้องเข้าไปในตำแหน่งที่ถูกต้อง ---
"filter": vector_search_filter
},
text_search_params={
"limit": max(20, 2 * FINAL_TOP_K)
},
pre_filter=pre_filter
)
return retriever
# def get_retriever(**kwargs):
# retriever = MongoDBAtlasHybridSearchRetriever(
# vectorstore=vector_store,
# search_index_name='search_index_v1',
# embedding=embed_model,
# text_key= 'text', #'token',
# embedding_key='embedding',
# top_k=FINAL_TOP_K,
# vector_penalty=HYBRID_VECTOR_PENALTY,
# fulltext_penalty=HYBRID_FULLTEXT_PENALTY,
# vector_search_params={
# "k": vector_k,
# "numCandidates": vector_num_candidates_for_operator
# },
# text_search_params={
# "limit": num_text_candidates
# },
# pre_filter=kwargs
# )
# return retriever
# ---------- FILTER METAAAAA ----------
# ---------- FILTER METAAAAA ----------
# ---------- FILTER METAAAAA ----------
# ---------- FILTER METAAAAA ----------
# def get_retriever(**kwargs):
# # ดึง filter ที่เราจะส่งมาจาก tool ออกมาจาก kwargs
# # เราใช้ .pop() เพื่อเอามันออกมา จะได้ไม่ถูกส่งไปที่ pre_filter ซ้ำซ้อน
# search_filter = kwargs.pop('filter', None)
# retriever = MongoDBAtlasHybridSearchRetriever(
# vectorstore=vector_store,
# search_index_name='search_index_v1',
# embedding=embed_model,
# text_key= 'text',
# embedding_key='embedding',
# top_k=FINAL_TOP_K,
# vector_penalty=HYBRID_VECTOR_PENALTY,
# fulltext_penalty=HYBRID_FULLTEXT_PENALTY,
# vector_search_params={
# "k": vector_k,
# "numCandidates": vector_num_candidates_for_operator,
# "filter": search_filter
# },
# text_search_params={
# "limit": num_text_candidates
# },
# pre_filter=kwargs
# )
# return retriever