|
import datasets |
|
from langchain_core.documents import Document |
|
from langchain_text_splitters import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
import faiss |
|
from langchain_community.docstore.in_memory import InMemoryDocstore |
|
from rag.settings import get_embeddings_model |
|
|
|
|
|
def get_vector_store(): |
|
embeddings = get_embeddings_model() |
|
index = faiss.IndexFlatL2(len(embeddings.embed_query("hello world"))) |
|
|
|
vector_store = FAISS( |
|
embedding_function=embeddings, |
|
index=index, |
|
docstore=InMemoryDocstore(), |
|
index_to_docstore_id={}, |
|
) |
|
return vector_store |
|
|
|
|
|
def get_docs(dataset): |
|
source_docs = [ |
|
Document( |
|
page_content=model["model_card"], |
|
metadata={ |
|
"model_id": model["model_id"], |
|
"model_labels": model["model_labels"], |
|
}, |
|
) |
|
for model in dataset |
|
] |
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
add_start_index=True, |
|
strip_whitespace=True, |
|
separators=["\n\n", "\n", ".", " ", ""], |
|
) |
|
docs_processed = text_splitter.split_documents(source_docs) |
|
print(f"Knowledge base prepared with {len(docs_processed)} document chunks") |
|
return docs_processed |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset = datasets.load_dataset( |
|
"stevenbucaille/object-detection-models-dataset", split="train" |
|
) |
|
docs_processed = get_docs(dataset) |
|
vector_store = get_vector_store() |
|
vector_store.add_documents(docs_processed) |
|
vector_store.save_local( |
|
folder_path="vector_store", |
|
index_name="object_detection_models_faiss_index", |
|
) |
|
|