File size: 19,337 Bytes
592acab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import os
import gradio as gr
import requests
from langchain_community.document_loaders import TextLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import numpy as np
import faiss
from collections import deque
from langchain_core.embeddings import Embeddings
import threading
import queue
from langchain_core.messages import HumanMessage, AIMessage
from sentence_transformers import SentenceTransformer
import pickle
import torch
from langchain_core.documents import Document
import time
from tqdm import tqdm

# 获取环境变量
os.environ["OPENROUTER_API_KEY"] = os.getenv("OPENROUTER_API_KEY", "")
if not os.environ["OPENROUTER_API_KEY"]:
    raise ValueError("OPENROUTER_API_KEY 未设置,请在环境变量中配置或在 .env 文件中添加")
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
if not SILICONFLOW_API_KEY:
    raise ValueError("SILICONFLOW_API_KEY 未设置,请在 Hugging Face Spaces 的 Settings > Secrets 中添加 SILICONFLOW_API_KEY")

# SiliconFlow API 配置
SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/rerank"  # 需根据实际文档确认

# 自定义 APIEmbeddings 类(使用 Hugging Face API 调用 BAAI/bge-m3)
class APIEmbeddings(Embeddings):
    def __init__(self, model_name="BAAI/bge-m3"):
        self.model_name = model_name
        self.api_key = os.getenv("HUGGINGFACE_API_KEY")
        if not self.api_key:
            raise ValueError("HUGGINGFACE_API_KEY 未设置,请在环境变量中配置或在 .env 文件中添加")

    def embed_documents(self, texts):
        embeddings = []
        batch_size = 1000  # 根据需要调整批次大小

        for i in tqdm(range(0, len(texts), batch_size), desc="生成嵌入进度"):
            batch_texts = texts[i:i + batch_size]
            batch_embeddings = self._request_embeddings(batch_texts)
            embeddings.extend(batch_embeddings)

        return embeddings

    def embed_query(self, text):
        query_embeddings = self._request_embeddings([text])
        return query_embeddings[0]

    def _request_embeddings(self, texts):
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json"
        }
        payload = {
            "inputs": texts,
            "model": self.model_name
        }

        response = requests.post("https://api-inference.huggingface.co/models/BAAI/bge-m3", headers=headers, json=payload)
        response.raise_for_status()

        return response.json()[0]["embedding"]

# 重排序函数,使用 SiliconFlow API 调用 BAAI/bge-reranker-v2-m3
def rerank_documents(query, documents, top_n=15):
    try:
        if not documents or not query:
            raise ValueError("查询或文档列表为空")
        
        # 提取文档内容和元数据,限制长度为 2048 字符
        doc_texts = [(doc.page_content[:2048].replace("\n", " ").strip(), doc.metadata.get("book", "未知来源")) for doc in documents[:50]]
        print(f"Query: {query[:100]}... (长度: {len(query)})")
        print(f"文档数量 (前50个): {len(doc_texts)}")
        for i, (doc, book) in enumerate(doc_texts[:5]):  # 仅打印前5个用于调试
            print(f"  Doc {i}: {doc[:100]}... (来源: {book})")
        
        # 构造 SiliconFlow API 请求
        headers = {
            "Authorization": f"Bearer {SILICONFLOW_API_KEY}",
            "Content-Type": "application/json"
        }
        payload = {
            "model": "BAAI/bge-reranker-v2-m3",
            "query": query,
            "documents": [text for text, _ in doc_texts],
            "top_n": top_n
        }
        
        start_time = time.time()
        response = requests.post(SILICONFLOW_API_URL, headers=headers, json=payload)
        response.raise_for_status()  # 检查请求是否成功
        rerank_time = time.time() - start_time
        print(f"重排序耗时: {rerank_time:.2f} 秒")
        
        # 解析 SiliconFlow API 响应
        result = response.json()
        print(f"SiliconFlow API 响应: {result}")
        
        # 验证返回结果
        if "results" not in result or not isinstance(result["results"], list):
            raise ValueError(f"SiliconFlow API 返回格式错误: {result}")
        
        # 构建重排序结果,修正键名为 "relevance_score"
        reranked_docs = []
        for res in result["results"]:
            if "index" not in res or "relevance_score" not in res:
                raise ValueError(f"SiliconFlow API 返回的条目格式错误: {res}")
            index = res["index"]
            score = res["relevance_score"]
            if index < len(documents):
                text, book = doc_texts[index]
                reranked_docs.append((Document(page_content=text, metadata={"book": book}), score))
        
        # 按得分排序并截取 top_n
        reranked_docs = sorted(reranked_docs, key=lambda x: x[1], reverse=True)[:top_n]
        print(f"重排序结果 (数量: {len(reranked_docs)}):")
        for i, (doc, score) in enumerate(reranked_docs):
            print(f"  Doc {i}: {doc.page_content[:100]}... (来源: {doc.metadata.get('book', '未知来源')}, 得分: {score:.4f})")
        
        return reranked_docs
    except Exception as e:
        error_msg = str(e)
        print(f"错误详情: {error_msg}")
        raise Exception(f"重排序失败: {error_msg}")

# 构建 HNSW 索引
def build_hnsw_index(knowledge_base_path, index_path):
    print("开始加载文档...")
    start_time = time.time()
    loader = DirectoryLoader(knowledge_base_path, glob="*.txt", loader_cls=lambda path: TextLoader(path, encoding="utf-8"), use_multithreading=False)
    documents = loader.load()
    load_time = time.time() - start_time
    print(f"加载完成,共 {len(documents)} 个文档,耗时 {load_time:.2f} 秒")
    
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    if not os.path.exists("chunks.pkl"):
        print("开始分片...")
        start_time = time.time()
        texts = []
        total_chars = 0
        total_bytes = 0
        for i, doc in enumerate(documents):
            doc_chunks = text_splitter.split_documents([doc])
            for chunk in doc_chunks:
                content = chunk.page_content
                file_path = chunk.metadata.get("source", "")
                book_name = os.path.basename(file_path).replace(".txt", "").replace("_", "·")
                texts.append(Document(page_content=content, metadata={"book": book_name or "未知来源"}))
                total_chars += len(content)
                total_bytes += len(content.encode('utf-8'))
            if i < 5:
                print(f"文件 {i} 字符数: {len(doc.page_content)}, 字节数: {len(doc.page_content.encode('utf-8'))}, 来源: {file_path}")
            if (i + 1) % 10 == 0:
                print(f"分片进度: 已处理 {i + 1}/{len(documents)} 个文件,当前分片总数: {len(texts)}")
        with open("chunks.pkl", "wb") as f:
            pickle.dump(texts, f)
        split_time = time.time() - start_time
        print(f"分片完成,共 {len(texts)} 个 chunk,总字符数: {total_chars},总字节数: {total_bytes},耗时 {split_time:.2f} 秒")
    else:
        with open("chunks.pkl", "rb") as f:
            texts = pickle.load(f)
        print(f"加载已有分片,共 {len(texts)} 个 chunk")
    
    if not os.path.exists("embeddings.npy"):
        print("开始生成嵌入(使用 BAAI/bge-m3 API,分批处理)...")
        embeddings = APIEmbeddings()
        texts_content = [text.page_content for text in texts]
        embeddings_array = embeddings.embed_documents(texts_content)
        if os.path.exists("embeddings_temp.npy"):
            os.remove("embeddings_temp.npy")
        print(f"嵌入生成完成,维度: {embeddings_array.shape}")
    else:
        embeddings_array = np.load("embeddings.npy")
        print(f"加载已有嵌入,维度: {embeddings_array.shape}")
    
    dimension = embeddings_array.shape[1]
    index = faiss.IndexHNSWFlat(dimension, 16)
    index.hnsw.efConstruction = 100
    print("开始构建 HNSW 索引...")
    
    batch_size = 5000
    total_vectors = embeddings_array.shape[0]
    for i in range(0, total_vectors, batch_size):
        batch = embeddings_array[i:i + batch_size]
        index.add(batch)
        print(f"索引构建进度: {min(i + batch_size, total_vectors)} / {total_vectors}")
    
    text_embeddings = [(text.page_content, embeddings_array[i]) for i, text in enumerate(texts)]
    vector_store = FAISS.from_embeddings(text_embeddings, embeddings, normalize_L2=True)
    vector_store.index = index
    vector_store.docstore._dict.clear()
    vector_store.index_to_docstore_id.clear()
    
    for i, text in enumerate(texts):
        doc_id = str(i)
        vector_store.docstore._dict[doc_id] = text
        vector_store.index_to_docstore_id[i] = doc_id
    
    print("开始保存索引...")
    vector_store.save_local(index_path)
    print(f"HNSW 索引已生成并保存到 '{index_path}'")
    return vector_store, texts

# 初始化嵌入模型
embeddings = APIEmbeddings(model_name="BAAI/bge-m3")
print("已初始化 BAAI/bge-m3 嵌入模型,通过 API 调用")

# 加载或生成索引
index_path = "faiss_index_hnsw_new"
knowledge_base_path = "knowledge_base"

if not os.path.exists(index_path):
    if os.path.exists(knowledge_base_path):
        print("检测到 knowledge_base,正在生成 HNSW 索引...")
        vector_store, all_documents = build_hnsw_index(knowledge_base_path, index_path)
    else:
        raise FileNotFoundError("未找到 'knowledge_base',请提供知识库数据")
else:
    vector_store = FAISS.load_local(index_path, embeddings=embeddings, allow_dangerous_deserialization=True)
    vector_store.index.hnsw.efSearch = 300
    print("已加载 HNSW 索引 'faiss_index_hnsw_new',efSearch 设置为 300")
    with open("chunks.pkl", "rb") as f:
        all_documents = pickle.load(f)
    book_counts = {}
    for doc in all_documents:
        book = doc.metadata.get("book", "未知来源")
        book_counts[book] = book_counts.get(book, 0) + 1
    print(f"all_documents 书籍分布: {book_counts}")

# 初始化 ChatOpenAI
llm = ChatOpenAI(
    model="deepseek/deepseek-r1:free",
    api_key=os.environ["OPENROUTER_API_KEY"],
    base_url="https://openrouter.ai/api/v1",
    timeout=60,
    temperature=0.3,
    max_tokens=130000,
    streaming=True
)

# 定义提示词模板
prompt_template = PromptTemplate(
    input_variables=["context", "question", "chat_history"],
    template="""  

    你是一个研究李敖的专家,根据用户提出的问题{question}、最近10轮对话历史{chat_history}以及从李敖相关书籍和评论中检索的至少10篇文本内容{context}回答问题。  

    在回答时,请注意以下几点:  

    - 结合李敖的写作风格和思想,筛选出与问题和对话历史最相关的检索内容,避免无关信息。  

    - 必须在回答中引用至少10篇不同的文本内容,引用格式为[引用: 文本序号],例如[引用: 1][引用: 2],并确保每篇文本在回答中都有明确使用。  

    - 在回答的末尾,必须以“引用文献”标题列出所有引用的文本序号及其内容摘要(每篇不超过50字)以及具体的书目信息(例如书名和章节),格式为:  

      - 引用文献:  

        1. [文本 1] 摘要... 出自:书名,第X页/章节。  

        2. [文本 2] 摘要... 出自:书名,第X页/章节。  

        (依此类推,至少10篇)  

    - 如果问题涉及李敖对某人或某事的评价,优先引用李敖的直接言论或文字,并说明出处。  

    - 回答应结构化、分段落,确保逻辑清晰,语言生动,类似李敖的犀利风格。  

    - 如果检索内容和历史不足以直接回答问题,可根据李敖的性格和观点推测其可能的看法,但需说明这是推测。  

    - 只能基于提供的知识库内容{context}和对话历史{chat_history}回答,不得引入外部信息。  

    - 对于列举类问题,控制在10个要点以内,并优先提供最相关项。  

    - 如果回答较长,结构化分段总结,分点作答控制在8个点以内。  

    - 对于客观类的问答,如果问题的答案非常简短,可以适当补充一到两句相关信息,以丰富内容。  

    - 你需要根据用户要求和回答内容选择合适、美观的回答格式,确保可读性强。  

    - 你的回答应该综合多个相关知识库内容来回答,不能重复引用一个知识库内容。  

    - 除非用户要求,否则你回答的语言需要和用户提问的语言保持一致。

    """
)

# 对话历史管理类
class ConversationHistory:
    def __init__(self, max_length=10):
        self.history = deque(maxlen=max_length)

    def add_turn(self, question, answer):
        self.history.append((question, answer))

    def get_history(self):
        return [(turn[0], turn[1]) for turn in self.history]

    def clear(self):
        self.history.clear()

# 用户会话状态类
class UserSession:
    def __init__(self):
        self.conversation = ConversationHistory()
        self.output_queue = queue.Queue()
        self.stop_flag = threading.Event()

# 生成回答的线程函数
def generate_answer_thread(question, session):
    stop_flag = session.stop_flag
    output_queue = session.output_queue
    conversation = session.conversation
    
    stop_flag.clear()
    try:
        history_list = conversation.get_history()
        history_text = "\n".join([f"问: {q}\n答: {a}" for q, a in history_list]) if history_list else ""
        query_with_context = f"{history_text}\n当前问题: {question}" if history_text else question
        
        # 1. 使用 BAAI/bge-m3 API 生成查询嵌入
        start_time = time.time()
        embeddings = APIEmbeddings()
        query_embedding = embeddings.embed_query(query_with_context)
        embed_time = time.time() - start_time
        output_queue.put(f"嵌入耗时 (BAAI/bge-m3 API): {embed_time:.2f} 秒\n")
        
        if stop_flag.is_set():
            output_queue.put("生成已停止")
            return
        
        # 2. 使用 FAISS HNSW 索引进行初始检索
        start_time = time.time()
        initial_docs_with_scores = vector_store.similarity_search_with_score(query_with_context, k=50)
        search_time = time.time() - start_time
        output_queue.put(f"初始检索数量: {len(initial_docs_with_scores)}\n检索耗时: {search_time:.2f} 秒\n")
        
        if stop_flag.is_set():
            output_queue.put("生成已停止")
            return
        
        initial_docs = [doc for doc, _ in initial_docs_with_scores]
        
        # 3. 使用 SiliconFlow 的 BAAI/bge-reranker-v2-m3 进行重排序
        start_time = time.time()
        reranked_docs_with_scores = rerank_documents(query_with_context, initial_docs, top_n=15)
        rerank_time = time.time() - start_time
        output_queue.put(f"重排序耗时 (BAAI/bge-reranker-v2-m3): {rerank_time:.2f} 秒\n")
        
        if stop_flag.is_set():
            output_queue.put("生成已停止")
            return
        
        # 调整 final_docs 数量,取前 10 篇
        final_docs = [doc for doc, _ in reranked_docs_with_scores][:10]
        if len(final_docs) < 10:
            output_queue.put(f"警告:仅检索到 {len(final_docs)} 篇文本,可能无法满足引用 10 篇的要求")
        
        # 构造 context,包含文本内容和书目信息
        context = "\n\n".join([f"[文本 {i+1}] {doc.page_content} (出处: {doc.metadata.get('book', '未知来源')})" for i, doc in enumerate(final_docs)])
        chat_history = [HumanMessage(content=q) if i % 2 == 0 else AIMessage(content=a) 
                        for i, (q, a) in enumerate(history_list)]
        prompt = prompt_template.format(context=context, question=question, chat_history=history_text)
        
        # 4. 使用 LLM 生成回答
        answer = ""
        start_time = time.time()
        for chunk in llm.stream([HumanMessage(content=prompt)]):
            if stop_flag.is_set():
                output_queue.put(answer + "\n\n(生成已停止)")
                return
            answer += chunk.content
            output_queue.put(answer)
        llm_time = time.time() - start_time
        output_queue.put(f"\nLLM 生成耗时: {llm_time:.2f} 秒")
        
        conversation.add_turn(question, answer)
        output_queue.put(answer)

    except Exception as e:
        output_queue.put(f"Error: {str(e)}")

# Gradio 接口函数
def answer_question(question, session_state):
    if session_state is None:
        session_state = UserSession()
    
    thread = threading.Thread(target=generate_answer_thread, args=(question, session_state))
    thread.start()
    
    while thread.is_alive() or not session_state.output_queue.empty():
        try:
            output = session_state.output_queue.get(timeout=0.1)
            yield output, session_state
        except queue.Empty:
            continue
    
    while not session_state.output_queue.empty():
        yield session_state.output_queue.get(), session_state

def stop_generation(session_state):
    if session_state is not None:
        session_state.stop_flag.set()
    return "生成已停止,正在中止..."

def clear_conversation():
    return "对话历史已清空,请开始新的对话。", UserSession()

# 创建 Gradio 界面
with gr.Blocks(title="AI李敖助手") as interface:
    gr.Markdown("### AI李敖助手")
    gr.Markdown("基于李敖163本相关书籍构建的知识库,支持上下文关联,记住最近10轮对话,输入问题以获取李敖风格的回答。")
    
    session_state = gr.State(value=None)
    
    with gr.Row():
        with gr.Column(scale=3):
            question_input = gr.Textbox(label="请输入您的问题", placeholder="输入您的问题...")
            submit_button = gr.Button("提交")
        with gr.Column(scale=1):
            clear_button = gr.Button("新建对话")
            stop_button = gr.Button("停止生成")
    
    output_text = gr.Textbox(label="回答", interactive=False)
    
    submit_button.click(fn=answer_question, inputs=[question_input, session_state], outputs=[output_text, session_state])
    clear_button.click(fn=clear_conversation, inputs=None, outputs=[output_text, session_state])
    stop_button.click(fn=stop_generation, inputs=[session_state], outputs=output_text)

# 启动应用
if __name__ == "__main__":
    interface.launch(share=True)