File size: 7,587 Bytes
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d067a6a
 
 
6c945f2
d067a6a
6c945f2
 
 
 
 
 
 
260a026
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d067a6a
 
 
 
 
 
 
 
6c945f2
 
 
 
d067a6a
6c945f2
 
 
d067a6a
6c945f2
d067a6a
 
6c945f2
d067a6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d067a6a
 
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
d067a6a
 
6c945f2
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# -*-coding:utf-8 -*-
import gradio as gr
import os
import json
from glob import glob
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from build_index.run import process_files
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains import QAGenerationChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain

# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './output/'


def process(files, openai_api_key, max_tokens, model, n_sample):
    """
    对文档处理进行摘要,构建问题,构建文档索引
    """
    model = model[0]
    os.environ['OPENAI_API_KEY'] = openai_api_key
    print('Displaying uploading files ')
    print(glob('/tmp/*'))
    docs = process_files([i.name for i in files], model, max_tokens)
    print('Display Faiss index')
    print(glob('./output/*'))
    question = get_question(docs, openai_api_key, max_tokens, n_sample)
    summary = get_summary(docs, openai_api_key, max_tokens, n_sample)
    return question, summary


def get_question(docs, openai_api_key, max_tokens, n_sample=5):
    q_list = []
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0)
    # 基于文档进行QA生成
    prompt = ChatPromptTemplate.from_messages(
        [
            SystemMessagePromptTemplate.from_template(MyTemplate['qa_sys_template']),
            HumanMessagePromptTemplate.from_template(MyTemplate['qa_user_template']),
        ]
    )
    chain = QAGenerationChain.from_llm(llm, prompt=prompt)
    print('Generating Question from template')
    for i in range(n_sample):
        qa = chain.run(docs[i].page_content)[0]
        print(qa)
        q_list.append(f"问题{i + 1}: {qa['question']}")
    return '\n'.join(q_list)


def get_summary(docs, openai_api_key, max_tokens, n_sample=5, verbose=None):
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
    # chain = load_summarize_chain(llm, chain_type="map_reduce")
    # summary = chain.run(docs[:n_sample])
    print('Generating Summary from tempalte')

    map_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
    combine_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
    map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
    reduce_chain = LLMChain(llm=llm, prompt=combine_prompt, verbose=verbose)
    combine_document_chain = StuffDocumentsChain(
        llm_chain=reduce_chain,
        document_variable_name='text',
        verbose=verbose,
    )
    chain = MapReduceDocumentsChain(
        llm_chain=map_chain,
        combine_document_chain=combine_document_chain,
        document_variable_name='text',
        collapse_document_chain=None,
        verbose=verbose
    )
    summary = chain.run(docs[:n_sample])
    print(summary)
    return summary


def predict(inputs, openai_api_key, max_tokens, model, chat_counter, chatbot=[], history=[]):
    model = model[0]
    print(f"chat_counter - {chat_counter}")
    print(f'Histroy - {history}')  # History: Original Input and Output in flatten list
    print(f'chatbot - {chatbot}')  # Chat Bot: 上一轮回复的[[user, AI]]

    history.append(inputs)
    print(f'loading faiss store from {faiss_store}')
    if model == 'openai':
        docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
    else:
        docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key))
    # 构建模板
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
    messages_combine = [
        SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
        HumanMessagePromptTemplate.from_template("{question}")
    ]
    p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
    messages_reduce = [
        SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
        HumanMessagePromptTemplate.from_template("{question}")
    ]
    p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
    chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
                                       k=4,
                                       chain_type_kwargs={"question_prompt": p_chat_reduce,
                                                          "combine_prompt": p_chat_combine}
                                       )
    result = chain({"query": inputs})
    print(result)
    result = result['result']
    # 生成返回值
    history.append(result)
    chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
    chat_counter += 1
    yield chat, history, chat_counter


def reset_textbox():
    return gr.update(value='')


with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 520px; overflow: auto;}""") as demo:
    gr.HTML("""<h1 align="center">🚀Your Doc Reader🚀</h1>""")
    with gr.Column(elem_id="col_container"):
        openai_api_key = gr.Textbox(type='password', label="输入 API Key")

        with gr.Accordion("Parameters", open=True):
            with gr.Row():
                max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True,
                                       label="字数")
                model = gr.CheckboxGroup(["cohere", "openai"])
                chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
                n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True,
                                     label="问题数")

        # 输入文件,进行摘要和问题生成
        with gr.Row():
            with gr.Column():
                files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件')
                run = gr.Button('研报解读')

            with gr.Column():
                summary = gr.Textbox(type='text', label="一眼看尽 - 文档概览")
                question = gr.Textbox(type='text', label='推荐问题 - 问别的也行哟')

        chatbot = gr.Chatbot(elem_id='chatbot')
        inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?")
        state = gr.State([])

        with gr.Row():
            clear = gr.Button("清空")
            start = gr.Button("提问")

    run.click(process, [files, openai_api_key, max_tokens, model, n_sample], [question, summary])
    inputs.submit(predict,
                  [inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
                  [chatbot, state, chat_counter], )
    start.click(predict,
                [inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
                [chatbot, state, chat_counter], )

    # 每次对话结束都重置对话
    clear.click(reset_textbox, [], [inputs], queue=False)
    inputs.submit(reset_textbox, [], [inputs])

    demo.queue().launch(debug=True)