abhinand2's picture
Rename and update db.py
bfe9335 verified
raw
history blame
1.65 kB
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
file_path="./paul_graham_essays.csv"
db_persist_directory = './docs/chroma/'
def load_data():
loader = CSVLoader(
file_path=file_path,
csv_args={
"delimiter": ",",
"fieldnames": ['id', 'title', 'date', 'text'],
},
source_column='title',
metadata_columns=['date'],
content_columns=['text'],
)
data = loader.load()
return data[1:]
def split_data(data, chunk_size, chunk_overlap):
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=['\n\n', '\n', '(?<=\. )', ' ', '']
)
return splitter.split_documents(data)
def generate_embeddings():
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {'device': device}
encode_kwargs = {'normalize_embeddings': False}
return HuggingFaceEmbeddings(
model_name=model_path,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
def get_db(
chunk_size=1000,
chunk_overlap=200,
model_path = 'intfloat/multilingual-e5-large-instruct',
):
data = load_data()
splits = split_data(data, chunk_size, chunk_overlap)
embedding = generate_embeddings()
return Chroma.from_documents(
documents=splits,
embedding=embedding,
persist_directory=persist_directory,
)