|
import os |
|
from bs4 import BeautifulSoup |
|
from llama_index.core import Document |
|
from llama_index.core import Settings |
|
from llama_index.core import SimpleDirectoryReader |
|
from llama_index.core import StorageContext |
|
from llama_index.core import VectorStoreIndex |
|
from llama_index.readers.web import SimpleWebPageReader |
|
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore |
|
|
|
import chromadb |
|
import re |
|
from llama_index.llms.cohere import Cohere |
|
from llama_index.embeddings.cohere import CohereEmbedding |
|
|
|
|
|
from llama_index.core.memory import ChatMemoryBuffer |
|
from llama_index.core.chat_engine import CondensePlusContextChatEngine |
|
|
|
import gradio as gr |
|
import uuid |
|
|
|
api_key = os.environ.get("API_KEY") |
|
base_url = os.environ.get("BASE_URL") |
|
|
|
llm = Cohere( |
|
api_key=api_key, |
|
model="command") |
|
embedding_model = CohereEmbedding( |
|
api_key=api_key, |
|
model_name="embed-multilingual-v3.0", |
|
input_type="search_document", |
|
embedding_type="int8",) |
|
|
|
|
|
memory = ChatMemoryBuffer.from_defaults(token_limit=3900) |
|
|
|
|
|
Settings.llm = llm |
|
Settings.embed_model=embedding_model |
|
|
|
Settings.context_window = 4096 |
|
|
|
Settings.num_output = 512 |
|
|
|
|
|
|
|
db_path="" |
|
|
|
def extract_web(url): |
|
web_documents = SimpleWebPageReader().load_data( |
|
[url] |
|
) |
|
html_content = web_documents[0].text |
|
|
|
soup = BeautifulSoup(html_content, 'html.parser') |
|
p_tags = soup.findAll('p') |
|
text_content = "" |
|
for each in p_tags: |
|
text_content += each.text + "\n" |
|
|
|
|
|
documents = [Document(text=text_content)] |
|
option = "web" |
|
return documents, option |
|
|
|
def extract_doc(path): |
|
documents = SimpleDirectoryReader(input_files=path).load_data() |
|
option = "doc" |
|
return documents, option |
|
|
|
|
|
def create_col(documents): |
|
|
|
db_path = f'database/{str(uuid.uuid4())[:4]}' |
|
client = chromadb.PersistentClient(path=db_path) |
|
chroma_collection = client.get_or_create_collection("quickstart") |
|
|
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
storage_context = StorageContext.from_defaults(vector_store=vector_store) |
|
|
|
VectorStoreIndex.from_documents( |
|
documents, storage_context=storage_context |
|
) |
|
return db_path |
|
|
|
def infer(message:str, history: list): |
|
global db_path |
|
option="" |
|
print(f'message: {message}') |
|
print(f'history: {history}') |
|
messages = [] |
|
files_list = message["files"] |
|
|
|
|
|
if files_list: |
|
documents, option = extract_doc(files_list) |
|
db_path = create_col(documents) |
|
else: |
|
if message["text"].startswith("http://") or message["text"].startswith("https://"): |
|
documents, option = extract_web(message["text"]) |
|
db_path = create_col(documents) |
|
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0: |
|
return gr.Error("Please input an url or upload file at first.") |
|
|
|
|
|
|
|
load_client = chromadb.PersistentClient(path=db_path) |
|
|
|
|
|
chroma_collection = load_client.get_collection("quickstart") |
|
|
|
|
|
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) |
|
|
|
|
|
index = VectorStoreIndex.from_vector_store( |
|
vector_store, |
|
) |
|
|
|
if option == "web" and len(history) == 0: |
|
response = "Get the web data! You can ask it." |
|
else: |
|
question = message['text'] |
|
|
|
chat_engine = CondensePlusContextChatEngine.from_defaults( |
|
index.as_retriever(), |
|
memory=memory, |
|
context_prompt=( |
|
"You are a chatbot, able to have normal interactions, as well as talk" |
|
" about the Kendrick and Drake beef." |
|
"Here are the relevant documents for the context:\n" |
|
"{context_str}" |
|
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user." |
|
), |
|
verbose=True, |
|
) |
|
response = chat_engine.chat( |
|
question |
|
) |
|
|
|
print(type(response)) |
|
print(f'response: {response}') |
|
|
|
|
|
return str(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
|
with gr.Blocks(theme="soft") as demo: |
|
gr.ChatInterface( |
|
fn = infer, |
|
title = "RAG demo", |
|
multimodal = True, |
|
chatbot=chatbot, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue(api_open=False).launch(show_api=False, share=False) |