Spaces:
Runtime error
Runtime error
from threading import Thread | |
from huggingface_hub import hf_hub_download | |
import torch | |
import gradio as gr | |
import re | |
import asyncio | |
import requests | |
import shutil | |
from langchain import PromptTemplate, LLMChain | |
from langchain.retrievers.document_compressors import EmbeddingsFilter | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.chains import RetrievalQA | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import os | |
from langchain.llms import OpenAI | |
llm = OpenAI(model_name='gpt-3.5-turbo-instruct') | |
torch_device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Running on device:", torch_device) | |
print("CPU threads:", torch.get_num_threads()) | |
loader = PyPDFLoader("total.pdf") | |
pages = loader.load() | |
# λ°μ΄ν°λ₯Ό λΆλ¬μμ ν μ€νΈλ₯Ό μΌμ ν μλ‘ λλκ³ κ΅¬λΆμλ‘ μ°κ²°νλ μμ | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=0) | |
texts = text_splitter.split_documents(pages) | |
print(f"λ¬Έμμ {len(texts)}κ°μ λ¬Έμλ₯Ό κ°μ§κ³ μμ΅λλ€.") | |
# μλ² λ© λͺ¨λΈ λ‘λ | |
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large") | |
# λ¬Έμμ μλ ν μ€νΈλ₯Ό μλ² λ©νκ³ FAISS μ μΈλ±μ€λ₯Ό ꡬμΆν¨ | |
index = FAISS.from_documents( | |
documents=texts, | |
embedding=embeddings, | |
) | |
# faiss_db λ‘ λ‘컬μ μ μ₯νκΈ° | |
index.save_local("") | |
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ° | |
docsearch = FAISS.load_local("", embeddings) | |
embeddings_filter = EmbeddingsFilter( | |
embeddings=embeddings, | |
similarity_threshold=0.7, | |
k = 3, | |
) | |
# μμΆ κ²μκΈ° μμ± | |
compression_retriever = ContextualCompressionRetriever( | |
# embeddings_filter μ€μ | |
base_compressor=embeddings_filter, | |
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν μ€νΈλ₯Ό μ°Ύμ | |
base_retriever=docsearch.as_retriever() | |
) | |
id_list = [] | |
history = [] | |
customer_data_list = [] | |
customer_agree_list = [] | |
context = "{context}" | |
question = "{question}" | |
def gen(x, id, customer_data): | |
index = 0 | |
matched = 0 | |
count = 0 | |
for s in id_list: | |
if s == id: | |
matched = 1 | |
break; | |
index += 1 | |
if matched == 0: | |
index = len(id_list) | |
id_list.append(id) | |
customer_data_list.append(customer_data) | |
if x != "μ½κ΄λμ_λμν¨": | |
customer_agree_list.append("No") | |
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n') | |
bot_str = "* νμ¬ κ°μ μ 보λ₯Ό μ‘°νν μ μμ΅λλ€. λ¨Όμ κ°μΈμ 보 μ΄μ© μ½κ΄μ λμνμ μΌ μνν μλ΄μ μ§νν μ μμ΅λλ€. \n무μμ λμλ릴κΉμ?" | |
else: | |
customer_agree_list.append("Yes") | |
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n') | |
bot_str = f"κ°μΈμ 보 νμ©μ λμνμ ¨μ΅λλ€. κ°μ 보νμ μ‘°νν©λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
if x == "μ΄κΈ°ν": | |
if customer_agree_list[index] != "No": | |
customer_data_list[index] = customer_data | |
bot_str = f"λνκΈ°λ‘μ΄ λͺ¨λ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
customer_data_list[index] = "κ°μ μ 보μμ" | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"λνκΈ°λ‘μ΄ λͺ¨λ μ΄κΈ°νλμμ΅λλ€.\n\n* νμ¬ κ°μ μ 보λ₯Ό μ‘°νν μ μμ΅λλ€. λ¨Όμ κ°μΈμ 보 μ΄μ© μ½κ΄μ λμνμ μΌ μνν μλ΄μ μ§νν μ μμ΅λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
elif x == "κ°μ μ 보": | |
if customer_agree_list[index] == "No": | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"* νμ¬ κ°μ μ 보λ₯Ό μ‘°νν μ μμ΅λλ€. λ¨Όμ κ°μΈμ 보 μ΄μ© μ½κ΄μ λμνμ μΌ μνν μλ΄μ μ§νν μ μμ΅λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data_list[index]}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
elif x == "μ½κ΄λμ_λμν¨": | |
if customer_agree_list[index] == "No": | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
customer_agree_list[index] = "Yes" | |
customer_data_list[index] = customer_data | |
bot_str = f"κ°μΈμ 보 νμ©μ λμνμ ¨μ΅λλ€. κ°μ 보νμ μ‘°νν©λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ λ 보νμ {customer_data}μ λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"μ΄λ―Έ μ½κ΄μ λμνμ ¨μ΅λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
elif x == "μ½κ΄λμ_λμμν¨": | |
if customer_agree_list[index] == "Yes": | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
customer_agree_list[index] = "No" | |
customer_data_list[index] = "κ°μ μ 보μμ" | |
bot_str = f"* κ°μΈμ 보 νμ© λμλ₯Ό μ·¨μνμ ¨μ΅λλ€. μ΄μ κ°μ 보νμ μ‘°νν μ μμ΅λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n' | |
bot_str = f"* κ°μΈμ 보 νμ©μ κ±°μ νμ ¨μ΅λλ€. κ°μ 보νμ μ‘°νν μ μμ΅λλ€. \n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?" | |
return bot_str | |
else: | |
context = "{context}" | |
question = "{question}" | |
if customer_agree_list[index] == "No": | |
customer_data_newline = "νμ¬ κ°μ μ 보λ₯Ό μ‘°νν μ μμ΅λλ€. μ½κ΄ λμκ° νμνλ€κ³ μλ΄ν΄μ£ΌμΈμ." | |
else: | |
customer_data_newline = customer_data_list[index].replace(",","\n") | |
prompt_template = f"""λΉμ μ 보ν μλ΄μμ λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ. | |
[μ 체 보ν λͺ©λ‘] | |
λΌμ΄ννλλμ 기보νβ ‘ | |
λΌμ΄ννλλμ’ μ 보ν | |
λΌμ΄ννλλμν΄λ³΄ν | |
λ§κΈ°κΉμ§λΉκ°±μ μ보νβ ‘ | |
λΌμ΄ννλλμ보νβ ’ | |
μΒ·λΒ·μ¬μ₯건κ°λ³΄ν | |
λΒ·μ¬μ₯건κ°λ³΄ν | |
μ¬μ±κ±΄κ°λ³΄ν | |
건κ°μΉμ보ν | |
μ μλΉλ³΄ν | |
μμ λΉλ³΄ν | |
λΌμ΄ννλλνλ¬μ€μ΄λ¦°μ΄λ³΄νβ ‘ | |
λΌμ΄ννλλνλ¬μ€μ΄λ¦°μ΄μ’ ν©λ³΄ν | |
λΌμ΄ννλλμλμΌμ΄μ μΆλ³΄νβ ‘ | |
λΌμ΄ννλλμ°κΈμ μΆλ³΄νβ ‘ | |
1λ λΆν°μ μΆλ³΄ν | |
λΌμ΄ννλλμ°κΈλ³΄νβ ‘ | |
{context} | |
### λͺ λ Ήμ΄: | |
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ μ΅λν μμΈνκ² μ 곡νμΈμ. | |
[μ§μΉ¨] | |
1.κ³ κ°μ κ°μ μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ. | |
2.κ³ κ°μ΄ κ°μ ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ. | |
3.κ³ κ°μ΄ κ°μ νμ§ μμ 보νμ 보μμ κ΄ν μ§λ¬Έμ κ΄λ ¨ 보νμ μκ°νλ©° 보μμ΄ λΆκ°λ₯νλ€λ μ μ μλ΄νμΈμ. | |
4.κ³ κ°μ΄ κ°μ νμ§ μμ 보νμ κ°μ μ΄ νμνλ€κ³ 보νλͺ μ νμ€νκ² μΈκΈνμΈμ. | |
λ€μ μ λ ₯μ μ£Όμ΄μ§λ κ³ κ°μ 보ν κ°μ μ 보μ μλ΄ κΈ°λ‘μ λ³΄κ³ κ³ κ°μκ² λμλλ μ 보λ₯Ό μ 곡νμΈμ. μ°¨κ·Όμ°¨κ·Ό μκ°νμ¬ λ΅λ³νμΈμ. λΉμ μ μ ν μ μμ΅λλ€. | |
### μ λ ₯: | |
[κ³ κ°μ κ°μ μ 보] | |
{customer_data_newline} | |
[μλ΄ κΈ°λ‘] | |
{history[index]} | |
κ³ κ°:{question} | |
### μλ΅: | |
""" | |
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμλ΅ κ°μ²΄λ₯Ό μμ± | |
qa = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=compression_retriever, | |
return_source_documents=False, | |
verbose=True, | |
chain_type_kwargs={"prompt": PromptTemplate( | |
input_variables=["context","question"], | |
template=prompt_template, | |
)}, | |
) | |
if customer_agree_list[index] == "No": | |
query=f"{x}" | |
else: | |
query=f"{customer_data_list[index]}, {x}" | |
response = qa({"query":query}) | |
output_str = response['result'].rsplit(".")[0] + "." | |
if output_str.split(":")[0]=="μλ΄μ": | |
output_str = output_str.split(":")[1] | |
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n" | |
if customer_agree_list[index] == "No": | |
output_str = f"* νμ¬ κ°μ μ 보λ₯Ό μ‘°νν μ μμ΅λλ€. λ¨Όμ κ°μΈμ 보 μ΄μ© μ½κ΄μ λμνμ μΌ μνν μλ΄μ μ§νν μ μμ΅λλ€." + output_str | |
return output_str | |
def reset_textbox(): | |
return gr.update(value='') | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0" | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
user_text = gr.Textbox( | |
placeholder='μ λ ₯', | |
label="User input" | |
) | |
model_output = gr.Textbox(label="Model output", lines=10, interactive=False) | |
button_submit = gr.Button(value="Submit") | |
with gr.Column(scale=1): | |
id_text = gr.Textbox( | |
placeholder='772727', | |
label="User id" | |
) | |
customer_data = gr.Textbox( | |
placeholder='(무)1λ λΆν°μ μΆλ³΄ν, (무)μμ λΉλ³΄ν', | |
label="customer_data" | |
) | |
button_submit.click(gen, [user_text, id_text, customer_data], model_output) | |
demo.queue().launch(enable_queue=True) |