File size: 7,587 Bytes
6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 260a026 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 d067a6a 6c945f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
# -*-coding:utf-8 -*-
import gradio as gr
import os
import json
from glob import glob
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from build_index.run import process_files
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts import PromptTemplate
from langchain.chains.llm import LLMChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains import QAGenerationChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './output/'
def process(files, openai_api_key, max_tokens, model, n_sample):
"""
对文档处理进行摘要,构建问题,构建文档索引
"""
model = model[0]
os.environ['OPENAI_API_KEY'] = openai_api_key
print('Displaying uploading files ')
print(glob('/tmp/*'))
docs = process_files([i.name for i in files], model, max_tokens)
print('Display Faiss index')
print(glob('./output/*'))
question = get_question(docs, openai_api_key, max_tokens, n_sample)
summary = get_summary(docs, openai_api_key, max_tokens, n_sample)
return question, summary
def get_question(docs, openai_api_key, max_tokens, n_sample=5):
q_list = []
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens, temperature=0)
# 基于文档进行QA生成
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(MyTemplate['qa_sys_template']),
HumanMessagePromptTemplate.from_template(MyTemplate['qa_user_template']),
]
)
chain = QAGenerationChain.from_llm(llm, prompt=prompt)
print('Generating Question from template')
for i in range(n_sample):
qa = chain.run(docs[i].page_content)[0]
print(qa)
q_list.append(f"问题{i + 1}: {qa['question']}")
return '\n'.join(q_list)
def get_summary(docs, openai_api_key, max_tokens, n_sample=5, verbose=None):
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
# chain = load_summarize_chain(llm, chain_type="map_reduce")
# summary = chain.run(docs[:n_sample])
print('Generating Summary from tempalte')
map_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
combine_prompt = PromptTemplate(template=MyTemplate['summary_template'], input_variables=["text"])
map_chain = LLMChain(llm=llm, prompt=map_prompt, verbose=verbose)
reduce_chain = LLMChain(llm=llm, prompt=combine_prompt, verbose=verbose)
combine_document_chain = StuffDocumentsChain(
llm_chain=reduce_chain,
document_variable_name='text',
verbose=verbose,
)
chain = MapReduceDocumentsChain(
llm_chain=map_chain,
combine_document_chain=combine_document_chain,
document_variable_name='text',
collapse_document_chain=None,
verbose=verbose
)
summary = chain.run(docs[:n_sample])
print(summary)
return summary
def predict(inputs, openai_api_key, max_tokens, model, chat_counter, chatbot=[], history=[]):
model = model[0]
print(f"chat_counter - {chat_counter}")
print(f'Histroy - {history}') # History: Original Input and Output in flatten list
print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
history.append(inputs)
print(f'loading faiss store from {faiss_store}')
if model == 'openai':
docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
else:
docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key))
# 构建模板
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
messages_combine = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine}
)
result = chain({"query": inputs})
print(result)
result = result['result']
# 生成返回值
history.append(result)
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
chat_counter += 1
yield chat, history, chat_counter
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML("""<h1 align="center">🚀Your Doc Reader🚀</h1>""")
with gr.Column(elem_id="col_container"):
openai_api_key = gr.Textbox(type='password', label="输入 API Key")
with gr.Accordion("Parameters", open=True):
with gr.Row():
max_tokens = gr.Slider(minimum=100, maximum=2000, value=1000, step=100, interactive=True,
label="字数")
model = gr.CheckboxGroup(["cohere", "openai"])
chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
n_sample = gr.Slider(minimum=3, maximum=5, value=3, step=1, interactive=True,
label="问题数")
# 输入文件,进行摘要和问题生成
with gr.Row():
with gr.Column():
files = gr.File(file_count="multiple", file_types=[".pdf"], label='上传pdf文件')
run = gr.Button('研报解读')
with gr.Column():
summary = gr.Textbox(type='text', label="一眼看尽 - 文档概览")
question = gr.Textbox(type='text', label='推荐问题 - 问别的也行哟')
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="这篇文档是关于什么的", label="针对文档你有哪些问题?")
state = gr.State([])
with gr.Row():
clear = gr.Button("清空")
start = gr.Button("提问")
run.click(process, [files, openai_api_key, max_tokens, model, n_sample], [question, summary])
inputs.submit(predict,
[inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
[chatbot, state, chat_counter], )
start.click(predict,
[inputs, openai_api_key, max_tokens, model, chat_counter, chatbot, state],
[chatbot, state, chat_counter], )
# 每次对话结束都重置对话
clear.click(reset_textbox, [], [inputs], queue=False)
inputs.submit(reset_textbox, [], [inputs])
demo.queue().launch(debug=True)
|