File size: 11,958 Bytes
069157b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import json
import os
import time
import uuid
from datetime import datetime

import gradio as gr
import openai
from huggingface_hub import HfApi
from langchain.document_loaders import PyPDFLoader, \
    UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader

from knowledge.faiss_handler import create_faiss_index_from_zip, load_faiss_index_from_zip
from knowledge.img_handler import process_image, add_markup
from llms.chatbot import OpenAIChatBot
from llms.embeddings import EMBEDDINGS_MAPPING
from utils import make_archive

UPLOAD_REPO_ID=os.getenv("UPLOAD_REPO_ID")
HF_TOKEN=os.getenv("HF_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base == os.getenv("OPENAI_API_BASE")
hf_api = HfApi(token=HF_TOKEN)

ALL_PDF_LOADERS = [PyPDFLoader, UnstructuredPDFLoader, PyPDFium2Loader, PyMuPDFLoader, PDFPlumberLoader]
ALL_EMBEDDINGS = EMBEDDINGS_MAPPING.keys()
PDF_LOADER_MAPPING = {loader.__name__: loader for loader in ALL_PDF_LOADERS}


#######################################################################################################################
# Host multiple vector database for use
#######################################################################################################################
# todo: add this feature in the future



INSTRUCTIONS = '''# FAISS Chat: 和本地数据库聊天!

***2023-06-06更新:*** 
1. 支持读取图片格式的图表数据(目前支持JPG, PNG).
2. 在"总结图表(Demo)"的标签页里提供了这个模块的测试.
  
***2023-06-04更新:*** 
1. 支持更多的Embedding Model (目前支持[text-embedding-ada-002](https://openai.com/blog/new-and-improved-embedding-model), [text2vec-large-chinese](https://huggingface.co/GanymedeNil/text2vec-large-chinese), 和[distilbert-dot-tas_b-b256-msmarco](https://huggingface.co/sebastian-hofstaetter/distilbert-dot-tas_b-b256-msmarco) )
2. 支持更多的文件格式(PDF, TXT, TEX, 和MD).
3. 所有生成的数据库都可以在[这个数据集](https://huggingface.co/datasets/shaocongma/shared-faiss-vdb)里访问了!如果不希望文件被上传,可以在高级设置里关闭. 
'''


def load_zip_as_db(file_from_gradio,
                   pdf_loader,
                   embedding_model,
                   chunk_size=300,
                   chunk_overlap=20,
                   upload_to_cloud=True):
    if chunk_size <= chunk_overlap:
        return "chunk_size小于chunk_overlap. 创建失败.", None, None
    if file_from_gradio is None:
        return "文件为空. 创建失败.", None, None
    pdf_loader = PDF_LOADER_MAPPING[pdf_loader]
    zip_file_path = file_from_gradio.name
    project_name = uuid.uuid4().hex
    db, project_name, db_meta = create_faiss_index_from_zip(zip_file_path, embeddings=embedding_model,
                                                   pdf_loader=pdf_loader, chunk_size=chunk_size,
                                                         chunk_overlap=chunk_overlap, project_name=project_name)
    index_name = project_name + ".zip"
    make_archive(project_name, index_name)
    date = datetime.today().strftime('%Y-%m-%d')
    if upload_to_cloud:
        hf_api.upload_file(path_or_fileobj=index_name,
                           path_in_repo=f"{date}/faiss_{index_name}.zip",
                           repo_id=UPLOAD_REPO_ID,
                           repo_type="dataset")
    return "成功创建知识库. 可以开始聊天了!", index_name, db, db_meta


def load_local_db(file_from_gradio):
    if file_from_gradio is None:
        return "文件为空. 创建失败.", None
    zip_file_path = file_from_gradio.name
    db = load_faiss_index_from_zip(zip_file_path)

    return "成功读取知识库. 可以开始聊天了!", db


def extract_image(image_path):
    from PIL import Image
    print("Image Path:", image_path)
    im = Image.open(image_path)
    table = process_image(im)
    print(f"Success in processing the image. Table: {table}")
    return table, add_markup(table)


def describe(image):
    table = add_markup(process_image(image))
    _INSTRUCTION = 'Read the table below to answer the following questions.'
    question = "Please refer to the above table, and write a summary of no less than 200 words based on it in Chinese, ensuring that your response is detailed and precise. "
    prompt_0shot = _INSTRUCTION + "\n" + add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"

    messages = [{"role": "assistant", "content": prompt_0shot}]
    response = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=messages,
        temperature=0.7,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
    )
    ret = response.choices[0].message['content']
    return ret


with gr.Blocks() as demo:
    local_db = gr.State(None)

    def get_augmented_message(message, local_db, query_count, preprocessing, meta):
        print(f"Receiving message: {message}")

        print("Detecting if the user need to read image from the local database...")
        # read the db_meta.json from the local file
        # read the images file list
        files = meta["files"]
        source_path = meta["source_path"]
        # with open(meta.name, "r", encoding="utf-8") as f:
        #     files = json.load(f)["files"]
        img_files = []
        for file in files:
            if os.path.splitext(file)[1] in [".png", ".jpg"]:
                img_files.append(file)

        # scan user's input to see if it contains images' name
        do_extract_image = False
        target_file = None
        for file in img_files:
            img = os.path.splitext(file)[0]
            if img in message:
                do_extract_image = True
                target_file = file
                break

        # extract image to tables
        image_info = ""
        if do_extract_image:
            print("The user needs to read image from the local database. Extract image ... ")
            target_file = os.path.join(source_path, target_file)
            _, image_info = extract_image(target_file)
        if len(image_info)>0:
            image_content = {"content": image_info, "source": os.path.basename(target_file)}
        else:
            image_content = None

        print("Querying references from the local database...")
        contents = []
        try:
            if query_count > 0:
                docs = local_db.similarity_search(message, k=query_count)
                for i in range(query_count):
                    # pre-processing each chunk
                    content = docs[i].page_content.replace('\n', ' ')
                    # pre-process meta data
                    contents.append(content)
        except:
            print("Failed to query from the local database. ")
        # generate augmented_message
        print("Success in querying references: {}".format(contents))
        if image_content is not None:
            augmented_message =  f"{image_content}\n\n---\n\n" + "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
        else:
            augmented_message =  "\n\n---\n\n".join(contents) + "\n\n-----\n\n"
        return augmented_message + "\n\n" + f"'user_input': {message}"


    def respond(message, local_db, chat_history, meta, query_count=5, test_mode=False, response_delay=5, preprocessing=False):
        gpt_chatbot = OpenAIChatBot()
        print("Chat History: ", chat_history)
        print("Local DB: ", local_db is None)
        for chat in chat_history:
            gpt_chatbot.load_chat(chat)
        if local_db is None or query_count == 0:
            bot_message = gpt_chatbot(message)
            print(bot_message)
            print(message)
            chat_history.append((message, bot_message))
            return "", chat_history
        else:
            augmented_message = get_augmented_message(message, local_db, query_count, preprocessing, meta)
            bot_message = gpt_chatbot(augmented_message, original_message=message)
            print(message)
            print(augmented_message)
            print(bot_message)
            if test_mode:
                chat_history.append((augmented_message, bot_message))
            else:
                chat_history.append((message, bot_message))
            time.sleep(response_delay)  # sleep 5 seconds to avoid freq. wall.
            return "", chat_history

    with gr.Row():
        with gr.Column():
            gr.Markdown(INSTRUCTIONS)

            with gr.Row():
                with gr.Tab("从本地PDF文件创建知识库"):
                    zip_file = gr.File(file_types=[".zip"], label="本地PDF文件(.zip)")
                    create_db = gr.Button("创建知识库", variant="primary")
                    with gr.Accordion("高级设置", open=False):
                        embedding_selector = gr.Dropdown(ALL_EMBEDDINGS,
                                                         value="distilbert-dot-tas_b-b256-msmarco",
                                                         label="Embedding Models")
                        pdf_loader_selector = gr.Dropdown([loader.__name__ for loader in ALL_PDF_LOADERS],
                                                          value=PyPDFLoader.__name__, label="PDF Loader")
                        chunk_size_slider = gr.Slider(minimum=50, maximum=2000, step=50, value=500,
                                                      label="Chunk size (tokens)")
                        chunk_overlap_slider = gr.Slider(minimum=0, maximum=500, step=1, value=50,
                                                         label="Chunk overlap (tokens)")
                        save_to_cloud_checkbox = gr.Checkbox(value=False, label="把数据库上传到云端")


                    file_dp_output = gr.File(file_types=[".zip"], label="(输出)知识库文件(.zip)")
                with gr.Tab("读取本地知识库文件"):
                    file_local = gr.File(file_types=[".zip"], label="本地知识库文件(.zip)")
                    load_db = gr.Button("读取已创建知识库", variant="primary")

                with gr.Tab("总结图表(Demo)"):
                    gr.Markdown(r"代码来源于: https://huggingface.co/spaces/fl399/deplot_plus_llm")
                    input_image = gr.Image(label="Input Image", type="pil", interactive=True)
                    extract = gr.Button("总结", variant="primary")

                    output_text = gr.Textbox(lines=8, label="Output")




        with gr.Column():
            status = gr.Textbox(label="用来显示程序运行状态的Textbox")
            chatbot = gr.Chatbot()

            msg = gr.Textbox()
            submit = gr.Button("Submit", variant="primary")
            with gr.Accordion("高级设置", open=False):
                json_output = gr.JSON()
                with gr.Row():
                    query_count_slider = gr.Slider(minimum=0, maximum=10, step=1, value=3,
                                                  label="Query counts")
                    test_mode_checkbox = gr.Checkbox(label="Test mode")


    # def load_pdf_as_db(file_from_gradio,
    #                    pdf_loader,
    #                    embedding_model,
    #                    chunk_size=300,
    #                    chunk_overlap=20,
    #                    upload_to_cloud=True):
    msg.submit(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])
    submit.click(respond, [msg, local_db, chatbot, json_output, query_count_slider, test_mode_checkbox], [msg, chatbot])

    create_db.click(load_zip_as_db, [zip_file, pdf_loader_selector, embedding_selector, chunk_size_slider, chunk_overlap_slider, save_to_cloud_checkbox],
                    [status, file_dp_output, local_db, json_output])
    load_db.click(load_local_db, [file_local], [status, local_db])

    extract.click(describe, [input_image], [output_text])

demo.launch(show_api=False)