File size: 6,283 Bytes
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260a026
6c945f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*-coding:utf-8 -*-
import gradio as gr
import os
import json
from glob import glob
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from build_index.run import process_files
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

from langchain.chains.summarize import load_summarize_chain
from langchain.chains import QAGenerationChain

# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './output/'


def process(files, openai_api_key, max_tokens, model, n_sample):
    """
    对文档处理进行摘要,构建问题,构建文档索引
    """
    model = model[0]
    os.environ['OPENAI_API_KEY'] = openai_api_key
    print('Displaying uploading files ')
    print(glob('/tmp/*'))
    docs = process_files([i.name for i in files], model, max_tokens)
    print('Display Faiss index')
    print(glob('./output/*'))
    question = get_question(docs, openai_api_key, max_tokens, n_sample)
    summary = get_summary(docs, openai_api_key, max_tokens, n_sample)
    return question, summary


def get_question(docs, openai_api_key, max_tokens, n_sample=5):
    q_list = []
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0)
    chain = QAGenerationChain.from_llm(llm)
    print('Generating Question from template')
    for i in range(n_sample):
        qa = chain.run(docs[i].page_content)[0]
        print(qa)
        q_list.append(f"问题{i+1}: {qa['question']}" )
    return '\n'.join(q_list)


def get_summary(docs, openai_api_key, max_tokens, n_sample=5):
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
    chain = load_summarize_chain(llm, chain_type="map_reduce")
    print('Generating Summary from tempalte')
    summary = chain.run(docs[:n_sample])
    print(summary)
    return summary


def predict(inputs, openai_api_key, max_tokens, model, chat_counter, chatbot=[], history=[]):
    model = model[0]
    print(f"chat_counter - {chat_counter}")
    print(f'Histroy - {history}')  # History: Original Input and Output in flatten list
    print(f'chatbot - {chatbot}')  # Chat Bot: 上一轮回复的[[user, AI]]

    history.append(inputs)
    print(f'loading faiss store from {faiss_store}')
    if model == 'openai':
        docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
    else:
        docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key))
    # 构建模板
    llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
    messages_combine = [
        SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
        HumanMessagePromptTemplate.from_template("{question}")
    ]
    p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
    messages_reduce = [
        SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
        HumanMessagePromptTemplate.from_template("{question}")
    ]
    p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
    chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
                                       k=4,
                                       chain_type_kwargs={"question_prompt": p_chat_reduce,
                                                          "combine_prompt": p_chat_combine}
                                       )
    result = chain({"query": inputs})
    print(result)
    result = result['result']
    # 生成返回值
    history.append(result)
    chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
    chat_counter += 1
    yield chat, history, chat_counter


def reset_textbox():
    return gr.update(value='')


with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 520px; overflow: auto;}""") as demo:
    gr.HTML("""<h1 align="center">🚀Your Doc Reader🚀</h1>""")
    with gr.Column(elem_id="col_container"):
        openai_api_key = gr.Textbox(type='password', label="输入 API Key")

        with gr.Accordion("Parameters", open=True):
            with gr.Row():
                max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True,
                                       label="字数")
                model = gr.CheckboxGroup(["cohere", "openai"])
                chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
                n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True,
                                     label="问题数")

        # 输入文件,进行摘要和问题生成
        with gr.Row():
            with gr.Column():
                files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件')
                run = gr.Button('研报解读')

            with gr.Column():
                summary = gr.Textbox(type='text', label="本文摘要")
                question = gr.Textbox(type='text', label='提问问题')

        chatbot = gr.Chatbot(elem_id='chatbot')
        inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?")
        state = gr.State([])

        with gr.Row():
            clear = gr.Button("清空")
            start = gr.Button("提问")

    run.click(process, [files, openai_api_key, max_tokens, model, n_sample], [question, summary])
    inputs.submit(predict,
                  [inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
                  [chatbot, state, chat_counter], )
    start.click(predict,
              [inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
              [chatbot, state, chat_counter], )

    # 每次对话结束都重置对话
    clear.click(reset_textbox, [], [inputs], queue=False)
    inputs.submit(reset_textbox, [], [inputs])

    demo.queue().launch(debug=True)