# -*- coding: utf-8 -*- # 財政部財政資訊中心 江信宗 import os from dotenv import load_dotenv load_dotenv() from langchain_community.utils import user_agent from langchain_groq import ChatGroq from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain_community.document_loaders import WebBaseLoader, TextLoader from langchain.prompts import PromptTemplate from langchain.schema import Document import gradio as gr import re import time def initialize_llm(api_key): return ChatGroq( groq_api_key=api_key, model_name='llama-3.1-70b-versatile' ) print(f"成功初始化大型語言模型(LLM)") def load_documents(sources): documents = [] for source in sources: try: if isinstance(source, str): if source.startswith('http'): loader = WebBaseLoader(source) else: loader = TextLoader(source) documents.extend(loader.load()) elif isinstance(source, dict): documents.append(Document(page_content=source['content'], metadata=source.get('metadata', {}))) except Exception as e: print(f"Error loading source {source}: {str(e)}") return documents sources = [ "TaxQADataSet_Slim1.txt", "TaxQADataSet_Slim2.txt", "TaxQADataSet_Slim3.txt", "TaxQADataSet_Slim4.txt", "TaxQADataSet_Slim5.txt", "TaxQADataSet_Slim6.txt", "TaxQADataSet_ntpc1.txt", "TaxQADataSet_ntpc2.txt", "TaxQADataSet_kctax.txt", "TaxQADataSet_chutax.txt", "LandTaxAct1100623.txt", "TheEnforcementRulesoftheLandTaxAct1100923.txt", "HouseTaxAct1130103.txt", "VehicleLicenseTaxAct1101230.txt", "TaxCollectionAct1101217.txt", "StampTaxAct910515.txt", "DeedTaxAct990505.txt", "AmusementTaxAct960523.txt" ] documents = load_documents(sources) print(f"\n成功載入 {len(documents)} 個網址或檔案") text_splitter = RecursiveCharacterTextSplitter( chunk_size=512, chunk_overlap=52, length_function=len, is_separator_regex=False, separators=["\n\n\n","\n\n", "\n", "。"] ) split_docs = text_splitter.split_documents(documents) print(f"分割後的文件數量:{len(split_docs)}") embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large") print(f"\n成功初始化微軟嵌入模型") print(f"\n開始建立向量資料庫") vectorstore = Chroma.from_documents(split_docs, embeddings, persist_directory="./Knowledge-base") print(f"成功建立 Chroma 向量資料庫") retriever = vectorstore.as_retriever( search_type="mmr", search_kwargs={ "k": 4, "fetch_k": 20, "lambda_mult": 0.8 } ) print(f"成功建立檢索器,搜尋演算法:Maximum Marginal Relevance Retrieval") template = """Let's work this out in a step by step way to be sure we have the right answer. Must reply to me in Taiwanese Traditional Chinese. 在回答之前,請仔細分析檢索到的上下文,確保你的回答準確完整反映了上下文中的訊息,而不是依賴先前的知識,在回應的答案中絕對不要提到是根據上下文回答。 如果檢索到的多個上下文之間存在聯繫,請整合這些訊息以提供更全面的回答,但要避免過度推斷。 如果檢索到的上下文不包含足夠回答問題的訊息,請誠實的說明,不要試圖編造答案。 上下文: {context} 問題: {question} 答案:""" PROMPT = PromptTemplate( template=template, input_variables=["context", "question"] ) print(f"成功定義 Prompt Template") def create_chain(llm): return RetrievalQA.from_chain_type( llm=llm, chain_type="stuff", retriever=retriever, return_source_documents=True, chain_type_kwargs={"prompt": PROMPT} ) print(f"成功建立 RAG Chain") def generate_insight_questions(query, api_key): llm = initialize_llm(api_key) prompt = f"""Let's work this out in a step by step way to be sure we have the right answer. Must reply to me in Taiwanese Traditional Chinese. 根據以下回答,生成3個相關的洞察問題: 原始問題: {query} 請提供3個簡短但有深度的問題,這些問題應該符合: 1. 與原始問題緊密相關 2. 重新準確描述原始問題 3. 引導更深入的解決原始問題 請直接列出這3個問題,每個問題一行,不要添加編號或其他文字。 """ try: response = llm.invoke(prompt) if hasattr(response, 'content'): questions = response.content.split('\n') else: questions = str(response).split('\n') while len(questions) < 3: questions.append("提供更多地方稅資訊") return questions[:3] except Exception as e: print(f"Error generating insight questions: {str(e)}") return ["提供更多地方稅資訊", "提供其他地方稅問題", "還想了解什麼地方稅目"] def answer_question(query, api_key): try: gr.Info("檢索地方稅知識庫中,請稍待片刻......") llm = initialize_llm(api_key) chain = create_chain(llm) result = chain.invoke({"query": query}) answer = result["result"] insight_questions = generate_insight_questions(query, api_key) while len(insight_questions) < 3: insight_questions.append("提供更多地方稅資訊") return answer, insight_questions[:3] except Exception as e: return f"抱歉,處理您的問題時發生錯誤:{str(e)}", [] def split_questions(query): questions = re.split(r'[?!。 ]', query) return [q.strip() for q in questions if q.strip()] def answer_multiple_questions(query, api_key): questions = split_questions(query) all_answers = [] all_insight_questions = [] for question in questions: answer, insight_questions = answer_question(question, api_key) if len(questions) > 1: all_answers.append(f"【問題】{question}\n答案:{answer}") else: all_answers.append(answer) all_insight_questions.extend(insight_questions) if len(questions) > 1: combined_answer = "\n\n\n".join(all_answers) else: combined_answer = "\n".join(all_answers) selected_insight_questions = all_insight_questions[:3] return combined_answer, selected_insight_questions def convert_punctuation(text): return text.replace('?', '?').replace(',', ',').replace('!', '!').replace(' ', ' ') def handle_interaction(query, api_key, state): gr.Info("開始處理問題,請稍待片刻......") start_time = time.time() if state is None: state = {"history": []} if not api_key: api_key = os.getenv("YOUR_API_KEY") query = convert_punctuation(query) answer, insight_questions = answer_multiple_questions(query, api_key) state["history"].append((query, answer)) while len(insight_questions) < 3: insight_questions.append("提供更多地方稅資訊") end_time = time.time() gr.Info(f"Model 已完成回覆,總執行時間: {(end_time - start_time):.2f} 秒。") return answer, insight_questions[0], insight_questions[1], insight_questions[2], state, query custom_css = """ .query-input { background-color: #B7E0FF !important; padding: 15px !important; border-radius: 10px !important; margin: 0 !important; } .query-input textarea { font-size: 18px !important; background-color: #ffffff; border: 1px solid #f0f8ff; border-radius: 8px; } .answer-box { background-color: #FFF5CD !important; padding: 10px !important; border-radius: 10px !important; margin: 0 !important; } .answer-box textarea { font-size: 18px !important; background-color: #ffffff; border: 1px solid #f0f8ff; border-radius: 8px; } .center-text { text-align: center !important; color: #ff4081; text-shadow: 2px 2px 4px rgba(0,0,0,0.1); margin-bottom: 0 !important; } #submit-btn { border-radius: 10px !important; background-color: #ff4081 !important; color: white !important; font-weight: bold !important; transition: all 0.3s ease !important; margin: 0 !important; } #submit-btn:hover { background-color: #f50057 !important; transform: scale(1.05); } .insight-btn { border-radius: 10px !important; background-color: #00bcd4 !important; } .insight-btn:hover { background-color: #00acc1 !important; } .gr-form { background-color: #e8f5e9 !important; padding: 15px !important; border-radius: 10px !important; } .api-key-input { background-color: #FFCFB3 !important; padding: 15px !important; border-radius: 10px !important; margin: 0 !important; } .clear-button { color: white !important; background-color: #000000 !important; padding: 5px !important; border-radius: 10px !important; margin: 0 !important; } .clear-button:hover { background-color: #000000 !important; transform: scale(1.05); } """ with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as iface: gr.Markdown(""" # 地方稅知識庫系統 - 財政部財政資訊中心 > ### **※ RAG-based 系統部署:江信宗,LLM:Llama-3.1-70B,以地方稅極少知識資料示範,僅供參考,準確資訊請依據地方稅稽徵機關回覆為準。** """, elem_classes="center-text") with gr.Row(): query_input = gr.Textbox(label="輸入您的問題,系統將基於學習到的知識資料提供相關答案。", placeholder="請輸入您的問題(支援同時輸入多個問題,例如:問題1?問題2?)", scale=3, elem_classes="query-input") api_key_input = gr.Textbox(label="請輸入您的 API Key", type="password", placeholder="API authentication key", scale=1, elem_classes="api-key-input") answer_output = gr.Textbox(label="答案:", max_lines=40, elem_classes="answer-box") with gr.Row(): insight_q1 = gr.Button("洞察問題 1", visible=False, elem_classes=["insight-btn"]) insight_q2 = gr.Button("洞察問題 2", visible=False, elem_classes=["insight-btn"]) insight_q3 = gr.Button("洞察問題 3", visible=False, elem_classes=["insight-btn"]) state = gr.State() current_question = gr.Textbox(lines=2, label="當前問題", visible=False) with gr.Row(): submit_btn = gr.Button("傳送", variant="primary", scale=3, elem_id="submit-btn") clear_button = gr.Button("清除", variant="secondary", scale=1, elem_classes="clear-button") def update_ui(answer, q1, q2, q3, state, current_q): return [ answer, gr.update(value=q1, visible=bool(q1)), gr.update(value=q2, visible=bool(q2)), gr.update(value=q3, visible=bool(q3)), state, current_q ] submit_btn.click( fn=handle_interaction, inputs=[query_input, api_key_input, state], outputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question] ).then( fn=update_ui, inputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question], outputs=[answer_output, insight_q1, insight_q2, insight_q3, state, current_question] ) for btn in [insight_q1, insight_q2, insight_q3]: btn.click( lambda x: x, inputs=[btn], outputs=[query_input] ) def clear_outputs(): return "", "" clear_button.click( fn=clear_outputs, inputs=[], outputs=[query_input, answer_output] ) if __name__ == "__main__": if "SPACE_ID" in os.environ: iface.launch() else: iface.launch(share=True, show_api=False)