avilum commited on
Commit
d30ec9f
·
verified ·
1 Parent(s): 621249e

Create ingest.py

Browse files
Files changed (1) hide show
  1. ingest.py +158 -0
ingest.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import pathlib
4
+
5
+ from langchain_community.document_loaders import TextLoader
6
+ from langchain.docstore.document import Document
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain_community.vectorstores import FAISS
10
+
11
+ os.environ["RAY_memory_monitor_refresh_ms"] = "0"
12
+ os.environ["RAY_DEDUP_LOGS"] = "0"
13
+ import ray
14
+
15
+ from common import DATASET_DIR, EMBEDDING_MODEL_NAME, MODEL_KWARGS, VECTORSTORE_FILENAME
16
+
17
+ # Each document is parsed on the same CPU, to decrease paging and data copies, and up to the the number of vCPUs.
18
+ CONCURRENCY = 32
19
+
20
+
21
+ # @ray.remote(num_cpus=1) # Outside a container, num_cpus=1 might speed things dramatically.
22
+ @ray.remote
23
+ def parse_doc(document_path: str) -> Document:
24
+ print("Loading", document_path)
25
+ loader = TextLoader(document_path)
26
+ langchain_dataset_documents = loader.load()
27
+
28
+ # Update the metadata with the proper metadata JSON file, parsed from Arxiv.com
29
+ return langchain_dataset_documents
30
+
31
+
32
+ def add_documents_to_vector_store(
33
+ vector_store, new_documents, text_splitter, embeddings
34
+ ):
35
+ split_docs = text_splitter.split_documents(new_documents)
36
+ # print("Embedding vectors...")
37
+ store = FAISS.from_documents(split_docs, embeddings)
38
+ if vector_store is None:
39
+ vector_store = store
40
+ else:
41
+ print("Updating vector store", store)
42
+ vector_store.merge_from(store)
43
+ return vector_store
44
+
45
+
46
+ def ingest_dataset_to_vectore_store(
47
+ vectorstore_filename: str, dataset_directory: os.PathLike
48
+ ):
49
+ ray.init()
50
+ vector_store = None
51
+ text_splitter = RecursiveCharacterTextSplitter(
52
+ chunk_size=160, # TODO: Finetune
53
+ chunk_overlap=40, # TODO: Finetune
54
+ length_function=len,
55
+ )
56
+
57
+ dataset_documents = []
58
+ dataset_dir_path = pathlib.Path(dataset_directory)
59
+ dataset_dir_path.mkdir(exist_ok=True)
60
+
61
+ for _dirname in os.listdir(str(dataset_dir_path)):
62
+ if _dirname.startswith("."):
63
+ continue
64
+ catagory_path = dataset_dir_path / pathlib.Path(_dirname)
65
+ for filename in os.listdir(str(dataset_dir_path / catagory_path)):
66
+ dataset_path = dataset_dir_path / catagory_path / pathlib.Path(filename)
67
+ dataset_documents.append(str(dataset_path))
68
+ print(dataset_documents)
69
+ print(f"Found {len(dataset_documents)} items in dataset: ")
70
+ langchain_documents = []
71
+
72
+ model_name = EMBEDDING_MODEL_NAME
73
+ model_kwargs = MODEL_KWARGS
74
+ print("Creating huggingface embeddings for ", model_name)
75
+ embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs)
76
+
77
+ if vector_store is None and os.path.exists(vectorstore_filename):
78
+ print("Loading existing vector store from", vectorstore_filename)
79
+ vector_store = FAISS.load_local(
80
+ vectorstore_filename, embeddings, allow_dangerous_deserialization=True
81
+ )
82
+
83
+ jobs = []
84
+ docs_count = len(dataset_documents)
85
+ failed = 0
86
+ print(f"Embedding {docs_count} documents with Ray...")
87
+ for i, document in enumerate(tqdm(dataset_documents)):
88
+ try:
89
+ # print(f"Submitting job ", i)
90
+ job = parse_doc.remote(document)
91
+ jobs.append(job)
92
+
93
+ if i > 1 and i <= docs_count and i % CONCURRENCY == 0:
94
+ if langchain_documents:
95
+ vector_store = add_documents_to_vector_store(
96
+ vector_store, langchain_documents, text_splitter, embeddings
97
+ )
98
+ print(f"\nSaving vector store to disk at {vectorstore_filename}...")
99
+ try:
100
+ os.unlink(vectorstore_filename)
101
+ except:
102
+ ...
103
+
104
+ vector_store.save_local(vectorstore_filename)
105
+ langchain_documents = []
106
+ jobs = []
107
+
108
+ # Block jobs every CONCURRENCY iterations
109
+ if i > 1 and i % CONCURRENCY == 0:
110
+ # print(f"Collecting {len(jobs)} jobs...")
111
+ for _ in jobs:
112
+ try:
113
+ # print("waiting for ray job ", _)
114
+ data = ray.get(_)
115
+ langchain_documents.extend(data)
116
+ except Exception as e:
117
+ print("error in job: ", e)
118
+ continue
119
+ except Exception as e:
120
+ print(f"\n\nERROR reading dataset {i}:", e)
121
+ failed = failed + 1
122
+ continue
123
+
124
+ # print(f"Collecting {len(jobs)} jobs...")
125
+ for _ in jobs:
126
+ try:
127
+ print("waiting for ray job ", _)
128
+ data = ray.get(_)
129
+ langchain_documents.extend(data)
130
+ except Exception as e:
131
+ print("error in job: ", e)
132
+ continue
133
+
134
+ if langchain_documents:
135
+ vector_store = add_documents_to_vector_store(
136
+ vector_store, langchain_documents, text_splitter, embeddings
137
+ )
138
+ print(f"\nSaving vector store to disk at {vectorstore_filename}...")
139
+ try:
140
+ os.unlink(vectorstore_filename)
141
+ except:
142
+ ...
143
+
144
+ vector_store.save_local(vectorstore_filename)
145
+
146
+ return vector_store
147
+
148
+
149
+ def main():
150
+ vectorstore_filename = VECTORSTORE_FILENAME
151
+ dataset_directory = DATASET_DIR
152
+ ingest_dataset_to_vectore_store(
153
+ vectorstore_filename=vectorstore_filename, dataset_directory=dataset_directory
154
+ )
155
+
156
+
157
+ if __name__ == "__main__":
158
+ main()