ldhldh commited on
Commit
4b7cafe
Β·
1 Parent(s): b5d147c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from llama_cpp import Llama
3
+ import torch
4
+ import gradio as gr
5
+ import re
6
+ import asyncio
7
+ import requests
8
+ import shutil
9
+ from langchain.llms import LlamaCpp
10
+ from langchain import PromptTemplate, LLMChain
11
+ from langchain.retrievers.document_compressors import EmbeddingsFilter
12
+ from langchain.retrievers import ContextualCompressionRetriever
13
+ from langchain.chains import RetrievalQA
14
+ from langchain.vectorstores import FAISS
15
+ from langchain.embeddings import HuggingFaceEmbeddings
16
+
17
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ print("Running on device:", torch_device)
19
+ print("CPU threads:", torch.get_num_threads())
20
+
21
+
22
+ llm = LlamaCpp(
23
+ model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
24
+ temperature=0.5,
25
+ top_p=0.9,
26
+ max_tokens=128,
27
+ verbose=True,
28
+ n_ctx=2048,
29
+ n_gpu_layers=-1,
30
+ f16_kv=True
31
+ )
32
+
33
+ # μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ
34
+ embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
35
+
36
+ # faiss_db 둜 λ‘œμ»¬μ— λ‘œλ“œν•˜κΈ°
37
+ docsearch = FAISS.load_local("faiss_db", embeddings)
38
+
39
+ embeddings_filter = EmbeddingsFilter(
40
+ embeddings=embeddings,
41
+ similarity_threshold=0.7,
42
+ k = 2,
43
+ )
44
+ # μ••μΆ• 검색기 생성
45
+ compression_retriever = ContextualCompressionRetriever(
46
+ # embeddings_filter μ„€μ •
47
+ base_compressor=embeddings_filter,
48
+ # retriever λ₯Ό ν˜ΈμΆœν•˜μ—¬ 검색쿼리와 μœ μ‚¬ν•œ ν…μŠ€νŠΈλ₯Ό 찾음
49
+ base_retriever=docsearch.as_retriever()
50
+ )
51
+
52
+
53
+ id_list = []
54
+ history = []
55
+ customer_data = ""
56
+ context = "{context}"
57
+ question = "{question}"
58
+
59
+ def gen(x, id, customer_data):
60
+
61
+ index = 0
62
+ matched = 0
63
+ count = 0
64
+ for s in id_list:
65
+ if s == id:
66
+ matched = 1
67
+ break;
68
+ index += 1
69
+
70
+ if matched == 0:
71
+ index = len(id_list)
72
+ id_list.append(id)
73
+ history.append('상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n')
74
+
75
+ bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
76
+ return bot_str
77
+ else:
78
+ if x == "μ΄ˆκΈ°ν™”":
79
+ history[index] = '상담원:무엇을 λ„μ™€λ“œλ¦΄κΉŒμš”?\n'
80
+ bot_str = f"λŒ€ν™”κΈ°λ‘μ΄ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.\n\nν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
81
+ elif x == "κ°€μž…μ •λ³΄":
82
+ bot_str = f"ν˜„μž¬ κ³ κ°λ‹˜κ»˜μ„œ κ°€μž…λœ λ³΄ν—˜μ€ {customer_data}μž…λ‹ˆλ‹€.\n\nκΆκΈˆν•˜μ‹  것이 μžˆμœΌμ‹ κ°€μš”?"
83
+ else:
84
+ context = "{context}"
85
+ question = "{question}"
86
+ customer_data_newline = customer_data.replace(",","\n")
87
+
88
+ from langchain.prompts import PromptTemplate
89
+ prompt_template = f"""당신은 λ³΄ν—˜ μƒλ‹΄μ›μž…λ‹ˆλ‹€. μ•„λž˜μ— 질문과 κ΄€λ ¨λœ μ•½κ΄€ 정보, 응닡 지침과 고객의 λ³΄ν—˜ κ°€μž… 정보, 고객과의 상담기둝이 μ£Όμ–΄μ§‘λ‹ˆλ‹€. μš”μ²­μ„ 적절히 μ™„λ£Œν•˜λŠ” 응닡을 μž‘μ„±ν•˜μ„Έμš”.
90
+
91
+ {context}
92
+
93
+ ### λͺ…λ Ήμ–΄:
94
+ λ‹€μŒ 지침을 μ°Έκ³ ν•˜μ—¬ μƒλ‹΄μ›μœΌλ‘œμ„œ κ³ κ°μ—κ²Œ ν•„μš”ν•œ 응닡을 μ œκ³΅ν•˜μ„Έμš”.
95
+
96
+ [지침]
97
+ 1.고객의 κ°€μž… 정보λ₯Ό κΌ­ ν™•μΈν•˜μ—¬ 고객이 κ°€μž…ν•œ λ³΄ν—˜μ— λŒ€ν•œ λ‚΄μš©λ§Œ μ œκ³΅ν•˜μ„Έμš”.
98
+ 2.고객이 κ°€μž…ν•œ λ³΄ν—˜μ΄λΌλ©΄ 고객의 μ§ˆλ¬Έμ— λŒ€ν•΄ 적절히 λ‹΅λ³€ν•˜μ„Έμš”.
99
+ 3.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ˜ 보상에 κ΄€ν•œ μ§ˆλ¬Έμ€ κ΄€λ ¨ λ³΄ν—˜μ„ μ†Œκ°œν•˜λ©° 보상이 λΆˆκ°€λŠ₯ν•˜λ‹€λŠ” 점을 μ•ˆλ‚΄ν•˜μ„Έμš”.
100
+ 4.고객이 κ°€μž…ν•˜μ§€ μ•Šμ€ λ³΄ν—˜μ€ κ°€μž…μ΄ ν•„μš”ν•˜λ‹€κ³  λ³΄ν—˜λͺ…을 ν™•μ‹€ν•˜κ²Œ μ–ΈκΈ‰ν•˜μ„Έμš”.
101
+
102
+ λ‹€μŒ μž…λ ₯에 μ£Όμ–΄μ§€λŠ” 고객의 λ³΄ν—˜ κ°€μž… 정보와 상담 기둝을 보고 κ³ κ°μ—κ²Œ λ„μ›€λ˜λŠ” 정보λ₯Ό μ œκ³΅ν•˜μ„Έμš”. μ°¨κ·Όμ°¨κ·Ό μƒκ°ν•˜μ—¬ λ‹΅λ³€ν•˜μ„Έμš”. 당신은 잘 ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
103
+
104
+ ### μž…λ ₯:
105
+ [고객의 κ°€μž… 정보]
106
+ {customer_data_newline}
107
+
108
+ [상담 기둝]
109
+ {history[index]}
110
+ 고객:{question}
111
+
112
+ ### 응닡:
113
+ """
114
+
115
+ # RetrievalQA 클래슀의 from_chain_typeμ΄λΌλŠ” 클래슀 λ©”μ„œλ“œλ₯Ό ν˜ΈμΆœν•˜μ—¬ μ§ˆμ˜μ‘λ‹΅ 객체λ₯Ό 생성
116
+ qa = RetrievalQA.from_chain_type(
117
+ llm=llm,
118
+ chain_type="stuff",
119
+ retriever=compression_retriever,
120
+ return_source_documents=False,
121
+ verbose=True,
122
+ chain_type_kwargs={"prompt": PromptTemplate(
123
+ input_variables=["context","question"],
124
+ template=prompt_template,
125
+ )},
126
+ )
127
+ query=f"λ‚˜λŠ” ν˜„μž¬ {customer_data}만 κ°€μž…ν•œ 상황이야. {x}"
128
+ response = qa({"query":query})
129
+ output_str = response.split("###")[0].split("\u200b")[0]
130
+ history[index] += f"고객:{x}\n상담원:{output_str}\n"
131
+ return output_str
132
+ def reset_textbox():
133
+ return gr.update(value='')
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown(
136
+ "duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
137
+ )
138
+
139
+ with gr.Row():
140
+ with gr.Column(scale=4):
141
+ user_text = gr.Textbox(
142
+ placeholder='μž…λ ₯',
143
+ label="User input"
144
+ )
145
+ model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
146
+ button_submit = gr.Button(value="Submit")
147
+ with gr.Column(scale=1):
148
+ id_text = gr.Textbox(
149
+ placeholder='772727',
150
+ label="User id"
151
+ )
152
+ customer_data = gr.Textbox(
153
+ placeholder='(무)1λ…„λΆ€ν„°μ €μΆ•λ³΄ν—˜, (무)μˆ˜μˆ λΉ„λ³΄ν—˜',
154
+ label="customer_data"
155
+ )
156
+ button_submit.click(gen, [user_text, id_text, customer_data], model_output)
157
+ demo.queue().launch(enable_queue=True)