File size: 5,035 Bytes
8364e36 5379f04 b2c2e74 5379f04 f77c387 5379f04 fe68312 a710661 5379f04 f77c387 5379f04 f77c387 2fb331a f77c387 d6d54c5 5379f04 fe68312 5379f04 d6d54c5 fe68312 5379f04 a9460a2 5379f04 b2c2e74 5379f04 b2c2e74 5379f04 226eb83 5379f04 b2c2e74 5379f04 f229ceb 9a6b2aa 5379f04 8364e36 5379f04 fe68312 5379f04 a9460a2 5379f04 1ebf8b3 a9460a2 a716052 5379f04 1615c34 5379f04 d6d54c5 5379f04 b2c2e74 5379f04 1615c34 5379f04 fe68312 1615c34 fe68312 1615c34 a716052 fe68312 14b0fd8 5379f04 14b0fd8 51b1469 5379f04 8364e36 dd3fe36 1615c34 de693c7 1615c34 de693c7 62fd8b9 5379f04 de693c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
from bs4 import BeautifulSoup
from llama_index.core import Document
from llama_index.core import Settings
from llama_index.core import SimpleDirectoryReader
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.readers.web import SimpleWebPageReader
from llama_index.vector_stores.chroma import ChromaVectorStore
import chromadb
import re
from llama_index.llms.cohere import Cohere
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.core.memory import ChatMemoryBuffer
from llama_index.core.chat_engine import CondensePlusContextChatEngine
import gradio as gr
import uuid
api_key = os.environ.get("API_KEY")
base_url = os.environ.get("BASE_URL")
llm = Cohere(
api_key=api_key,
model="command")
embedding_model = CohereEmbedding(
api_key=api_key,
model_name="embed-multilingual-v3.0",
input_type="search_document",
embedding_type="int8",)
memory = ChatMemoryBuffer.from_defaults(token_limit=3900)
# Set Global settings
Settings.llm = llm
Settings.embed_model=embedding_model
# set context window
Settings.context_window = 4096
# set number of output tokens
Settings.num_output = 512
db_path=""
def extract_web(url):
web_documents = SimpleWebPageReader().load_data(
[url]
)
html_content = web_documents[0].text
# Parse the data.
soup = BeautifulSoup(html_content, 'html.parser')
p_tags = soup.findAll('p')
text_content = ""
for each in p_tags:
text_content += each.text + "\n"
# Convert back to Document format
documents = [Document(text=text_content)]
option = "web"
return documents, option
def extract_doc(path):
documents = SimpleDirectoryReader(input_files=path).load_data()
option = "doc"
return documents, option
def create_col(documents):
# Create a client and a new collection
db_path = f'database/{str(uuid.uuid4())[:4]}'
client = chromadb.PersistentClient(path=db_path)
chroma_collection = client.get_or_create_collection("quickstart")
# Create a vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Create a storage context
storage_context = StorageContext.from_defaults(vector_store=vector_store)
# Create an index from the documents and save it to the disk.
VectorStoreIndex.from_documents(
documents, storage_context=storage_context
)
return db_path
def infer(message:str, history: list):
global db_path
option=""
print(f'message: {message}')
print(f'history: {history}')
messages = []
files_list = message["files"]
if files_list:
documents, option = extract_doc(files_list)
db_path = create_col(documents)
else:
if message["text"].startswith("http://") or message["text"].startswith("https://"):
documents, option = extract_web(message["text"])
db_path = create_col(documents)
memory = ChatMemoryBuffer.from_defaults(token_limit=3900)
elif not message["text"].startswith("http://") and not message["text"].startswith("https://") and len(history) == 0:
raise gr.Error("请先输入网址或上传文档。")
# Load from disk
load_client = chromadb.PersistentClient(path=db_path)
# Fetch the collection
chroma_collection = load_client.get_collection("quickstart")
# Fetch the vector store
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
# Get the index from the vector store
index = VectorStoreIndex.from_vector_store(
vector_store,
)
if option == "web" and len(history) == 0:
response = "获取成功,开始对话吧!"
else:
question = message['text']
chat_engine = CondensePlusContextChatEngine.from_defaults(
index.as_retriever(),
memory=memory,
context_prompt=(
"You are an assistant for question-answering tasks."
"Use the following context to answer the question:\n"
"{context_str}"
"\nIf you don't know the answer, just say that you don't know."
"Use five sentences maximum and keep the answer concise."
"\nInstruction: Use the previous chat history, or the context above, to interact and help the user."
),
verbose=True,
)
response = chat_engine.chat(
question
)
print(type(response))
print(f'response: {response}')
return str(response)
chatbot = gr.Chatbot(placeholder="请先输入网址或上传文档<br>然后进行对话")
with gr.Blocks(theme="soft", fill_height="true") as demo:
gr.ChatInterface(
fn = infer,
title = "RAG demo",
multimodal = True,
chatbot=chatbot,
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |