File size: 7,875 Bytes
141b0a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
import warnings
import shutil
from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader, WikipediaLoader
from typing import List, Optional, Dict, Any
from langchain.schema import Document
import chromadb
# from langchain_community.embeddings.sentence_transformer import (SentenceTransformerEmbeddings)
from langchain_community.vectorstores import FAISS



warnings.filterwarnings("ignore")
CHROMA_DB_PATH = os.path.join(os.getcwd(), "Stock Sentiment Analysis", "chroma_db")
# FAISS_DB_PATH = os.path.join(os.getcwd(), "Stock Sentiment Analysis", "faiss_index")
tesla_10k_collection = 'tesla-10k-2019-to-2023'
embedding_model = ""
# embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')


class DBStorage:
    def __init__(self):
        self.CHROMA_PATH = CHROMA_DB_PATH
        self.vector_store = None
        self.client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
        print(self.client.list_collections())
        self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
        print(self.collection.count())

    def chunk_data(self, data, chunk_size=10000):
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
        return text_splitter.split_documents(data)

    def create_embeddings(self, chunks):
        embeddings = AzureOpenAIEmbeddings(
            model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
            api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
            api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
            azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
        )
        
        self.vector_store = Chroma.from_documents(documents=chunks, 
                                                #   embedding=embeddings, 
                                                  embedding=embedding_model, 
                                                  collection_name=tesla_10k_collection, 
                                                  persist_directory=self.CHROMA_PATH)
        print("Here B")
        self.collection = self.client.get_or_create_collection(name=tesla_10k_collection)
        print("here"+str(self.collection.count()))
        # return self.vector_store
    
    def create_vector_store(self, chunks):
        embeddings = AzureOpenAIEmbeddings(
            model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
            api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
            api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
            azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
        )
        return FAISS.from_documents(chunks, embedding=embeddings)
        # vector_store.save_local(FAISS_DB_PATH)
        
    
    def load_embeddings(self):
        embeddings = AzureOpenAIEmbeddings(
            model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
            api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
            api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
            azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
        )

        self.vector_store = Chroma(collection_name=tesla_10k_collection, 
                                   persist_directory=CHROMA_DB_PATH, 
                                #    embedding_function=embeddings
                                   embedding_function=embedding_model
                                   )
        print("loaded vector store: ")
        print(self.vector_store)
        # return self.vector_store

    def load_vectors(self,FAISS_DB_PATH):
        embeddings = AzureOpenAIEmbeddings(
            model=os.getenv("AZURE_OPENAI_EMBEDDING_NAME"),
            api_key=os.getenv("AZURE_OPENAI_EMBEDDING_API_KEY"),
            api_version=os.getenv("AZURE_OPENAI_EMBEDDING_API_VERSION"),
            azure_endpoint=os.getenv("AZURE_OPENAI_EMBEDDING_ENDPOINT")
        )

        self.vector_store = FAISS.load_local(folder_path=FAISS_DB_PATH, 
                                            embeddings=embeddings, 
                                            allow_dangerous_deserialization=True)


    
    def fetch_documents(self, metadata_filter: Dict[str, Any]) -> List[Document]:
        results = self.collection.get(
            where=metadata_filter,
            include=["documents", "metadatas"],
        )

        documents = []
        for content, metadata in zip(results['documents'][0], results['metadatas'][0]):
            documents.append(Document(page_content=content, metadata=metadata))

        return documents

    
    def get_context_for_query(self, question, k=3):
        print(self.vector_store)
        # if not self.vector_store:
        #     raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")

        # relevant_document_chunks=self.fetch_documents({"company": question})

        # retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
        # relevant_document_chunks = retriever.get_relevant_documents(question)
        
        relevant_document_chunks = self.vector_store.similarity_search(question)
        # chain = get_conversational_chain(models.llm)
        # response = chain({"input_documents": docs, "question": user_question}, return_only_outputs=True)
        # print(response)
        
        print(relevant_document_chunks)
        context_list = [d.page_content for d in relevant_document_chunks]
        context_for_query = ". ".join(context_list)
        print("context_for_query: "+ str(len(context_for_query)))

        return context_for_query
    
    # def ask_question(self, question, k=3):
    #     if not self.vector_store:
    #         raise ValueError("Vector store not initialized. Call create_embeddings() or load_embeddings() first.")

    #     llm = AzureChatOpenAI(
    #         temperature=0,
    #         api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    #         api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
    #         azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    #         model=os.getenv("AZURE_OPENAI_MODEL_NAME")
    #     )

    #     retriever = self.vector_store.as_retriever(search_type='similarity', search_kwargs={'k': k})
    #     chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=retriever)
        
    #     return chain.invoke(question)
    
    def embed_vectors(self,social_media_document,FAISS_DB_PATH):
        print("here A")
        chunks = self.chunk_data(social_media_document)
        print(len(chunks))
        # self.create_embeddings(chunks)
        vector_store = self.create_vector_store(chunks)
        check_and_delete(FAISS_DB_PATH)
        vector_store.save_local(FAISS_DB_PATH)

def check_and_delete(PATH):
    if os.path.isdir(PATH):
        shutil.rmtree(PATH, onexc=lambda func, path, exc: os.chmod(path, 0o777))
        print(f'Deleted {PATH}')
    
def clear_db():
    check_and_delete(CHROMA_DB_PATH)
    # check_and_delete(FAISS_DB_PATH)


# Usage example
if __name__ == "__main__":
    qa_system = DBStorage()

    # Load and process document
    social_media_document = []
    chunks = qa_system.chunk_data(social_media_document)

    # Create embeddings
    qa_system.create_embeddings(chunks)

    # # Ask a question
    # question = 'Summarize the whole input in 150 words'
    # answer = qa_system.ask_question(question)
    # print(answer)