import os from bs4 import BeautifulSoup from llama_index.core import Document from llama_index.core import Settings from llama_index.core import SimpleDirectoryReader from llama_index.core import StorageContext from llama_index.core import VectorStoreIndex from llama_index.readers.web import SimpleWebPageReader from llama_index.vector_stores.chroma import ChromaVectorStore import chromadb import re from llama_index.llms.cohere import Cohere from llama_index.embeddings.cohere import CohereEmbedding from llama_index.core.memory import ChatMemoryBuffer from llama_index.core.chat_engine import CondensePlusContextChatEngine import gradio as gr import uuid api_key = os.environ.get("API_KEY") base_url = os.environ.get("BASE_URL") llm = Cohere( api_key=api_key, model="command") embedding_model = CohereEmbedding( api_key=api_key, model_name="embed-multilingual-v3.0", input_type="search_document", embedding_type="int8",) memory = ChatMemoryBuffer.from_defaults(token_limit=3900) # Set Global settings Settings.llm = llm Settings.embed_model=embedding_model # set context window Settings.context_window = 4096 # set number of output tokens Settings.num_output = 512 db_path="" def extract_web(url): web_documents = SimpleWebPageReader().load_data( [url] ) html_content = web_documents[0].text # Parse the data. soup = BeautifulSoup(html_content, 'html.parser') p_tags = soup.findAll('p') text_content = "" for each in p_tags: text_content += each.text + "\n" # Convert back to Document format documents = [Document(text=text_content)] option = "web" return documents, option def extract_doc(path): documents = SimpleDirectoryReader(input_files=path).load_data() option = "doc" return documents, option def create_col(documents): # Create a client and a new collection db_path = f'database/{str(uuid.uuid4())[:4]}' client = chromadb.PersistentClient(path=db_path) chroma_collection = client.get_or_create_collection("quickstart") # Create a vector store vector_store = ChromaVectorStore(chroma_collection=chroma_collection) # Create a storage context storage_context = StorageContext.from_defaults(vector_store=vector_store) # Create an index from the documents and save it to the disk. VectorStoreIndex.from_documents( documents, storage_context=storage_context ) return db_path def infer(message:str, history: list): global db_path option="" print(f'message: {message}') print(f'history: {history}') messages = [] files_list = message["files"] if files_list: documents, option = extract_doc(files_list) db_path = create_col(documents) else: if message["text"].startswith("http://") or message["text"].startswith("https://"): documents, option = extract_web(message["text"]) db_path = create_col(documents) elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0: return gr.Error("Please input an url or upload file at first.") # Load from disk load_client = chromadb.PersistentClient(path=db_path) # Fetch the collection chroma_collection = load_client.get_collection("quickstart") # Fetch the vector store vector_store = ChromaVectorStore(chroma_collection=chroma_collection) # Get the index from the vector store index = VectorStoreIndex.from_vector_store( vector_store, ) if option == "web" and len(history) == 0: response = "Get the web data! You can ask it." else: question = message['text'] chat_engine = CondensePlusContextChatEngine.from_defaults( index.as_retriever(), memory=memory, context_prompt=( "You are a chatbot, able to have normal interactions, as well as talk" " about the Kendrick and Drake beef." "Here are the relevant documents for the context:\n" "{context_str}" "\nInstruction: Use the previous chat history, or the context above, to interact and help the user." ), verbose=True, ) response = chat_engine.chat( question ) print(type(response)) print(f'response: {response}') return str(response) chatbot = gr.Chatbot() with gr.Blocks(theme="soft") as demo: gr.ChatInterface( fn = infer, title = "RAG demo", multimodal = True, chatbot=chatbot, ) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False, share=False)