Spaces:
Sleeping
Sleeping
File size: 19,337 Bytes
592acab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 |
import os
import gradio as gr
import requests
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import numpy as np
import faiss
from collections import deque
from langchain_core.embeddings import Embeddings
import threading
import queue
from langchain_core.messages import HumanMessage, AIMessage
from sentence_transformers import SentenceTransformer
import pickle
import torch
from langchain_core.documents import Document
import time
from tqdm import tqdm
# 获取环境变量
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
if not os.environ["OPENROUTER_API_KEY"]:
raise ValueError("OPENROUTER_API_KEY 未设置,请在环境变量中配置或在 .env 文件中添加")
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
raise ValueError("SILICONFLOW_API_KEY 未设置,请在 Hugging Face Spaces 的 Settings > Secrets 中添加 SILICONFLOW_API_KEY")
# SiliconFlow API 配置
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank" # 需根据实际文档确认
# 自定义 APIEmbeddings 类(使用 Hugging Face API 调用 BAAI/bge-m3)
class APIEmbeddings(Embeddings):
def __init__(self, model_name="BAAI/bge-m3"):
self.model_name = model_name
self.api_key = os.getenv("HUGGINGFACE_API_KEY")
if not self.api_key:
raise ValueError("HUGGINGFACE_API_KEY 未设置,请在环境变量中配置或在 .env 文件中添加")
def embed_documents(self, texts):
embeddings = []
batch_size = 1000 # 根据需要调整批次大小
for i in tqdm(range(0, len(texts), batch_size), desc="生成嵌入进度"):
batch_texts = texts[i:i + batch_size]
batch_embeddings = self._request_embeddings(batch_texts)
embeddings.extend(batch_embeddings)
return embeddings
def embed_query(self, text):
query_embeddings = self._request_embeddings([text])
return query_embeddings[0]
def _request_embeddings(self, texts):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"inputs": texts,
"model": self.model_name
}
response = requests.post("https://api-inference.huggingface.co/models/BAAI/bge-m3", headers=headers, json=payload)
response.raise_for_status()
return response.json()[0]["embedding"]
# 重排序函数,使用 SiliconFlow API 调用 BAAI/bge-reranker-v2-m3
def rerank_documents(query, documents, top_n=15):
try:
if not documents or not query:
raise ValueError("查询或文档列表为空")
# 提取文档内容和元数据,限制长度为 2048 字符
doc_texts = [(doc.page_content[:2048].replace("\n", " ").strip(), doc.metadata.get("book", "未知来源")) for doc in documents[:50]]
print(f"Query: {query[:100]}... (长度: {len(query)})")
print(f"文档数量 (前50个): {len(doc_texts)}")
for i, (doc, book) in enumerate(doc_texts[:5]): # 仅打印前5个用于调试
print(f" Doc {i}: {doc[:100]}... (来源: {book})")
# 构造 SiliconFlow API 请求
headers = {
"Authorization": f"Bearer {SILICONFLOW_API_KEY}",
"Content-Type": "application/json"
}
payload = {
"model": "BAAI/bge-reranker-v2-m3",
"query": query,
"documents": [text for text, _ in doc_texts],
"top_n": top_n
}
start_time = time.time()
response = requests.post(SILICONFLOW_API_URL, headers=headers, json=payload)
response.raise_for_status() # 检查请求是否成功
rerank_time = time.time() - start_time
print(f"重排序耗时: {rerank_time:.2f} 秒")
# 解析 SiliconFlow API 响应
result = response.json()
print(f"SiliconFlow API 响应: {result}")
# 验证返回结果
if "results" not in result or not isinstance(result["results"], list):
raise ValueError(f"SiliconFlow API 返回格式错误: {result}")
# 构建重排序结果,修正键名为 "relevance_score"
reranked_docs = []
for res in result["results"]:
if "index" not in res or "relevance_score" not in res:
raise ValueError(f"SiliconFlow API 返回的条目格式错误: {res}")
index = res["index"]
score = res["relevance_score"]
if index < len(documents):
text, book = doc_texts[index]
reranked_docs.append((Document(page_content=text, metadata={"book": book}), score))
# 按得分排序并截取 top_n
reranked_docs = sorted(reranked_docs, key=lambda x: x[1], reverse=True)[:top_n]
print(f"重排序结果 (数量: {len(reranked_docs)}):")
for i, (doc, score) in enumerate(reranked_docs):
print(f" Doc {i}: {doc.page_content[:100]}... (来源: {doc.metadata.get('book', '未知来源')}, 得分: {score:.4f})")
return reranked_docs
except Exception as e:
error_msg = str(e)
print(f"错误详情: {error_msg}")
raise Exception(f"重排序失败: {error_msg}")
# 构建 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
print("开始加载文档...")
start_time = time.time()
loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"), use_multithreading=False)
documents = loader.load()
load_time = time.time() - start_time
print(f"加载完成,共 {len(documents)} 个文档,耗时 {load_time:.2f} 秒")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
if not os.path.exists("chunks.pkl"):
print("开始分片...")
start_time = time.time()
texts = []
total_chars = 0
total_bytes = 0
for i, doc in enumerate(documents):
doc_chunks = text_splitter.split_documents([doc])
for chunk in doc_chunks:
content = chunk.page_content
file_path = chunk.metadata.get("source", "")
book_name = os.path.basename(file_path).replace(".txt", "").replace("_", "·")
texts.append(Document(page_content=content, metadata={"book": book_name or "未知来源"}))
total_chars += len(content)
total_bytes += len(content.encode('utf-8'))
if i < 5:
print(f"文件 {i} 字符数: {len(doc.page_content)}, 字节数: {len(doc.page_content.encode('utf-8'))}, 来源: {file_path}")
if (i + 1) % 10 == 0:
print(f"分片进度: 已处理 {i + 1}/{len(documents)} 个文件,当前分片总数: {len(texts)}")
with open("chunks.pkl", "wb") as f:
pickle.dump(texts, f)
split_time = time.time() - start_time
print(f"分片完成,共 {len(texts)} 个 chunk,总字符数: {total_chars},总字节数: {total_bytes},耗时 {split_time:.2f} 秒")
else:
with open("chunks.pkl", "rb") as f:
texts = pickle.load(f)
print(f"加载已有分片,共 {len(texts)} 个 chunk")
if not os.path.exists("embeddings.npy"):
print("开始生成嵌入(使用 BAAI/bge-m3 API,分批处理)...")
embeddings = APIEmbeddings()
texts_content = [text.page_content for text in texts]
embeddings_array = embeddings.embed_documents(texts_content)
if os.path.exists("embeddings_temp.npy"):
os.remove("embeddings_temp.npy")
print(f"嵌入生成完成,维度: {embeddings_array.shape}")
else:
embeddings_array = np.load("embeddings.npy")
print(f"加载已有嵌入,维度: {embeddings_array.shape}")
dimension = embeddings_array.shape[1]
index = faiss.IndexHNSWFlat(dimension, 16)
index.hnsw.efConstruction = 100
print("开始构建 HNSW 索引...")
batch_size = 5000
total_vectors = embeddings_array.shape[0]
for i in range(0, total_vectors, batch_size):
batch = embeddings_array[i:i + batch_size]
index.add(batch)
print(f"索引构建进度: {min(i + batch_size, total_vectors)} / {total_vectors}")
text_embeddings = [(text.page_content, embeddings_array[i]) for i, text in enumerate(texts)]
vector_store = FAISS.from_embeddings(text_embeddings, embeddings, normalize_L2=True)
vector_store.index = index
vector_store.docstore._dict.clear()
vector_store.index_to_docstore_id.clear()
for i, text in enumerate(texts):
doc_id = str(i)
vector_store.docstore._dict[doc_id] = text
vector_store.index_to_docstore_id[i] = doc_id
print("开始保存索引...")
vector_store.save_local(index_path)
print(f"HNSW 索引已生成并保存到 '{index_path}'")
return vector_store, texts
# 初始化嵌入模型
embeddings = APIEmbeddings(model_name="BAAI/bge-m3")
print("已初始化 BAAI/bge-m3 嵌入模型,通过 API 调用")
# 加载或生成索引
index_path = "faiss_index_hnsw_new"
knowledge_base_path = "knowledge_base"
if not os.path.exists(index_path):
if os.path.exists(knowledge_base_path):
print("检测到 knowledge_base,正在生成 HNSW 索引...")
vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
else:
raise FileNotFoundError("未找到 'knowledge_base',请提供知识库数据")
else:
vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
vector_store.index.hnsw.efSearch = 300
print("已加载 HNSW 索引 'faiss_index_hnsw_new',efSearch 设置为 300")
with open("chunks.pkl", "rb") as f:
all_documents = pickle.load(f)
book_counts = {}
for doc in all_documents:
book = doc.metadata.get("book", "未知来源")
book_counts[book] = book_counts.get(book, 0) + 1
print(f"all_documents 书籍分布: {book_counts}")
# 初始化 ChatOpenAI
llm = ChatOpenAI(
model="deepseek/deepseek-r1:free",
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1",
timeout=60,
temperature=0.3,
max_tokens=130000,
streaming=True
)
# 定义提示词模板
prompt_template = PromptTemplate(
input_variables=["context", "question", "chat_history"],
template="""
你是一个研究李敖的专家,根据用户提出的问题{question}、最近10轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。
在回答时,请注意以下几点:
- 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。
- 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。
- 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:
- 引用文献:
1. [文本 1] 摘要... 出自:书名,第X页/章节。
2. [文本 2] 摘要... 出自:书名,第X页/章节。
(依此类推,至少10篇)
- 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。
- 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。
- 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。
- 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。
- 对于列举类问题,控制在10个要点以内,并优先提供最相关项。
- 如果回答较长,结构化分段总结,分点作答控制在8个点以内。
- 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。
- 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。
- 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。
- 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。
"""
)
# 对话历史管理类
class ConversationHistory:
def __init__(self, max_length=10):
self.history = deque(maxlen=max_length)
def add_turn(self, question, answer):
self.history.append((question, answer))
def get_history(self):
return [(turn[0], turn[1]) for turn in self.history]
def clear(self):
self.history.clear()
# 用户会话状态类
class UserSession:
def __init__(self):
self.conversation = ConversationHistory()
self.output_queue = queue.Queue()
self.stop_flag = threading.Event()
# 生成回答的线程函数
def generate_answer_thread(question, session):
stop_flag = session.stop_flag
output_queue = session.output_queue
conversation = session.conversation
stop_flag.clear()
try:
history_list = conversation.get_history()
history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list]) if history_list else ""
query_with_context = f"{history_text}\n当前问题: {question}" if history_text else question
# 1. 使用 BAAI/bge-m3 API 生成查询嵌入
start_time = time.time()
embeddings = APIEmbeddings()
query_embedding = embeddings.embed_query(query_with_context)
embed_time = time.time() - start_time
output_queue.put(f"嵌入耗时 (BAAI/bge-m3 API): {embed_time:.2f} 秒\n")
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 2. 使用 FAISS HNSW 索引进行初始检索
start_time = time.time()
initial_docs_with_scores = vector_store.similarity_search_with_score(query_with_context, k=50)
search_time = time.time() - start_time
output_queue.put(f"初始检索数量: {len(initial_docs_with_scores)}\n检索耗时: {search_time:.2f} 秒\n")
if stop_flag.is_set():
output_queue.put("生成已停止")
return
initial_docs = [doc for doc, _ in initial_docs_with_scores]
# 3. 使用 SiliconFlow 的 BAAI/bge-reranker-v2-m3 进行重排序
start_time = time.time()
reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs, top_n=15)
rerank_time = time.time() - start_time
output_queue.put(f"重排序耗时 (BAAI/bge-reranker-v2-m3): {rerank_time:.2f} 秒\n")
if stop_flag.is_set():
output_queue.put("生成已停止")
return
# 调整 final_docs 数量,取前 10 篇
final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]
if len(final_docs) < 10:
output_queue.put(f"警告:仅检索到 {len(final_docs)} 篇文本,可能无法满足引用 10 篇的要求")
# 构造 context,包含文本内容和书目信息
context = "\n\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book', '未知来源')})" for i, doc in enumerate(final_docs)])
chat_history = [HumanMessage(content=q) if i % 2 == 0 else AIMessage(content=a)
for i, (q, a) in enumerate(history_list)]
prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
# 4. 使用 LLM 生成回答
answer = ""
start_time = time.time()
for chunk in llm.stream([HumanMessage(content=prompt)]):
if stop_flag.is_set():
output_queue.put(answer + "\n\n(生成已停止)")
return
answer += chunk.content
output_queue.put(answer)
llm_time = time.time() - start_time
output_queue.put(f"\nLLM 生成耗时: {llm_time:.2f} 秒")
conversation.add_turn(question, answer)
output_queue.put(answer)
except Exception as e:
output_queue.put(f"Error: {str(e)}")
# Gradio 接口函数
def answer_question(question, session_state):
if session_state is None:
session_state = UserSession()
thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
thread.start()
while thread.is_alive() or not session_state.output_queue.empty():
try:
output = session_state.output_queue.get(timeout=0.1)
yield output, session_state
except queue.Empty:
continue
while not session_state.output_queue.empty():
yield session_state.output_queue.get(), session_state
def stop_generation(session_state):
if session_state is not None:
session_state.stop_flag.set()
return "生成已停止,正在中止..."
def clear_conversation():
return "对话历史已清空,请开始新的对话。", UserSession()
# 创建 Gradio 界面
with gr.Blocks(title="AI李敖助手") as interface:
gr.Markdown("### AI李敖助手")
gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近10轮对话,输入问题以获取李敖风格的回答。")
session_state = gr.State(value=None)
with gr.Row():
with gr.Column(scale=3):
question_input = gr.Textbox(label="请输入您的问题", placeholder="输入您的问题...")
submit_button = gr.Button("提交")
with gr.Column(scale=1):
clear_button = gr.Button("新建对话")
stop_button = gr.Button("停止生成")
output_text = gr.Textbox(label="回答", interactive=False)
submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)
# 启动应用
if __name__ == "__main__":
interface.launch(share=True) |