Spaces:
Runtime error
Runtime error
File size: 11,958 Bytes
069157b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import json
import os
import time
import uuid
from datetime import datetime
import gradio as gr
import openai
from huggingface_hub import HfApi
from langchain.document_loaders import PyPDFLoader, \
UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader
from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip
from knowledge.img_handler import process_image, add_markup
from llms.chatbot import OpenAIChatBot
from llms.embeddings import EMBEDDINGS_MAPPING
from utils import make_archive
UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID")
HF_TOKEN=os.getenv("HF_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base == os.getenv("OPENAI_API_BASE")
hf_api = HfApi(token=HF_TOKEN)
ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader]
ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys()
PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS}
#######################################################################################################################
# Host multiple vector database for use
#######################################################################################################################
# todo: add this feature in the future
INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天!
***2023-06-06更新:***
1. 支持读取图片格式的图表数据(目前支持JPG, PNG).
2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试.
***2023-06-04更新:***
1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
2. 支持更多的文件格式(PDF, TXT, TEX, 和MD).
3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭.
'''
def load_zip_as_db(file_from_gradio,
pdf_loader,
embedding_model,
chunk_size=300,
chunk_overlap=20,
upload_to_cloud=True):
if chunk_size <= chunk_overlap:
return "chunk_size小于chunk_overlap. 创建失败.", None, None
if file_from_gradio is None:
return "文件为空. 创建失败.", None, None
pdf_loader = PDF_LOADER_MAPPING[pdf_loader]
zip_file_path = file_from_gradio.name
project_name = uuid.uuid4().hex
db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model,
pdf_loader=pdf_loader, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, project_name=project_name)
index_name = project_name + ".zip"
make_archive(project_name, index_name)
date = datetime.today().strftime('%Y-%m-%d')
if upload_to_cloud:
hf_api.upload_file(path_or_fileobj=index_name,
path_in_repo=f"{date}/faiss_{index_name}.zip",
repo_id=UPLOAD_REPO_ID,
repo_type="dataset")
return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta
def load_local_db(file_from_gradio):
if file_from_gradio is None:
return "文件为空. 创建失败.", None
zip_file_path = file_from_gradio.name
db = load_faiss_index_from_zip(zip_file_path)
return "成功读取知识库. 可以开始聊天了!", db
def extract_image(image_path):
from PIL import Image
print("Image Path:", image_path)
im = Image.open(image_path)
table = process_image(im)
print(f"Success in processing the image. Table: {table}")
return table, add_markup(table)
def describe(image):
table = add_markup(process_image(image))
_INSTRUCTION = 'Read the table below to answer the following questions.'
question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. "
prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
messages = [{"role": "assistant", "content": prompt_0shot}]
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.7,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
)
ret = response.choices[0].message['content']
return ret
with gr.Blocks() as demo:
local_db = gr.State(None)
def get_augmented_message(message, local_db, query_count, preprocessing, meta):
print(f"Receiving message: {message}")
print("Detecting if the user need to read image from the local database...")
# read the db_meta.json from the local file
# read the images file list
files = meta["files"]
source_path = meta["source_path"]
# with open(meta.name, "r", encoding="utf-8") as f:
# files = json.load(f)["files"]
img_files = []
for file in files:
if os.path.splitext(file)[1] in [".png", ".jpg"]:
img_files.append(file)
# scan user's input to see if it contains images' name
do_extract_image = False
target_file = None
for file in img_files:
img = os.path.splitext(file)[0]
if img in message:
do_extract_image = True
target_file = file
break
# extract image to tables
image_info = ""
if do_extract_image:
print("The user needs to read image from the local database. Extract image ... ")
target_file = os.path.join(source_path, target_file)
_, image_info = extract_image(target_file)
if len(image_info)>0:
image_content = {"content": image_info, "source": os.path.basename(target_file)}
else:
image_content = None
print("Querying references from the local database...")
contents = []
try:
if query_count > 0:
docs = local_db.similarity_search(message, k=query_count)
for i in range(query_count):
# pre-processing each chunk
content = docs[i].page_content.replace('\n', ' ')
# pre-process meta data
contents.append(content)
except:
print("Failed to query from the local database. ")
# generate augmented_message
print("Success in querying references: {}".format(contents))
if image_content is not None:
augmented_message = f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
else:
augmented_message = "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
return augmented_message + "\n\n" + f"'user_input': {message}"
def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False):
gpt_chatbot = OpenAIChatBot()
print("Chat History: ", chat_history)
print("Local DB: ", local_db is None)
for chat in chat_history:
gpt_chatbot.load_chat(chat)
if local_db is None or query_count == 0:
bot_message = gpt_chatbot(message)
print(bot_message)
print(message)
chat_history.append((message, bot_message))
return "", chat_history
else:
augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta)
bot_message = gpt_chatbot(augmented_message, original_message=message)
print(message)
print(augmented_message)
print(bot_message)
if test_mode:
chat_history.append((augmented_message, bot_message))
else:
chat_history.append((message, bot_message))
time.sleep(response_delay) # sleep 5 seconds to avoid freq. wall.
return "", chat_history
with gr.Row():
with gr.Column():
gr.Markdown(INSTRUCTIONS)
with gr.Row():
with gr.Tab("从本地PDF文件创建知识库"):
zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)")
create_db = gr.Button("创建知识库", variant="primary")
with gr.Accordion("高级设置", open=False):
embedding_selector = gr.Dropdown(ALL_EMBEDDINGS,
value="distilbert-dot-tas_b-b256-msmarco",
label="Embedding Models")
pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS],
value=PyPDFLoader.__name__, label="PDF Loader")
chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500,
label="Chunk size (tokens)")
chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50,
label="Chunk overlap (tokens)")
save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端")
file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)")
with gr.Tab("读取本地知识库文件"):
file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)")
load_db = gr.Button("读取已创建知识库", variant="primary")
with gr.Tab("总结图表(Demo)"):
gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm")
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
extract = gr.Button("总结", variant="primary")
output_text = gr.Textbox(lines=8, label="Output")
with gr.Column():
status = gr.Textbox(label="用来显示程序运行状态的Textbox")
chatbot = gr.Chatbot()
msg = gr.Textbox()
submit = gr.Button("Submit", variant="primary")
with gr.Accordion("高级设置", open=False):
json_output = gr.JSON()
with gr.Row():
query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3,
label="Query counts")
test_mode_checkbox = gr.Checkbox(label="Test mode")
# def load_pdf_as_db(file_from_gradio,
# pdf_loader,
# embedding_model,
# chunk_size=300,
# chunk_overlap=20,
# upload_to_cloud=True):
msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox],
[status, file_dp_output, local_db, json_output])
load_db.click(load_local_db, [file_local], [status, local_db])
extract.click(describe, [input_image], [output_text])
demo.launch(show_api=False)
|