File size: 2,582 Bytes
dc92420
 
 
 
 
 
66f2ba6
 
 
dc92420
 
 
 
 
 
66f2ba6
dc92420
 
 
 
 
 
 
 
 
 
66f2ba6
dc92420
66f2ba6
dc92420
66f2ba6
 
 
dc92420
 
 
bfe9335
66f2ba6
dc92420
bfe9335
 
dc92420
 
66f2ba6
dc92420
66f2ba6
bfe9335
dc92420
 
66f2ba6
dc92420
 
 
 
 
66f2ba6
dc92420
66f2ba6
dc92420
 
 
 
 
bfe9335
 
 
66f2ba6
bfe9335
66f2ba6
bfe9335
 
66f2ba6
bfe9335
dc92420
66f2ba6
 
 
 
dc92420
 
 
746e929
bfe9335
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma

import logging
logger = logging.getLogger(__name__)


file_path="./paul_graham_essays.csv"
db_persist_directory = './docs/chroma/'


def load_data():
    logger.info(f'Instantiating CSVLoader with file_path={file_path}')
    loader = CSVLoader(
        file_path=file_path,
        csv_args={
            "delimiter": ",",
            "fieldnames": ['id', 'title', 'date', 'text'],
        },
        source_column='title',
        metadata_columns=['date'],
        content_columns=['text'],
    )
    logger.info('Instantiating CSVLoader complete')

    logger.info('Loading data')
    data = loader.load()
    logger.info('Loading data complete')
    
    logger.info('Returning data')
    return data[1:]


def split_data(data, chunk_size, chunk_overlap):
    logger.info(f'Instantiating RecursiveCharacterTextSplitter with chunk_size={chunk_size} and chunk_overlap={chunk_overlap}')
    splitter = RecursiveCharacterTextSplitter(
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap,
        separators=['\n\n', '\n', '(?<=\. )', ' ', '']
    )
    logger.info('Instantiating RecursiveCharacterTextSplitter complete')

    logger.info('Generating and returning splits')
    return splitter.split_documents(data)


def generate_embeddings(model_name):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_kwargs = {'device': device} 

    encode_kwargs = {'normalize_embeddings': False}

    logger.info(f'Instantiating and returning HuggingFaceEmbeddings with model_name={model_name}, model_kwargs={model_kwargs} and encode_kwargs={encode_kwargs}')
    return HuggingFaceEmbeddings(
        model_name=model_name,
        model_kwargs=model_kwargs,
        encode_kwargs=encode_kwargs,
    )


def get_db(
    chunk_size=1000,
    chunk_overlap=200,
    model_name = 'intfloat/multilingual-e5-large-instruct',
):
    logger.info('Getting data')
    data = load_data()
    
    logger.info('Getting splits')
    splits = split_data(data, chunk_size, chunk_overlap)

    logger.info('Getting embedding')
    embedding = generate_embeddings(model_name)

    logger.info(f'Instantiating and returning Chroma DB with persist_directory={db_persist_directory}')    
    return Chroma.from_documents(
        documents=splits,
        embedding=embedding,
        persist_directory=db_persist_directory,
    )