zeerd commited on
Commit
fdfcf53
·
verified ·
1 Parent(s): b2369fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from langchain_huggingface import HuggingFaceEndpoint
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
 
@@ -24,12 +24,7 @@ except Exception as e:
24
  def answer_question(repo_id, temperature, max_length, question):
25
  # 4. 初始化 Gemma 模型
26
  try:
27
- llm = HuggingFaceEndpoint(
28
- repo_id=repo_id,
29
- temperature=temperature,
30
- max_length=max_length,
31
- huggingfacehub_api_token=HUGGINGFACEHUB_API_TOKEN
32
- )
33
  except Exception as e:
34
  st.error(f"Gemma 模型加载失败:{e}")
35
  st.stop()
@@ -47,7 +42,7 @@ def answer_question(repo_id, temperature, max_length, question):
47
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
48
  print('prompt: ' + prompt)
49
 
50
- answer = llm.invoke(prompt)
51
  return answer
52
  except Exception as e:
53
  st.error(f"问答过程出错:{e}")
@@ -56,7 +51,7 @@ def answer_question(repo_id, temperature, max_length, question):
56
  # 6. Streamlit 界面
57
  st.title("Gemma 知识库问答系统")
58
 
59
- gemma = st.selectbox("repo-id", ("google/gemma-7b-it", "google/gemma-2b-it", "google/recurrentgemma-2b-it"), 2)
60
  temperature = st.number_input("temperature", value=1.0)
61
  max_length = st.number_input("max_length", value=1024)
62
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")
 
1
  import streamlit as st
2
+ from langchain_community.llms import HuggingFaceHub
3
  from langchain_community.embeddings import SentenceTransformerEmbeddings
4
  from langchain_community.vectorstores import FAISS
5
 
 
24
  def answer_question(repo_id, temperature, max_length, question):
25
  # 4. 初始化 Gemma 模型
26
  try:
27
+ llm = HuggingFaceHub(repo_id=repo_id, model_kwargs={"temperature": temperature, "max_length": max_length})
 
 
 
 
 
28
  except Exception as e:
29
  st.error(f"Gemma 模型加载失败:{e}")
30
  st.stop()
 
42
  prompt = f"请根据以下知识库回答问题:\n{context}\n问题:{question}"
43
  print('prompt: ' + prompt)
44
 
45
+ answer = llm(prompt)
46
  return answer
47
  except Exception as e:
48
  st.error(f"问答过程出错:{e}")
 
51
  # 6. Streamlit 界面
52
  st.title("Gemma 知识库问答系统")
53
 
54
+ gemma = st.selectbox("repo-id", ("google/gemma-2-9b-it", "google/gemma-2-2b-it", "google/recurrentgemma-2b-it"), 2)
55
  temperature = st.number_input("temperature", value=1.0)
56
  max_length = st.number_input("max_length", value=1024)
57
  question = st.text_area("请输入问题", "Gemma 有哪些特点?")