Spaces:
Runtime error
Runtime error
File size: 6,541 Bytes
4b7cafe 2caab98 4b7cafe 2caab98 42efc58 ab29248 4b7cafe 42efc58 4b7cafe d7359cd 42efc58 d7359cd 4b7cafe 5568b10 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 1baf1cd 4b7cafe 42efc58 1baf1cd 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe 42efc58 4b7cafe d7359cd 42efc58 4b7cafe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
from threading import Thread
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
import torch
import gradio as gr
import re
import asyncio
import requests
import shutil
from langchain import PromptTemplate, LLMChain
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
from langchain.llms import OpenAI
llm = OpenAI(model_name='gpt-3.5-turbo-instruct')
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
loader = PyPDFLoader("total.pdf")
pages = loader.load()
# λ°μ΄ν°λ₯Ό λΆλ¬μμ ν
μ€νΈλ₯Ό μΌμ ν μλ‘ λλκ³ κ΅¬λΆμλ‘ μ°κ²°νλ μμ
text_splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=0)
texts = text_splitter.split_documents(pages)
print(f"λ¬Έμμ {len(texts)}κ°μ λ¬Έμλ₯Ό κ°μ§κ³ μμ΅λλ€.")
# μλ² λ© λͺ¨λΈ λ‘λ
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# λ¬Έμμ μλ ν
μ€νΈλ₯Ό μλ² λ©νκ³ FAISS μ μΈλ±μ€λ₯Ό ꡬμΆν¨
index = FAISS.from_documents(
documents=texts,
embedding=embeddings,
)
# faiss_db λ‘ λ‘컬μ μ μ₯νκΈ°
index.save_local("")
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ°
docsearch = FAISS.load_local("", embeddings)
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings,
similarity_threshold=0.7,
k = 2,
)
# μμΆ κ²μκΈ° μμ±
compression_retriever = ContextualCompressionRetriever(
# embeddings_filter μ€μ
base_compressor=embeddings_filter,
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν
μ€νΈλ₯Ό μ°Ύμ
base_retriever=docsearch.as_retriever()
)
id_list = []
history = []
customer_data_list = []
context = "{context}"
question = "{question}"
def gen(x, id, customer_data):
index = 0
matched = 0
count = 0
for s in id_list:
if s == id:
matched = 1
break;
index += 1
if matched == 0:
index = len(id_list)
id_list.append(id)
customer_data_list.append(customer_data)
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n')
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
return bot_str
else:
if x == "μ΄κΈ°ν":
customer_data_list[index] = customer_data
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n'
bot_str = f"λνκΈ°λ‘μ΄ λͺ¨λ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
return bot_str
elif x == "κ°μ
μ 보":
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data_list[index]}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
return bot_str
else:
context = "{context}"
question = "{question}"
customer_data_newline = customer_data_list[index].replace(",","\n")
prompt_template = f"""λΉμ μ 보ν μλ΄μμ
λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ
μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.
{context}
### λͺ
λ Ήμ΄:
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ κ°κ²°νκ² μ 곡νμΈμ.
[μ§μΉ¨]
1.κ³ κ°μ κ°μ
μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ
ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ.
2.κ³ κ°μ΄ κ°μ
ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ.
3.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ 보μμ κ΄ν μ§λ¬Έμ κ΄λ ¨ 보νμ μκ°νλ©° 보μμ΄ λΆκ°λ₯νλ€λ μ μ μλ΄νμΈμ.
4.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ κ°μ
μ΄ νμνλ€κ³ 보νλͺ
μ νμ€νκ² μΈκΈνμΈμ.
λ€μ μ
λ ₯μ μ£Όμ΄μ§λ κ³ κ°μ 보ν κ°μ
μ 보μ μλ΄ κΈ°λ‘μ λ³΄κ³ κ³ κ°μκ² λμλλ μ 보λ₯Ό μ 곡νμΈμ. μ°¨κ·Όμ°¨κ·Ό μκ°νμ¬ λ΅λ³νμΈμ. λΉμ μ μ ν μ μμ΅λλ€.
### μ
λ ₯:
[κ³ κ°μ κ°μ
μ 보]
{customer_data_newline}
[μλ΄ κΈ°λ‘]
{history[index]}
κ³ κ°:{question}
### μλ΅:
"""
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμλ΅ κ°μ²΄λ₯Ό μμ±
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=compression_retriever,
return_source_documents=False,
verbose=True,
chain_type_kwargs={"prompt": PromptTemplate(
input_variables=["context","question"],
template=prompt_template,
)},
)
query=f"λλ νμ¬ {customer_data_list[index]}λ§ κ°μ
ν μν©μ΄μΌ. {x}"
response = qa({"query":query})
output_str = response['result'].rsplit(".",1)[0]+"."
print(prompt_template + output_str)
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n"
return output_str
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder='μ
λ ₯',
label="User input"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
with gr.Column(scale=1):
id_text = gr.Textbox(
placeholder='772727',
label="User id"
)
customer_data = gr.Textbox(
placeholder='(무)1λ
λΆν°μ μΆλ³΄ν, (무)μμ λΉλ³΄ν',
label="customer_data"
)
button_submit.click(gen, [user_text, id_text, customer_data], model_output)
demo.queue().launch(enable_queue=True) |