zeerd commited on
Commit
5b5abf5
·
verified ·
1 Parent(s): b6628bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -17
app.py CHANGED
@@ -2,12 +2,8 @@ import streamlit as st
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
- import numpy as np
6
 
7
- # gemma = 'google/gemma-7b-it';
8
- gemma = 'google/recurrentgemma-2b-it';
9
-
10
- # 2. 准备知识库数据 (示例)
11
  knowledge_base = [
12
  "Gemma 是 Google 开发的大型语言模型。",
13
  "Gemma 具有强大的自然语言处理能力。",
@@ -16,7 +12,7 @@ knowledge_base = [
16
  "Gemma 支持多种语言。"
17
  ]
18
 
19
- # 3. 构建向量数据库 (如果需要,仅构建一次)
20
  try:
21
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
22
  db = FAISS.from_texts(knowledge_base, embeddings)
@@ -24,24 +20,27 @@ except Exception as e:
24
  st.error(f"向量数据库构建失败:{e}")
25
  st.stop()
26
 
27
- # 4. 问答函数
28
  def answer_question(gemma, temperature, max_length, question):
29
- # 1. 初始化 Gemma 模型
30
  try:
31
  llm = HuggingFaceHub(repo_id=gemma, model_kwargs={"temperature": temperature, "max_length": max_length})
32
  except Exception as e:
33
  st.error(f"Gemma 模型加载失败:{e}")
34
  st.stop()
 
 
35
  try:
36
  question_embedding = embeddings.embed_query(question)
37
- question_embedding_np = " ".join(map(str, question_embedding))
38
- docs_and_scores = db.similarity_search_with_score(question_embedding_np)
39
- # 正确处理 docs_and_scores 列表
40
- context = "\n".join([doc.page_content for doc, _ in docs_and_scores]) # 使用 join() 方法
41
- print(context)
 
42
 
43
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
44
- print(prompt)
45
 
46
  answer = llm(prompt)
47
  return answer
@@ -53,8 +52,8 @@ def answer_question(gemma, temperature, max_length, question):
53
  st.title("Gemma 知识库问答系统")
54
 
55
  gemma = st.selectbox("模型", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
56
- temperature = st.text_area("temperature", "1.0")
57
- max_length = st.text_area("max_length", "1024")
58
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
59
 
60
  if st.button("提交"):
@@ -62,6 +61,6 @@ if st.button("提交"):
62
  st.warning("请输入问题!")
63
  else:
64
  with st.spinner("正在查询..."):
65
- answer = answer_question(gemma, temperature, max_length, question)
66
  st.write("答案:")
67
  st.write(answer)
 
2
  from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
 
5
 
6
+ # 1. 准备知识库数据 (示例)
 
 
 
7
  knowledge_base = [
8
  "Gemma 是 Google 开发的大型语言模型。",
9
  "Gemma 具有强大的自然语言处理能力。",
 
12
  "Gemma 支持多种语言。"
13
  ]
14
 
15
+ # 2. 构建向量数据库 (如果需要,仅构建一次)
16
  try:
17
  embeddings = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
18
  db = FAISS.from_texts(knowledge_base, embeddings)
 
20
  st.error(f"向量数据库构建失败:{e}")
21
  st.stop()
22
 
23
+ # 3. 问答函数
24
  def answer_question(gemma, temperature, max_length, question):
25
+ # 4. 初始化 Gemma 模型
26
  try:
27
  llm = HuggingFaceHub(repo_id=gemma, model_kwargs={"temperature": temperature, "max_length": max_length})
28
  except Exception as e:
29
  st.error(f"Gemma 模型加载失败:{e}")
30
  st.stop()
31
+
32
+ # 5. 获取答案
33
  try:
34
  question_embedding = embeddings.embed_query(question)
35
+ question_embedding_str = " ".join(map(str, question_embedding))
36
+ print('question_embedding: ' + question_embedding_str)
37
+ docs_and_scores = db.similarity_search_with_score(question_embedding_str)
38
+
39
+ context = "\n".join([doc.page_content for doc, _ in docs_and_scores])
40
+ print('context: ' + context)
41
 
42
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
43
+ print('prompt: ' + prompt)
44
 
45
  answer = llm(prompt)
46
  return answer
 
52
  st.title("Gemma 知识库问答系统")
53
 
54
  gemma = st.selectbox("模型", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
55
+ temperature = st.text_input("temperature", "1.0")
56
+ max_length = st.text_input("max_length", "1024")
57
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
58
 
59
  if st.button("提交"):
 
61
  st.warning("请输入问题!")
62
  else:
63
  with st.spinner("正在查询..."):
64
+ answer = answer_question(gemma, float(temperature), int(max_length), question)
65
  st.write("答案:")
66
  st.write(answer)