File size: 3,048 Bytes
1eb1b46
df9f17a
 
 
 
 
 
 
1eb1b46
 
 
 
 
df9f17a
02e3c4c
1eb1b46
02e3c4c
 
 
 
 
1eb1b46
aafd261
df9f17a
 
 
 
 
 
 
 
 
 
1eb1b46
df9f17a
 
 
 
 
 
 
 
 
aafd261
1eb1b46
df9f17a
aafd261
1eb1b46
 
 
 
 
ff3fafd
aafd261
02e3c4c
df9f17a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02e3c4c
aafd261
1eb1b46
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""

model_name = "llama-3.3-70b-versatile"
emb_model_name = "pkshatech/GLuCoSE-base-ja"

def load_vector_store(embedding_model_name, vector_store_file, k=4):
    embeddings = HuggingFaceEmbeddings(model_name = embedding_model_name)
    vector_store = InMemoryVectorStore.load(vector_store_file, embeddings)
    retriever = vector_store.as_retriever(search_kwargs={"k": k})
    return retriever

def fetch_response(groq_api_key, user_input):
  chat = ChatGroq(
    api_key = groq_api_key,
    model_name = model_name
  )
  system_prompt = (
    "あなたは便利なアシスタントです。"
    "マニュアルの内容から回答してください。"
    "\n\n"
    "{context}"
    )

  prompt = ChatPromptTemplate.from_messages(
      [
          ("system", system_prompt),
           ("human", "{input}"),
          ]
      )
  # ドキュメントのリストを渡せるchainを作成
  question_answer_chain = create_stuff_documents_chain(chat, prompt)
  # RetrieverとQAチェーンを組み合わせてRAGチェーンを作成
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)

  response = rag_chain.invoke({"input": user_input})
  return [response["answer"], response["context"][0], response["context"][1]]


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
retriever = load_vector_store(emb_model_name, "kaihatsu_vector_store", 4)
#retriever_rechunk = load_vector_store(emb_model_name, "kaihatsu_vector_store_rechunk", 4)

with gr.Blocks() as demo:
    gr.Markdown('''# 「スマート農業技術の開発・供給に関する事業」マスター \n
                     「スマート農業技術の開発・供給に関する事業」に関して、公募要領や審査要領を参考にRAGを使って回答します。
                ''')
    with gr.Row():
        api_key = gr.Textbox(label="Groq API key")
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="User Input")
            submit = gr.Button("Submit")
            answer = gr.Textbox(label="Answer")
    with gr.Row():
        with gr.Column():
            source1 = gr.Textbox(label="回答ソース1")
        with gr.Column():
            source2 = gr.Textbox(label="回答ソース2")
            
    submit.click(fetch_response, inputs=[api_key, user_input], outputs=[answer, source1, source2])

if __name__ == "__main__":
    demo.launch()