Kaihatsu_master / app.py
t2ag3's picture
Update app.py
ff3fafd verified
import gradio as gr
from langchain_core.vectorstores import InMemoryVectorStore
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
model_name = "llama-3.3-70b-versatile"
emb_model_name = "pkshatech/GLuCoSE-base-ja"
def load_vector_store(embedding_model_name, vector_store_file, k=4):
embeddings = HuggingFaceEmbeddings(model_name = embedding_model_name)
vector_store = InMemoryVectorStore.load(vector_store_file, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": k})
return retriever
def fetch_response(groq_api_key, user_input):
chat = ChatGroq(
api_key = groq_api_key,
model_name = model_name
)
system_prompt = (
"あなたは便利なアシスタントです。"
"マニュアルの内容から回答してください。"
"\n\n"
"{context}"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
# ドキュメントのリストを渡せるchainを作成
question_answer_chain = create_stuff_documents_chain(chat, prompt)
# RetrieverとQAチェーンを組み合わせてRAGチェーンを作成
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
response = rag_chain.invoke({"input": user_input})
return [response["answer"], response["context"][0], response["context"][1]]
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
retriever = load_vector_store(emb_model_name, "kaihatsu_vector_store", 4)
#retriever_rechunk = load_vector_store(emb_model_name, "kaihatsu_vector_store_rechunk", 4)
with gr.Blocks() as demo:
gr.Markdown('''# 「スマート農業技術の開発・供給に関する事業」マスター \n
「スマート農業技術の開発・供給に関する事業」に関して、公募要領や審査要領を参考にRAGを使って回答します。
''')
with gr.Row():
api_key = gr.Textbox(label="Groq API key")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="User Input")
submit = gr.Button("Submit")
answer = gr.Textbox(label="Answer")
with gr.Row():
with gr.Column():
source1 = gr.Textbox(label="回答ソース1")
with gr.Column():
source2 = gr.Textbox(label="回答ソース2")
submit.click(fetch_response, inputs=[api_key, user_input], outputs=[answer, source1, source2])
if __name__ == "__main__":
demo.launch()