t2ag3 commited on
Commit
02e3c4c
·
verified ·
1 Parent(s): 882f365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -12,16 +12,15 @@ For more information on `huggingface_hub` Inference API support, please check th
12
  """
13
 
14
  model_name = "llama-3.3-70b-versatile"
15
- embeddings = HuggingFaceEmbeddings(
16
- model_name = "pkshatech/GLuCoSE-base-ja"
17
- )
18
- vector_store = InMemoryVectorStore.load(
19
- "kaihatsu_vector_store", embeddings
20
- )
21
- retriever = vector_store.as_retriever(search_kwargs={"k": 4})
22
 
 
 
 
 
 
23
 
24
- def fetch_response(groq_api_key, user_input):
25
  chat = ChatGroq(
26
  api_key = groq_api_key,
27
  model_name = model_name
@@ -42,15 +41,20 @@ def fetch_response(groq_api_key, user_input):
42
  # ドキュメントのリストを渡せるchainを作成
43
  question_answer_chain = create_stuff_documents_chain(chat, prompt)
44
  # RetrieverとQAチェーンを組み合わせてRAGチェーンを作成
45
- rag_chain = create_retrieval_chain(retriever, question_answer_chain)
 
46
 
47
  response = rag_chain.invoke({"input": user_input})
48
- return [response["answer"], response["context"][0], response["context"][1]]
 
49
 
50
 
51
  """
52
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
53
  """
 
 
 
54
  with gr.Blocks() as demo:
55
  gr.Markdown('''# 「スマート農業技術の開発・供給に関する事業」マスター \n
56
  「スマート農業技術の開発・供給に関する事業」に関して、公募要領や審査要領を参考にRAGを使って回答します。
@@ -62,12 +66,16 @@ with gr.Blocks() as demo:
62
  user_input = gr.Textbox(label="User Input")
63
  submit = gr.Button("Submit")
64
  answer = gr.Textbox(label="Answer")
 
65
  with gr.Row():
66
  with gr.Column():
67
  source1 = gr.Textbox(label="回答ソース1")
 
68
  with gr.Column():
69
  source2 = gr.Textbox(label="回答ソース2")
70
- submit.click(fetch_response, inputs=[api_key, user_input], outputs=[answer, source1, source2])
 
 
71
 
72
  if __name__ == "__main__":
73
  demo.launch()
 
12
  """
13
 
14
  model_name = "llama-3.3-70b-versatile"
15
+ emb_model_name = "pkshatech/GLuCoSE-base-ja"
 
 
 
 
 
 
16
 
17
+ def load_vector_store(embedding_model_name, vector_store_file, k=4):
18
+ embeddings = HuggingFaceEmbeddings(model_name = embedding_model_name)
19
+ vector_store = InMemoryVectorStore.load(vector_store_file, embeddings)
20
+ retriever = vector_store.as_retriever(search_kwargs={"k": k})
21
+ return retriever
22
 
23
+ def fetch_response(groq_api_key, user_input, retriever1, retriever2):
24
  chat = ChatGroq(
25
  api_key = groq_api_key,
26
  model_name = model_name
 
41
  # ドキュメントのリストを渡せるchainを作成
42
  question_answer_chain = create_stuff_documents_chain(chat, prompt)
43
  # RetrieverとQAチェーンを組み合わせてRAGチェーンを作成
44
+ rag_chain = create_retrieval_chain(retriever1, question_answer_chain)
45
+ rag_chain2 = create_retrieval_chain(retriever2, question_answer_chain)
46
 
47
  response = rag_chain.invoke({"input": user_input})
48
+ response2 = rag_chain.invoke({"input": user_input})
49
+ return [response["answer"], response["context"][0], response["context"][1], response2["answer"], response2["context"][0], response2["context"][1]]
50
 
51
 
52
  """
53
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
54
  """
55
+ retriever = load_vector_store(emb_model_name, "kaihatsu_vector_store", 4)
56
+ retriever_rechunk = load_vector_store(emb_model_name, "kaihatsu_vector_store_rechunk", 4)
57
+
58
  with gr.Blocks() as demo:
59
  gr.Markdown('''# 「スマート農業技術の開発・供給に関する事業」マスター \n
60
  「スマート農業技術の開発・供給に関する事業」に関して、公募要領や審査要領を参考にRAGを使って回答します。
 
66
  user_input = gr.Textbox(label="User Input")
67
  submit = gr.Button("Submit")
68
  answer = gr.Textbox(label="Answer")
69
+ answer2 = gr.Textbox(label="Answer")
70
  with gr.Row():
71
  with gr.Column():
72
  source1 = gr.Textbox(label="回答ソース1")
73
+ source2_1 = gr.Textbox(label="回答ソース1")
74
  with gr.Column():
75
  source2 = gr.Textbox(label="回答ソース2")
76
+ source2_2 = gr.Textbox(label="回答ソース2")
77
+
78
+ submit.click(fetch_response, inputs=[api_key, user_input, retriever, retriever_rechunk], outputs=[answer, source1, source2, answer2, source2_1, source2_2])
79
 
80
  if __name__ == "__main__":
81
  demo.launch()