Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from threading import Thread
|
2 |
+
from llama_cpp import Llama
|
3 |
+
import torch
|
4 |
+
import gradio as gr
|
5 |
+
import re
|
6 |
+
import asyncio
|
7 |
+
import requests
|
8 |
+
import shutil
|
9 |
+
from langchain.llms import LlamaCpp
|
10 |
+
from langchain import PromptTemplate, LLMChain
|
11 |
+
from langchain.retrievers.document_compressors import EmbeddingsFilter
|
12 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
13 |
+
from langchain.chains import RetrievalQA
|
14 |
+
from langchain.vectorstores import FAISS
|
15 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
16 |
+
|
17 |
+
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
print("Running on device:", torch_device)
|
19 |
+
print("CPU threads:", torch.get_num_threads())
|
20 |
+
|
21 |
+
|
22 |
+
llm = LlamaCpp(
|
23 |
+
model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
|
24 |
+
temperature=0.5,
|
25 |
+
top_p=0.9,
|
26 |
+
max_tokens=128,
|
27 |
+
verbose=True,
|
28 |
+
n_ctx=2048,
|
29 |
+
n_gpu_layers=-1,
|
30 |
+
f16_kv=True
|
31 |
+
)
|
32 |
+
|
33 |
+
# μλ² λ© λͺ¨λΈ λ‘λ
|
34 |
+
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
|
35 |
+
|
36 |
+
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ°
|
37 |
+
docsearch = FAISS.load_local("faiss_db", embeddings)
|
38 |
+
|
39 |
+
embeddings_filter = EmbeddingsFilter(
|
40 |
+
embeddings=embeddings,
|
41 |
+
similarity_threshold=0.7,
|
42 |
+
k = 2,
|
43 |
+
)
|
44 |
+
# μμΆ κ²μκΈ° μμ±
|
45 |
+
compression_retriever = ContextualCompressionRetriever(
|
46 |
+
# embeddings_filter μ€μ
|
47 |
+
base_compressor=embeddings_filter,
|
48 |
+
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν
μ€νΈλ₯Ό μ°Ύμ
|
49 |
+
base_retriever=docsearch.as_retriever()
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
id_list = []
|
54 |
+
history = []
|
55 |
+
customer_data = ""
|
56 |
+
context = "{context}"
|
57 |
+
question = "{question}"
|
58 |
+
|
59 |
+
def gen(x, id, customer_data):
|
60 |
+
|
61 |
+
index = 0
|
62 |
+
matched = 0
|
63 |
+
count = 0
|
64 |
+
for s in id_list:
|
65 |
+
if s == id:
|
66 |
+
matched = 1
|
67 |
+
break;
|
68 |
+
index += 1
|
69 |
+
|
70 |
+
if matched == 0:
|
71 |
+
index = len(id_list)
|
72 |
+
id_list.append(id)
|
73 |
+
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n')
|
74 |
+
|
75 |
+
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
76 |
+
return bot_str
|
77 |
+
else:
|
78 |
+
if x == "μ΄κΈ°ν":
|
79 |
+
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n'
|
80 |
+
bot_str = f"λνκΈ°λ‘μ΄ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
81 |
+
elif x == "κ°μ
μ 보":
|
82 |
+
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
|
83 |
+
else:
|
84 |
+
context = "{context}"
|
85 |
+
question = "{question}"
|
86 |
+
customer_data_newline = customer_data.replace(",","\n")
|
87 |
+
|
88 |
+
from langchain.prompts import PromptTemplate
|
89 |
+
prompt_template = f"""λΉμ μ 보ν μλ΄μμ
λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ
μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.
|
90 |
+
|
91 |
+
{context}
|
92 |
+
|
93 |
+
### λͺ
λ Ήμ΄:
|
94 |
+
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ μ 곡νμΈμ.
|
95 |
+
|
96 |
+
[μ§μΉ¨]
|
97 |
+
1.κ³ κ°μ κ°μ
μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ
ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ.
|
98 |
+
2.κ³ κ°μ΄ κ°μ
ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ.
|
99 |
+
3.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ 보μμ κ΄ν μ§λ¬Έμ κ΄λ ¨ 보νμ μκ°νλ©° 보μμ΄ λΆκ°λ₯νλ€λ μ μ μλ΄νμΈμ.
|
100 |
+
4.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ κ°μ
μ΄ νμνλ€κ³ 보νλͺ
μ νμ€νκ² μΈκΈνμΈμ.
|
101 |
+
|
102 |
+
λ€μ μ
λ ₯μ μ£Όμ΄μ§λ κ³ κ°μ 보ν κ°μ
μ 보μ μλ΄ κΈ°λ‘μ λ³΄κ³ κ³ κ°μκ² λμλλ μ 보λ₯Ό μ 곡νμΈμ. μ°¨κ·Όμ°¨κ·Ό μκ°νμ¬ λ΅λ³νμΈμ. λΉμ μ μ ν μ μμ΅λλ€.
|
103 |
+
|
104 |
+
### μ
λ ₯:
|
105 |
+
[κ³ κ°μ κ°μ
μ 보]
|
106 |
+
{customer_data_newline}
|
107 |
+
|
108 |
+
[μλ΄ κΈ°λ‘]
|
109 |
+
{history[index]}
|
110 |
+
κ³ κ°:{question}
|
111 |
+
|
112 |
+
### μλ΅:
|
113 |
+
"""
|
114 |
+
|
115 |
+
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμλ΅ κ°μ²΄λ₯Ό μμ±
|
116 |
+
qa = RetrievalQA.from_chain_type(
|
117 |
+
llm=llm,
|
118 |
+
chain_type="stuff",
|
119 |
+
retriever=compression_retriever,
|
120 |
+
return_source_documents=False,
|
121 |
+
verbose=True,
|
122 |
+
chain_type_kwargs={"prompt": PromptTemplate(
|
123 |
+
input_variables=["context","question"],
|
124 |
+
template=prompt_template,
|
125 |
+
)},
|
126 |
+
)
|
127 |
+
query=f"λλ νμ¬ {customer_data}λ§ κ°μ
ν μν©μ΄μΌ. {x}"
|
128 |
+
response = qa({"query":query})
|
129 |
+
output_str = response.split("###")[0].split("\u200b")[0]
|
130 |
+
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n"
|
131 |
+
return output_str
|
132 |
+
def reset_textbox():
|
133 |
+
return gr.update(value='')
|
134 |
+
with gr.Blocks() as demo:
|
135 |
+
gr.Markdown(
|
136 |
+
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
|
137 |
+
)
|
138 |
+
|
139 |
+
with gr.Row():
|
140 |
+
with gr.Column(scale=4):
|
141 |
+
user_text = gr.Textbox(
|
142 |
+
placeholder='μ
λ ₯',
|
143 |
+
label="User input"
|
144 |
+
)
|
145 |
+
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
|
146 |
+
button_submit = gr.Button(value="Submit")
|
147 |
+
with gr.Column(scale=1):
|
148 |
+
id_text = gr.Textbox(
|
149 |
+
placeholder='772727',
|
150 |
+
label="User id"
|
151 |
+
)
|
152 |
+
customer_data = gr.Textbox(
|
153 |
+
placeholder='(무)1λ
λΆν°μ μΆλ³΄ν, (무)μμ λΉλ³΄ν',
|
154 |
+
label="customer_data"
|
155 |
+
)
|
156 |
+
button_submit.click(gen, [user_text, id_text, customer_data], model_output)
|
157 |
+
demo.queue().launch(enable_queue=True)
|