|
import streamlit as st |
|
from langchain_community.llms import HuggingFaceHub |
|
from langchain_community.embeddings import SentenceTransformerEmbeddings |
|
from langchain.vectorstores import FAISS |
|
import numpy as np |
|
|
|
|
|
gemma = 'google/recurrentgemma-2b-it'; |
|
|
|
|
|
try: |
|
llm = HuggingFaceHub(repo_id=gemma, model_kwargs={"temperature": 0.5, "max_length": 512}) |
|
except Exception as e: |
|
st.error(f"Gemma 模型加载失败:{e}") |
|
st.stop() |
|
|
|
|
|
knowledge_base = [ |
|
"Gemma 是 Google 开发的大型语言模型。", |
|
"Gemma 具有强大的自然语言处理能力。", |
|
"Gemma 可以用于问答、对话、文本生成等任务。", |
|
"Gemma 基于 Transformer 架构。", |
|
"Gemma 支持多种语言。" |
|
] |
|
|
|
|
|
try: |
|
embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2") |
|
db = FAISS.from_texts(knowledge_base, embeddings) |
|
except Exception as e: |
|
st.error(f"向量数据库构建失败:{e}") |
|
st.stop() |
|
|
|
|
|
def answer_question(question): |
|
try: |
|
question_embedding = embeddings.embed_query(question) |
|
question_embedding_np = " ".join(map(str, question_embedding)) |
|
docs_and_scores = db.similarity_search_with_score(question_embedding_np) |
|
|
|
context = "\n".join([doc.page_content for doc, _ in docs_and_scores]) |
|
print(context) |
|
|
|
prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}" |
|
print(prompt) |
|
|
|
answer = llm(prompt) |
|
return answer |
|
except Exception as e: |
|
st.error(f"问答过程出错:{e}") |
|
return "An error occurred during the answering process." |
|
|
|
|
|
st.title("Gemma 知识库问答系统") |
|
|
|
question = st.text_area("请输入问题", height=100) |
|
|
|
if st.button("提交"): |
|
if not question: |
|
st.warning("请输入问题!") |
|
else: |
|
with st.spinner("正在查询..."): |
|
answer = answer_question(question) |
|
st.write("答案:") |
|
st.write(answer) |