abhinand2's picture
Update db.py
66f2ba6 verified
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
import logging
logger = logging.getLogger(__name__)
file_path="./paul_graham_essays.csv"
db_persist_directory = './docs/chroma/'
def load_data():
logger.info(f'Instantiating CSVLoader with file_path={file_path}')
loader = CSVLoader(
file_path=file_path,
csv_args={
"delimiter": ",",
"fieldnames": ['id', 'title', 'date', 'text'],
},
source_column='title',
metadata_columns=['date'],
content_columns=['text'],
)
logger.info('Instantiating CSVLoader complete')
logger.info('Loading data')
data = loader.load()
logger.info('Loading data complete')
logger.info('Returning data')
return data[1:]
def split_data(data, chunk_size, chunk_overlap):
logger.info(f'Instantiating RecursiveCharacterTextSplitter with chunk_size={chunk_size} and chunk_overlap={chunk_overlap}')
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
separators=['\n\n', '\n', '(?<=\. )', ' ', '']
)
logger.info('Instantiating RecursiveCharacterTextSplitter complete')
logger.info('Generating and returning splits')
return splitter.split_documents(data)
def generate_embeddings(model_name):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_kwargs = {'device': device}
encode_kwargs = {'normalize_embeddings': False}
logger.info(f'Instantiating and returning HuggingFaceEmbeddings with model_name={model_name}, model_kwargs={model_kwargs} and encode_kwargs={encode_kwargs}')
return HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
def get_db(
chunk_size=1000,
chunk_overlap=200,
model_name = 'intfloat/multilingual-e5-large-instruct',
):
logger.info('Getting data')
data = load_data()
logger.info('Getting splits')
splits = split_data(data, chunk_size, chunk_overlap)
logger.info('Getting embedding')
embedding = generate_embeddings(model_name)
logger.info(f'Instantiating and returning Chroma DB with persist_directory={db_persist_directory}')
return Chroma.from_documents(
documents=splits,
embedding=embedding,
persist_directory=db_persist_directory,
)