aihuashanying commited on
Commit
239aab2
·
verified ·
1 Parent(s): d85f6de

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -297
app.py DELETED
@@ -1,297 +0,0 @@
1
- import os
2
- import gradio as gr
3
- import requests
4
- from langchain_community.document_loaders import TextLoader, DirectoryLoader
5
- from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.vectorstores import FAISS
7
- from langchain_openai import ChatOpenAI
8
- from langchain.prompts import PromptTemplate
9
- import numpy as np
10
- import faiss
11
- from collections import deque
12
- from langchain_core.embeddings import Embeddings
13
- import threading
14
- import queue
15
- from langchain_core.messages import HumanMessage, AIMessage
16
- from sentence_transformers import SentenceTransformer
17
- import pickle
18
- import torch
19
- import time
20
- from tqdm import tqdm
21
- import logging
22
-
23
- # 设置日志
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- # 获取环境变量
28
- os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
29
- if not os.environ["OPENROUTER_API_KEY"]:
30
- raise ValueError("OPENROUTER_API_KEY 未设置")
31
- SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
32
- if not SILICONFLOW_API_KEY:
33
- raise ValueError("SILICONFLOW_API_KEY 未设置")
34
-
35
- # SiliconFlow API 配置
36
- SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank"
37
-
38
- # 自定义嵌入类,优化查询缓存
39
- class SentenceTransformerEmbeddings(Embeddings):
40
- def __init__(self, model_name="BAAI/bge-m3"):
41
- device = "cuda" if torch.cuda.is_available() else "cpu"
42
- self.model = SentenceTransformer(model_name, device=device)
43
- self.batch_size = 32 # 减小批次大小以适应低内存
44
- self.query_cache = {}
45
- self.cache_lock = threading.Lock()
46
-
47
- def embed_documents(self, texts):
48
- embeddings_list = []
49
- batch_size = 1000 # 减小批次以降低内存压力
50
- total_chunks = len(texts)
51
- logger.info(f"生成嵌入,文档数: {total_chunks}")
52
- with torch.no_grad():
53
- for i in tqdm(range(0, total_chunks, batch_size), desc="生成嵌入"):
54
- batch_texts = [text.page_content for text in texts[i:i + batch_size]]
55
- batch_emb = self.model.encode(
56
- batch_texts,
57
- normalize_embeddings=True,
58
- batch_size=self.batch_size
59
- )
60
- embeddings_list.append(batch_emb)
61
- embeddings_array = np.vstack(embeddings_list)
62
- np.save("embeddings.npy", embeddings_array)
63
- return embeddings_array
64
-
65
- def embed_query(self, text):
66
- with self.cache_lock:
67
- if text in self.query_cache:
68
- return self.query_cache[text]
69
- with torch.no_grad():
70
- emb = self.model.encode([text], normalize_embeddings=True, batch_size=1)[0]
71
- with self.cache_lock:
72
- self.query_cache[text] = emb
73
- if len(self.query_cache) > 1000: # 限制缓存大小
74
- self.query_cache.pop(next(iter(self.query_cache)))
75
- return emb
76
-
77
- # 重排序函数
78
- def rerank_documents(query, documents, top_n=15):
79
- try:
80
- doc_texts = [(doc.page_content[:2048], doc.metadata.get("book", "未知来源")) for doc in documents[:50]]
81
- headers = {"Authorization": f"Bearer {SILICONFLOW_API_KEY}", "Content-Type": "application/json"}
82
- payload = {"model": "BAAI/bge-reranker-v2-m3", "query": query, "documents": [text for text, _ in doc_texts], "top_n": top_n}
83
- response = requests.post(SILICONFLOW_API_URL, headers=headers, json=payload)
84
- response.raise_for_status()
85
- result = response.json()
86
- reranked_docs = []
87
- for res in result["results"]:
88
- index = res["index"]
89
- score = res["relevance_score"]
90
- if index < len(documents):
91
- text, book = doc_texts[index]
92
- reranked_docs.append((documents[index], score))
93
- return sorted(reranked_docs, key=lambda x: x[1], reverse=True)[:top_n]
94
- except Exception as e:
95
- logger.error(f"重排序失败: {str(e)}")
96
- raise
97
-
98
- # 构建 HNSW 索引
99
- def build_hnsw_index(knowledge_base_path, index_path):
100
- loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"))
101
- documents = loader.load()
102
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
103
- texts = text_splitter.split_documents(documents)
104
- for i, doc in enumerate(texts):
105
- doc.metadata["book"] = os.path.basename(doc.metadata.get("source", "未知来源")).replace(".txt", "")
106
- embeddings_array = embeddings.embed_documents(texts)
107
- dimension = embeddings_array.shape[1]
108
- index = faiss.IndexHNSWFlat(dimension, 16)
109
- index.hnsw.efConstruction = 100
110
- index.add(embeddings_array)
111
- vector_store = FAISS.from_embeddings([(doc.page_content, embeddings_array[i]) for i, doc in enumerate(texts)], embeddings)
112
- vector_store.index = index
113
- vector_store.save_local(index_path)
114
- with open("chunks.pkl", "wb") as f:
115
- pickle.dump(texts, f)
116
- return vector_store, texts
117
-
118
- # 初始化嵌入模型和索引
119
- embeddings = SentenceTransformerEmbeddings()
120
- index_path = "faiss_index_hnsw_new"
121
- knowledge_base_path = "knowledge_base"
122
-
123
- if not os.path.exists(index_path):
124
- vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
125
- else:
126
- vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
127
- vector_store.index.hnsw.efSearch = 200 # 降低 efSearch 以提升速度
128
- with open("chunks.pkl", "rb") as f:
129
- all_documents = pickle.load(f)
130
-
131
- # 初始化 LLM
132
- llm = ChatOpenAI(
133
- model="deepseek/deepseek-r1:free",
134
- api_key=os.environ["OPENROUTER_API_KEY"],
135
- base_url="https://openrouter.ai/api/v1",
136
- timeout=60,
137
- temperature=0.3,
138
- max_tokens=130000,
139
- streaming=True
140
- )
141
-
142
- # 提示词模板
143
- prompt_template = PromptTemplate(
144
- input_variables=["context", "question", "chat_history"],
145
- template="""
146
- 你是一个研究李敖的专家,根据用户提出的问题{question}、最近10轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。
147
- 在回答时,请注意以下几点:
148
- - 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。
149
- - 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。
150
- - 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:
151
- - 引用文献:
152
- 1. [文本 1] 摘要... 出自:书名,第X页/章节。
153
- 2. [文本 2] 摘要... 出自:书名,第X页/章节。
154
- (依此类推,至少10篇)
155
- - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
156
- - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
157
- - 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
158
- - 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。
159
- - 对于列举类问题,控制在10个要点以内,并优先提供最相关项。
160
- - 如果回答较长,结构化分段总结,分点作答控制在8个点以内。
161
- - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
162
- - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
163
- - 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。
164
- - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
165
- """
166
- )
167
-
168
- # 对话历史管理
169
- class ConversationHistory:
170
- def __init__(self, max_length=10): # 减少历史轮数
171
- self.history = deque(maxlen=max_length)
172
-
173
- def add_turn(self, question, answer):
174
- self.history.append((question, answer))
175
-
176
- def get_history(self):
177
- return [(q, a) for q, a in self.history]
178
-
179
- # 用户会话状态
180
- class UserSession:
181
- def __init__(self):
182
- self.conversation = ConversationHistory()
183
- self.output_queue = queue.Queue()
184
- self.stop_flag = threading.Event()
185
-
186
- # 生成回答
187
- def generate_answer_thread(question, session):
188
- stop_flag = session.stop_flag
189
- output_queue = session.output_queue
190
- conversation = session.conversation
191
-
192
- stop_flag.clear()
193
- try:
194
- # 打印用户问题到控制台
195
- logger.info(f"用户问题: {question}")
196
-
197
- history_list = conversation.get_history()
198
- history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list[-5:]]) # 只用最后5轮
199
- query_with_context = f"{history_text}\n问题: {question}" if history_text else question
200
-
201
- # 异步生成查询嵌入
202
- embed_queue = queue.Queue()
203
- def embed_task():
204
- start = time.time()
205
- emb = embeddings.embed_query(query_with_context)
206
- embed_queue.put((emb, time.time() - start))
207
- embed_thread = threading.Thread(target=embed_task)
208
- embed_thread.start()
209
- embed_thread.join()
210
- query_embedding, embed_time = embed_queue.get()
211
- output_queue.put(f"嵌入耗时: {embed_time:.2f} 秒\n")
212
-
213
- if stop_flag.is_set():
214
- output_queue.put("生成已停止")
215
- return
216
-
217
- # 初始检索
218
- start = time.time()
219
- docs_with_scores = vector_store.similarity_search_with_score_by_vector(query_embedding, k=50)
220
- search_time = time.time() - start
221
- output_queue.put(f"检索耗时: {search_time:.2f} 秒\n")
222
-
223
- if stop_flag.is_set():
224
- output_queue.put("生成已停止")
225
- return
226
-
227
- initial_docs = [doc for doc, _ in docs_with_scores]
228
- reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs)
229
- final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]
230
-
231
- # 打印重排序结果到控制台
232
- logger.info("重排序结果(最终保留的片段及其得分):")
233
- for i, (doc, score) in enumerate(reranked_docs_with_scores[:10], 1):
234
- logger.info(f"片段 {i}:")
235
- logger.info(f" 内容: {doc.page_content[:100]}...") # 打印前100字符,避免过长
236
- logger.info(f" 来源: {doc.metadata.get('book', '未知来源')}")
237
- logger.info(f" 得分: {score:.4f}")
238
-
239
- context = "\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book')})" for i, doc in enumerate(final_docs)])
240
- prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
241
-
242
- answer = ""
243
- start = time.time()
244
- for chunk in llm.stream([HumanMessage(content=prompt)]):
245
- if stop_flag.is_set():
246
- output_queue.put(answer + "\n(生成已停止)")
247
- return
248
- answer += chunk.content
249
- output_queue.put(answer)
250
- output_queue.put(f"\n生成耗时: {time.time() - start:.2f} 秒")
251
-
252
- conversation.add_turn(question, answer)
253
- output_queue.put(answer)
254
-
255
- except Exception as e:
256
- output_queue.put(f"Error: {str(e)}")
257
-
258
- # Gradio 接口
259
- def answer_question(question, session_state):
260
- if session_state is None:
261
- session_state = UserSession()
262
-
263
- thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
264
- thread.start()
265
-
266
- while thread.is_alive() or not session_state.output_queue.empty():
267
- try:
268
- output = session_state.output_queue.get(timeout=0.1)
269
- yield output, session_state
270
- except queue.Empty:
271
- continue
272
-
273
- def stop_generation(session_state):
274
- if session_state:
275
- session_state.stop_flag.set()
276
- return "生成已停止"
277
-
278
- def clear_conversation():
279
- return "对话已清空", UserSession()
280
-
281
- # Gradio 界面
282
- with gr.Blocks(title="AI李敖助手") as interface:
283
- gr.Markdown("### AI李敖助手")
284
- gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近10轮对话,输入问题以获取李敖风格的回答。")
285
- session_state = gr.State(value=None)
286
- question_input = gr.Textbox(label="问题")
287
- submit_button = gr.Button("提交")
288
- clear_button = gr.Button("新建对话")
289
- stop_button = gr.Button("停止")
290
- output_text = gr.Textbox(label="回答", interactive=False)
291
-
292
- submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
293
- clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
294
- stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)
295
-
296
- if __name__ == "__main__":
297
- interface.launch(share=True)