Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,184 +1,2 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
import gradio as gr
|
4 |
-
import time
|
5 |
-
from pydantic import BaseModel, Field
|
6 |
-
from typing import Any, Optional, Dict, List
|
7 |
-
from huggingface_hub import InferenceClient
|
8 |
-
from langchain.llms.base import LLM
|
9 |
-
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
10 |
-
from langchain.vectorstores import Chroma
|
11 |
-
import os
|
12 |
-
from dotenv import load_dotenv
|
13 |
-
load_dotenv()
|
14 |
-
|
15 |
-
path_work = "."
|
16 |
-
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
17 |
-
|
18 |
-
embeddings = HuggingFaceInstructEmbeddings(
|
19 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
20 |
-
model_kwargs={"device": "cpu"}
|
21 |
-
)
|
22 |
-
|
23 |
-
vectordb = Chroma(
|
24 |
-
persist_directory = path_work + '/cromadb_llama2-papers',
|
25 |
-
embedding_function=embeddings)
|
26 |
-
|
27 |
-
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
|
28 |
-
|
29 |
-
class KwArgsModel(BaseModel):
|
30 |
-
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
31 |
-
|
32 |
-
class CustomInferenceClient(LLM, KwArgsModel):
|
33 |
-
model_name: str
|
34 |
-
inference_client: InferenceClient
|
35 |
-
|
36 |
-
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
|
37 |
-
inference_client = InferenceClient(model=model_name, token=hf_token)
|
38 |
-
super().__init__(
|
39 |
-
model_name=model_name,
|
40 |
-
hf_token=hf_token,
|
41 |
-
kwargs=kwargs,
|
42 |
-
inference_client=inference_client
|
43 |
-
)
|
44 |
-
|
45 |
-
def _call(
|
46 |
-
self,
|
47 |
-
prompt: str,
|
48 |
-
stop: Optional[List[str]] = None
|
49 |
-
) -> str:
|
50 |
-
if stop is not None:
|
51 |
-
raise ValueError("stop kwargs are not permitted.")
|
52 |
-
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
|
53 |
-
response = ''.join(response_gen)
|
54 |
-
return response
|
55 |
-
|
56 |
-
@property
|
57 |
-
def _llm_type(self) -> str:
|
58 |
-
return "custom"
|
59 |
-
|
60 |
-
@property
|
61 |
-
def _identifying_params(self) -> dict:
|
62 |
-
return {"model_name": self.model_name}
|
63 |
-
|
64 |
-
kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
|
65 |
-
|
66 |
-
model_list=[
|
67 |
-
"meta-llama/Llama-2-13b-chat-hf",
|
68 |
-
"HuggingFaceH4/zephyr-7b-alpha",
|
69 |
-
"meta-llama/Llama-2-70b-chat-hf",
|
70 |
-
"tiiuae/falcon-180B-chat"
|
71 |
-
]
|
72 |
-
|
73 |
-
qa_chain = None
|
74 |
-
|
75 |
-
def load_model(model_selected):
|
76 |
-
global qa_chain
|
77 |
-
model_name = model_selected
|
78 |
-
llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
|
79 |
-
|
80 |
-
from langchain.chains import RetrievalQA
|
81 |
-
qa_chain = RetrievalQA.from_chain_type(
|
82 |
-
llm=llm,
|
83 |
-
chain_type="stuff",
|
84 |
-
retriever=retriever,
|
85 |
-
return_source_documents=True,
|
86 |
-
verbose=True,
|
87 |
-
)
|
88 |
-
qa_chain
|
89 |
-
|
90 |
-
load_model("meta-llama/Llama-2-70b-chat-hf")
|
91 |
-
|
92 |
-
def model_select(model_selected):
|
93 |
-
load_model(model_selected)
|
94 |
-
return f"모델 {model_selected} 로딩 완료."
|
95 |
-
|
96 |
-
def predict(message, chatbot, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3,):
|
97 |
-
|
98 |
-
temperature = float(temperature)
|
99 |
-
if temperature < 1e-2: temperature = 1e-2
|
100 |
-
top_p = float(top_p)
|
101 |
-
|
102 |
-
llm_response = qa_chain(message)
|
103 |
-
res_result = llm_response['result']
|
104 |
-
|
105 |
-
res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
|
106 |
-
response = f"{res_result}" + "\n\n" + "[답변 근거 소스 논문 (ctrl + click 하세요!)] :" + "\n" + f" \n {res_relevant_doc}"
|
107 |
-
print("response: =====> \n", response, "\n\n")
|
108 |
-
|
109 |
-
tokens = response.split('\n')
|
110 |
-
token_list = []
|
111 |
-
for idx, token in enumerate(tokens):
|
112 |
-
token_dict = {"id": idx + 1, "text": token}
|
113 |
-
token_list.append(token_dict)
|
114 |
-
response = {"data": {"token": token_list}}
|
115 |
-
response = json.dumps(response, indent=4)
|
116 |
-
|
117 |
-
response = json.loads(response)
|
118 |
-
data_dict = response.get('data', {})
|
119 |
-
token_list = data_dict.get('token', [])
|
120 |
-
|
121 |
-
partial_message = ""
|
122 |
-
for token_entry in token_list:
|
123 |
-
if token_entry:
|
124 |
-
try:
|
125 |
-
token_id = token_entry.get('id', None)
|
126 |
-
token_text = token_entry.get('text', None)
|
127 |
-
|
128 |
-
if token_text:
|
129 |
-
for char in token_text:
|
130 |
-
partial_message += char
|
131 |
-
yield partial_message
|
132 |
-
time.sleep(0.01)
|
133 |
-
else:
|
134 |
-
print(f"[[워닝]] ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
|
135 |
-
pass
|
136 |
-
|
137 |
-
except KeyError as e:
|
138 |
-
gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}")
|
139 |
-
continue
|
140 |
-
|
141 |
-
title = "Llama-2 모델 관련 논문 Generative QA (with RAG) 서비스 (Llama-2-70b 모델 등 활용)"
|
142 |
-
description = """Chat history 유지 보다는 QA에 충실하도록 제작되었으므로 Single turn으로 활용 하여 주세요. Default로 Llama-2 70b 모델로 설정되어 있으나 GPU 서비스 한도 초과로 Error가 발생할 수 있으니 양해부탁드리며, 화면 하단의 모델 변경/로딩하시어 다른 모델로 변경하여 사용을 부탁드립��다. (다만, Llama-2 70b가 가장 정확하오니 참고하여 주시기 바랍니다.) """
|
143 |
-
css = """.toast-wrap { display: none !important } """
|
144 |
-
examples=[['Can you tell me about the llama-2 model?'],['What is percent accuracy, using the SPP layer as features on the SPP (ZF-5) model?'], ["How much less accurate is using the SPP layer as features on the SPP (ZF-5) model compared to using the same model on the undistorted full image?"], ["tell me about method for human pose estimation based on DNNs"]]
|
145 |
-
|
146 |
-
def vote(data: gr.LikeData):
|
147 |
-
if data.liked: print("You upvoted this response: " + data.value)
|
148 |
-
else: print("You downvoted this response: " + data.value)
|
149 |
-
|
150 |
-
additional_inputs = [
|
151 |
-
gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
|
152 |
-
gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens"),
|
153 |
-
gr.Slider(label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
|
154 |
-
gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
|
155 |
-
]
|
156 |
-
|
157 |
-
chatbot_stream = gr.Chatbot(avatar_images=(
|
158 |
-
"https://drive.google.com/uc?id=18xKoNOHN15H_qmGhK__VKnGjKjirrquW",
|
159 |
-
"https://drive.google.com/uc?id=1tfELAQW_VbPCy6QTRbexRlwAEYo8rSSv"
|
160 |
-
), bubble_full_width = False)
|
161 |
-
|
162 |
-
chat_interface_stream = gr.ChatInterface(
|
163 |
-
predict,
|
164 |
-
title=title,
|
165 |
-
description=description,
|
166 |
-
chatbot=chatbot_stream,
|
167 |
-
css=css,
|
168 |
-
examples=examples,
|
169 |
-
)
|
170 |
-
|
171 |
-
with gr.Blocks() as demo:
|
172 |
-
with gr.Tab("스트리밍"):
|
173 |
-
chatbot_stream.like(vote, None, None)
|
174 |
-
chat_interface_stream.render()
|
175 |
-
with gr.Row():
|
176 |
-
with gr.Column(scale=6):
|
177 |
-
with gr.Row():
|
178 |
-
model_selector = gr.Dropdown(model_list, label="모델 선택", value= "meta-llama/Llama-2-70b-chat-hf", scale=5)
|
179 |
-
submit_btn1 = gr.Button(value="모델 로드", scale=1)
|
180 |
-
with gr.Column(scale=4):
|
181 |
-
model_status = gr.Textbox(value="", label="모델 상태")
|
182 |
-
submit_btn1.click(model_select, inputs=[model_selector], outputs=[model_status])
|
183 |
-
|
184 |
-
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|
|
|
1 |
+
import sys
|
2 |
+
print(sys.version)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|