msy127 commited on
Commit
a38cc11
·
1 Parent(s): 9531d4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -184
app.py CHANGED
@@ -1,184 +1,2 @@
1
- import json
2
- import os
3
- import gradio as gr
4
- import time
5
- from pydantic import BaseModel, Field
6
- from typing import Any, Optional, Dict, List
7
- from huggingface_hub import InferenceClient
8
- from langchain.llms.base import LLM
9
- from langchain.embeddings import HuggingFaceInstructEmbeddings
10
- from langchain.vectorstores import Chroma
11
- import os
12
- from dotenv import load_dotenv
13
- load_dotenv()
14
-
15
- path_work = "."
16
- hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
-
18
- embeddings = HuggingFaceInstructEmbeddings(
19
- model_name="sentence-transformers/all-MiniLM-L6-v2",
20
- model_kwargs={"device": "cpu"}
21
- )
22
-
23
- vectordb = Chroma(
24
- persist_directory = path_work + '/cromadb_llama2-papers',
25
- embedding_function=embeddings)
26
-
27
- retriever = vectordb.as_retriever(search_kwargs={"k": 5})
28
-
29
- class KwArgsModel(BaseModel):
30
- kwargs: Dict[str, Any] = Field(default_factory=dict)
31
-
32
- class CustomInferenceClient(LLM, KwArgsModel):
33
- model_name: str
34
- inference_client: InferenceClient
35
-
36
- def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None):
37
- inference_client = InferenceClient(model=model_name, token=hf_token)
38
- super().__init__(
39
- model_name=model_name,
40
- hf_token=hf_token,
41
- kwargs=kwargs,
42
- inference_client=inference_client
43
- )
44
-
45
- def _call(
46
- self,
47
- prompt: str,
48
- stop: Optional[List[str]] = None
49
- ) -> str:
50
- if stop is not None:
51
- raise ValueError("stop kwargs are not permitted.")
52
- response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True)
53
- response = ''.join(response_gen)
54
- return response
55
-
56
- @property
57
- def _llm_type(self) -> str:
58
- return "custom"
59
-
60
- @property
61
- def _identifying_params(self) -> dict:
62
- return {"model_name": self.model_name}
63
-
64
- kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True}
65
-
66
- model_list=[
67
- "meta-llama/Llama-2-13b-chat-hf",
68
- "HuggingFaceH4/zephyr-7b-alpha",
69
- "meta-llama/Llama-2-70b-chat-hf",
70
- "tiiuae/falcon-180B-chat"
71
- ]
72
-
73
- qa_chain = None
74
-
75
- def load_model(model_selected):
76
- global qa_chain
77
- model_name = model_selected
78
- llm = CustomInferenceClient(model_name=model_name, hf_token=hf_token, kwargs=kwargs)
79
-
80
- from langchain.chains import RetrievalQA
81
- qa_chain = RetrievalQA.from_chain_type(
82
- llm=llm,
83
- chain_type="stuff",
84
- retriever=retriever,
85
- return_source_documents=True,
86
- verbose=True,
87
- )
88
- qa_chain
89
-
90
- load_model("meta-llama/Llama-2-70b-chat-hf")
91
-
92
- def model_select(model_selected):
93
- load_model(model_selected)
94
- return f"모델 {model_selected} 로딩 완료."
95
-
96
- def predict(message, chatbot, temperature=0.9, max_new_tokens=512, top_p=0.6, repetition_penalty=1.3,):
97
-
98
- temperature = float(temperature)
99
- if temperature < 1e-2: temperature = 1e-2
100
- top_p = float(top_p)
101
-
102
- llm_response = qa_chain(message)
103
- res_result = llm_response['result']
104
-
105
- res_relevant_doc = [source.metadata['source'] for source in llm_response["source_documents"]]
106
- response = f"{res_result}" + "\n\n" + "[답변 근거 소스 논문 (ctrl + click 하세요!)] :" + "\n" + f" \n {res_relevant_doc}"
107
- print("response: =====> \n", response, "\n\n")
108
-
109
- tokens = response.split('\n')
110
- token_list = []
111
- for idx, token in enumerate(tokens):
112
- token_dict = {"id": idx + 1, "text": token}
113
- token_list.append(token_dict)
114
- response = {"data": {"token": token_list}}
115
- response = json.dumps(response, indent=4)
116
-
117
- response = json.loads(response)
118
- data_dict = response.get('data', {})
119
- token_list = data_dict.get('token', [])
120
-
121
- partial_message = ""
122
- for token_entry in token_list:
123
- if token_entry:
124
- try:
125
- token_id = token_entry.get('id', None)
126
- token_text = token_entry.get('text', None)
127
-
128
- if token_text:
129
- for char in token_text:
130
- partial_message += char
131
- yield partial_message
132
- time.sleep(0.01)
133
- else:
134
- print(f"[[워닝]] ==> The key 'text' does not exist or is None in this token entry: {token_entry}")
135
- pass
136
-
137
- except KeyError as e:
138
- gr.Warning(f"KeyError: {e} occurred for token entry: {token_entry}")
139
- continue
140
-
141
- title = "Llama-2 모델 관련 논문 Generative QA (with RAG) 서비스 (Llama-2-70b 모델 등 활용)"
142
- description = """Chat history 유지 보다는 QA에 충실하도록 제작되었으므로 Single turn으로 활용 하여 주세요. Default로 Llama-2 70b 모델로 설정되어 있으나 GPU 서비스 한도 초과로 Error가 발생할 수 있으니 양해부탁드리며, 화면 하단의 모델 변경/로딩하시어 다른 모델로 변경하여 사용을 부탁드립��다. (다만, Llama-2 70b가 가장 정확하오니 참고하여 주시기 바랍니다.) """
143
- css = """.toast-wrap { display: none !important } """
144
- examples=[['Can you tell me about the llama-2 model?'],['What is percent accuracy, using the SPP layer as features on the SPP (ZF-5) model?'], ["How much less accurate is using the SPP layer as features on the SPP (ZF-5) model compared to using the same model on the undistorted full image?"], ["tell me about method for human pose estimation based on DNNs"]]
145
-
146
- def vote(data: gr.LikeData):
147
- if data.liked: print("You upvoted this response: " + data.value)
148
- else: print("You downvoted this response: " + data.value)
149
-
150
- additional_inputs = [
151
- gr.Slider(label="Temperature", value=0.9, minimum=0.0, maximum=1.0, step=0.05, interactive=True, info="Higher values produce more diverse outputs"),
152
- gr.Slider(label="Max new tokens", value=256, minimum=0, maximum=4096, step=64, interactive=True, info="The maximum numbers of new tokens"),
153
- gr.Slider(label="Top-p (nucleus sampling)", value=0.6, minimum=0.0, maximum=1, step=0.05, interactive=True, info="Higher values sample more low-probability tokens"),
154
- gr.Slider(label="Repetition penalty", value=1.2, minimum=1.0, maximum=2.0, step=0.05, interactive=True, info="Penalize repeated tokens")
155
- ]
156
-
157
- chatbot_stream = gr.Chatbot(avatar_images=(
158
- "https://drive.google.com/uc?id=18xKoNOHN15H_qmGhK__VKnGjKjirrquW",
159
- "https://drive.google.com/uc?id=1tfELAQW_VbPCy6QTRbexRlwAEYo8rSSv"
160
- ), bubble_full_width = False)
161
-
162
- chat_interface_stream = gr.ChatInterface(
163
- predict,
164
- title=title,
165
- description=description,
166
- chatbot=chatbot_stream,
167
- css=css,
168
- examples=examples,
169
- )
170
-
171
- with gr.Blocks() as demo:
172
- with gr.Tab("스트리밍"):
173
- chatbot_stream.like(vote, None, None)
174
- chat_interface_stream.render()
175
- with gr.Row():
176
- with gr.Column(scale=6):
177
- with gr.Row():
178
- model_selector = gr.Dropdown(model_list, label="모델 선택", value= "meta-llama/Llama-2-70b-chat-hf", scale=5)
179
- submit_btn1 = gr.Button(value="모델 로드", scale=1)
180
- with gr.Column(scale=4):
181
- model_status = gr.Textbox(value="", label="모델 상태")
182
- submit_btn1.click(model_select, inputs=[model_selector], outputs=[model_status])
183
-
184
- demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
 
1
+ import sys
2
+ print(sys.version)