File size: 4,802 Bytes
8364e36
5379f04
 
 
b2c2e74
5379f04
 
 
 
 
 
 
 
f77c387
 
5379f04
fe68312
 
 
 
a710661
5379f04
 
 
f77c387
5379f04
f77c387
 
2fb331a
f77c387
 
 
d6d54c5
 
5379f04
 
fe68312
5379f04
 
 
d6d54c5
 
 
 
fe68312
 
 
5379f04
a9460a2
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
5379f04
 
b2c2e74
 
5379f04
 
226eb83
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
b2c2e74
5379f04
f229ceb
9a6b2aa
5379f04
 
8364e36
5379f04
 
fe68312
5379f04
 
a9460a2
5379f04
 
1ebf8b3
a9460a2
5379f04
fe68312
5379f04
 
 
 
 
 
 
 
 
 
 
 
 
d6d54c5
5379f04
b2c2e74
5379f04
 
 
 
fe68312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b0fd8
 
 
 
5379f04
14b0fd8
51b1469
 
5379f04
 
8364e36
dd3fe36
 
de693c7
 
 
 
62fd8b9
5379f04
de693c7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader

from llama_index.vector_stores.chroma import ChromaVectorStore

import chromadb
import re
from llama_index.llms.cohere import Cohere
from llama_index.embeddings.cohere import CohereEmbedding


from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.chat_engine import CondensePlusContextChatEngine

import gradio as gr
import uuid

api_key = os.environ.get("API_KEY")
base_url = os.environ.get("BASE_URL")

llm = Cohere(
    api_key=api_key, 
    model="command")
embedding_model = CohereEmbedding(
    api_key=api_key, 
    model_name="embed-multilingual-v3.0",
    input_type="search_document",
    embedding_type="int8",)


memory = ChatMemoryBuffer.from_defaults(token_limit=3900)

# Set Global settings
Settings.llm = llm
Settings.embed_model=embedding_model
# set context window
Settings.context_window = 4096
# set number of output tokens
Settings.num_output = 512



db_path=""

def extract_web(url):
    web_documents = SimpleWebPageReader().load_data(
        [url]
    )
    html_content = web_documents[0].text
    # Parse the data.
    soup = BeautifulSoup(html_content, 'html.parser')
    p_tags = soup.findAll('p')
    text_content = ""
    for each in p_tags:
        text_content += each.text + "\n"
    
    # Convert back to Document format
    documents = [Document(text=text_content)]
    option = "web"
    return documents, option

def extract_doc(path):
    documents = SimpleDirectoryReader(input_files=path).load_data()
    option = "doc"
    return documents, option


def create_col(documents):
    # Create a client and a new collection
    db_path = f'database/{str(uuid.uuid4())[:4]}'
    client = chromadb.PersistentClient(path=db_path)
    chroma_collection = client.get_or_create_collection("quickstart")
    
    # Create a vector store
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    
    # Create a storage context
    storage_context = StorageContext.from_defaults(vector_store=vector_store)
    # Create an index from the documents and save it to the disk.
    VectorStoreIndex.from_documents(
        documents, storage_context=storage_context
    )
    return db_path

def infer(message:str, history: list):
    global db_path
    option=""
    print(f'message: {message}')
    print(f'history: {history}')
    messages = []
    files_list = message["files"]
    
         
    if files_list:
        documents, option = extract_doc(files_list)
        db_path = create_col(documents)
    else:
        if message["text"].startswith("http://") or message["text"].startswith("https://"):
            documents, option = extract_web(message["text"])
            db_path = create_col(documents)
        elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
            return gr.Error("Please input an url or upload file at first.")
            

    # Load from disk
    load_client = chromadb.PersistentClient(path=db_path)
    
    # Fetch the collection
    chroma_collection = load_client.get_collection("quickstart")
    
    # Fetch the vector store
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    
    # Get the index from the vector store
    index = VectorStoreIndex.from_vector_store(
        vector_store,
    )

    if option == "web" and len(history) == 0:
        response = "Get the web data! You can ask it."   
    else: 
        question = message['text']

        chat_engine = CondensePlusContextChatEngine.from_defaults(
            index.as_retriever(),
            memory=memory,
            context_prompt=(
                "You are a chatbot, able to have normal interactions, as well as talk"
                " about the Kendrick and Drake beef."
                "Here are the relevant documents for the context:\n"
                "{context_str}"
                "\nInstruction: Use the previous chat history, or the context above, to interact and help the user."
            ),
            verbose=True,
        )
        response = chat_engine.chat(
            question
        )
        
    print(type(response))
    print(f'response: {response}')
    

    return str(response)
    






chatbot = gr.Chatbot()

with gr.Blocks(theme="soft") as demo:
    gr.ChatInterface(
        fn = infer,
        title = "RAG demo", 
        multimodal = True,
        chatbot=chatbot,
    )

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)