File size: 3,756 Bytes
fdb6484
fb95c43
 
 
 
 
 
af07445
fb95c43
 
 
 
 
fdb6484
fb95c43
ad2ef92
fdb6484
f4bba44
af07445
fdb6484
fb95c43
 
 
 
 
 
ad2ef92
fb95c43
 
 
 
 
 
ad2ef92
fb95c43
 
 
 
 
 
 
 
 
 
 
4089bc6
 
 
 
fdb6484
fb95c43
 
ad2ef92
fb95c43
 
6f2843a
fb95c43
ddfe3b8
 
 
fb95c43
 
 
 
 
ad2ef92
fb95c43
 
ddfe3b8
 
 
 
 
ad2ef92
ddfe3b8
fb95c43
 
 
4089bc6
 
 
 
 
fb95c43
 
 
 
 
ad2ef92
fb95c43
 
6f2843a
fb95c43
 
 
 
42f834a
6592db0
42f834a
 
6592db0
 
 
 
fdb6484
42f834a
 
fb95c43
fdb6484
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from langchain.tools import tool
from langchain_google_community import GoogleSearchAPIWrapper
from langchain_community.embeddings.sentence_transformer import (
    SentenceTransformerEmbeddings,
)
from langchain_community.vectorstores import Chroma
import ast
from rag_app.loading_data.load_S3_vector_stores import get_chroma_vs
import chromadb

from rag_app.utils.utils import (
    parse_list_to_dicts, format_search_results
)
import chromadb
import os
from config import db, PERSIST_DIRECTORY, EMBEDDING_MODEL

if not os.path.exists(PERSIST_DIRECTORY):
    get_chroma_vs()

@tool
def memory_search(query:str) -> str:
    """Search the memory vector store for existing knowledge and relevent pervious researches. \
        This is your primary source to start your search with checking what you already have learned from the past, before going online."""
    # Since we have more than one collections we should change the name of this tool
    client = chromadb.PersistentClient(
     path=PERSIST_DIRECTORY,
    )
    
    collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
    #store using envar
    
    embedding_function = SentenceTransformerEmbeddings(
        model_name=EMBEDDING_MODEL,
        )
    
    vector_db = Chroma(
    client=client, # client for Chroma
    collection_name=collection_name,
    embedding_function=embedding_function,
    )
    
    retriever = vector_db.as_retriever()
    docs = retriever.invoke(query)
    
    # add the session id to each element in `docs`
    [i.update({"session_id":db.session_id}) for i in docs] 
    db.add_many(docs)
    
    
    return docs.__str__()


@tool
def knowledgeBase_search(query:str) -> str:
    """Suche die interne Datenbank nach passenden Versicherungsprodukten und Informationen zu den Versicherungen"""
    # Since we have more than one collections we should change the name of this tool
    # client = chromadb.PersistentClient(
    #  path=persist_directory,
    # )
    
    #collection_name="ArxivPapers"
    #store using envar
    
    embedding_function = SentenceTransformerEmbeddings(
        model_name=EMBEDDING_MODEL
        )
    
    # vector_db = Chroma(
    # client=client, # client for Chroma
    # #collection_name=collection_name,
    # embedding_function=embedding_function,
    # )
    vector_db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding_function)
    retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={'k':5, 'fetch_k':10})
    # This is deprecated, changed to invoke
    # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
    docs = retriever.invoke(query)
    
    # add the session id to each element in `docs`
    [i.update({"session_id":db.session_id}) for i in docs]
    db.add_many(docs)
    
    for doc in docs:
        print(doc)
    
    return docs.__str__()


@tool
def google_search(query: str) -> str:
    """Verbessere die Ergebnisse durch eine Suche über die Webseite der Versicherung. Erstelle eine neue Suchanfrage, um die Erfolgschancen zu verbesseren."""
    
    websearch = GoogleSearchAPIWrapper()
    search_results:dict = websearch.results(query, 3)
    print(search_results)
    if len(search_results)>1:
        # add session id
        cleaner_sources =format_search_results(search_results)
        parsed_csources = parse_list_to_dicts(cleaner_sources)
        
        # add the session id to each element in `parsed_csources`
        [i.update({"session_id":db.session_id}) for i in parsed_csources]
        
        db.add_many(parsed_csources)
    else:
        cleaner_sources = search_results
    
    return cleaner_sources.__str__()