Spaces:
Runtime error
Runtime error
File size: 5,695 Bytes
4b7cafe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from threading import Thread
from llama_cpp import Llama
import torch
import gradio as gr
import re
import asyncio
import requests
import shutil
from langchain.llms import LlamaCpp
from langchain import PromptTemplate, LLMChain
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain.retrievers import ContextualCompressionRetriever
from langchain.chains import RetrievalQA
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on device:", torch_device)
print("CPU threads:", torch.get_num_threads())
llm = LlamaCpp(
model_path='Llama-2-ko-7B-chat-gguf-q4_0.bin',
temperature=0.5,
top_p=0.9,
max_tokens=128,
verbose=True,
n_ctx=2048,
n_gpu_layers=-1,
f16_kv=True
)
# μλ² λ© λͺ¨λΈ λ‘λ
embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
# faiss_db λ‘ λ‘컬μ λ‘λνκΈ°
docsearch = FAISS.load_local("faiss_db", embeddings)
embeddings_filter = EmbeddingsFilter(
embeddings=embeddings,
similarity_threshold=0.7,
k = 2,
)
# μμΆ κ²μκΈ° μμ±
compression_retriever = ContextualCompressionRetriever(
# embeddings_filter μ€μ
base_compressor=embeddings_filter,
# retriever λ₯Ό νΈμΆνμ¬ κ²μ쿼리μ μ μ¬ν ν
μ€νΈλ₯Ό μ°Ύμ
base_retriever=docsearch.as_retriever()
)
id_list = []
history = []
customer_data = ""
context = "{context}"
question = "{question}"
def gen(x, id, customer_data):
index = 0
matched = 0
count = 0
for s in id_list:
if s == id:
matched = 1
break;
index += 1
if matched == 0:
index = len(id_list)
id_list.append(id)
history.append('μλ΄μ:무μμ λμλ릴κΉμ?\n')
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
return bot_str
else:
if x == "μ΄κΈ°ν":
history[index] = 'μλ΄μ:무μμ λμλ릴κΉμ?\n'
bot_str = f"λνκΈ°λ‘μ΄ μ΄κΈ°νλμμ΅λλ€.\n\nνμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
elif x == "κ°μ
μ 보":
bot_str = f"νμ¬ κ³ κ°λκ»μ κ°μ
λ 보νμ {customer_data}μ
λλ€.\n\nκΆκΈνμ κ²μ΄ μμΌμ κ°μ?"
else:
context = "{context}"
question = "{question}"
customer_data_newline = customer_data.replace(",","\n")
from langchain.prompts import PromptTemplate
prompt_template = f"""λΉμ μ 보ν μλ΄μμ
λλ€. μλμ μ§λ¬Έκ³Ό κ΄λ ¨λ μ½κ΄ μ 보, μλ΅ μ§μΉ¨κ³Ό κ³ κ°μ 보ν κ°μ
μ 보, κ³ κ°κ³Όμ μλ΄κΈ°λ‘μ΄ μ£Όμ΄μ§λλ€. μμ²μ μ μ ν μλ£νλ μλ΅μ μμ±νμΈμ.
{context}
### λͺ
λ Ήμ΄:
λ€μ μ§μΉ¨μ μ°Έκ³ νμ¬ μλ΄μμΌλ‘μ κ³ κ°μκ² νμν μλ΅μ μ 곡νμΈμ.
[μ§μΉ¨]
1.κ³ κ°μ κ°μ
μ 보λ₯Ό κΌ νμΈνμ¬ κ³ κ°μ΄ κ°μ
ν 보νμ λν λ΄μ©λ§ μ 곡νμΈμ.
2.κ³ κ°μ΄ κ°μ
ν 보νμ΄λΌλ©΄ κ³ κ°μ μ§λ¬Έμ λν΄ μ μ ν λ΅λ³νμΈμ.
3.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ 보μμ κ΄ν μ§λ¬Έμ κ΄λ ¨ 보νμ μκ°νλ©° 보μμ΄ λΆκ°λ₯νλ€λ μ μ μλ΄νμΈμ.
4.κ³ κ°μ΄ κ°μ
νμ§ μμ 보νμ κ°μ
μ΄ νμνλ€κ³ 보νλͺ
μ νμ€νκ² μΈκΈνμΈμ.
λ€μ μ
λ ₯μ μ£Όμ΄μ§λ κ³ κ°μ 보ν κ°μ
μ 보μ μλ΄ κΈ°λ‘μ λ³΄κ³ κ³ κ°μκ² λμλλ μ 보λ₯Ό μ 곡νμΈμ. μ°¨κ·Όμ°¨κ·Ό μκ°νμ¬ λ΅λ³νμΈμ. λΉμ μ μ ν μ μμ΅λλ€.
### μ
λ ₯:
[κ³ κ°μ κ°μ
μ 보]
{customer_data_newline}
[μλ΄ κΈ°λ‘]
{history[index]}
κ³ κ°:{question}
### μλ΅:
"""
# RetrievalQA ν΄λμ€μ from_chain_typeμ΄λΌλ ν΄λμ€ λ©μλλ₯Ό νΈμΆνμ¬ μ§μμλ΅ κ°μ²΄λ₯Ό μμ±
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=compression_retriever,
return_source_documents=False,
verbose=True,
chain_type_kwargs={"prompt": PromptTemplate(
input_variables=["context","question"],
template=prompt_template,
)},
)
query=f"λλ νμ¬ {customer_data}λ§ κ°μ
ν μν©μ΄μΌ. {x}"
response = qa({"query":query})
output_str = response.split("###")[0].split("\u200b")[0]
history[index] += f"κ³ κ°:{x}\nμλ΄μ:{output_str}\n"
return output_str
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
gr.Markdown(
"duplicated from beomi/KoRWKV-1.5B, baseModel:Llama-2-ko-7B-chat-gguf-q4_0"
)
with gr.Row():
with gr.Column(scale=4):
user_text = gr.Textbox(
placeholder='μ
λ ₯',
label="User input"
)
model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
button_submit = gr.Button(value="Submit")
with gr.Column(scale=1):
id_text = gr.Textbox(
placeholder='772727',
label="User id"
)
customer_data = gr.Textbox(
placeholder='(무)1λ
λΆν°μ μΆλ³΄ν, (무)μμ λΉλ³΄ν',
label="customer_data"
)
button_submit.click(gen, [user_text, id_text, customer_data], model_output)
demo.queue().launch(enable_queue=True) |