esbatmop commited on
Commit
7266379
·
verified ·
1 Parent(s): 5cca75a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -129
app.py CHANGED
@@ -1,136 +1,271 @@
1
- import os
2
- from threading import Thread
3
- from typing import Iterator
 
4
 
 
5
  import gradio as gr
6
- import spaces
7
- import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- DESCRIPTION = """\
11
- # Llama 3.2 3B Instruct
12
- Llama 3.2 3B is Meta's latest iteration of open LLMs.
13
- This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
14
- For more details, please check [our post](https://huggingface.co/blog/llama32).
15
- """
16
-
17
- MAX_MAX_NEW_TOKENS = 2048
18
- DEFAULT_MAX_NEW_TOKENS = 1024
19
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
-
21
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
-
23
- model_id = "liwu/liwu_forum_post_2.0"
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_id,
27
- device_map="auto",
28
- torch_dtype=torch.bfloat16,
29
- )
30
- model.eval()
31
-
32
-
33
- @spaces.GPU(duration=90)
34
- def generate(
35
- message: str,
36
- chat_history: list[tuple[str, str]],
37
- max_new_tokens: int = 1024,
38
- temperature: float = 0.6,
39
- top_p: float = 0.9,
40
- top_k: int = 50,
41
- repetition_penalty: float = 1.2,
42
- ) -> Iterator[str]:
43
- conversation = []
44
- for user, assistant in chat_history:
45
- conversation.extend(
46
- [
47
- {"role": "user", "content": user},
48
- {"role": "assistant", "content": assistant},
49
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  )
51
- conversation.append({"role": "user", "content": message})
52
-
53
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
54
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
- input_ids = input_ids.to(model.device)
58
-
59
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
60
- generate_kwargs = dict(
61
- {"input_ids": input_ids},
62
- streamer=streamer,
63
- max_new_tokens=max_new_tokens,
64
- do_sample=True,
65
- top_p=top_p,
66
- top_k=top_k,
67
- temperature=temperature,
68
- num_beams=1,
69
- repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  )
71
- t = Thread(target=model.generate, kwargs=generate_kwargs)
72
- t.start()
73
-
74
- outputs = []
75
- for text in streamer:
76
- outputs.append(text)
77
- yield "".join(outputs)
78
-
79
-
80
- chat_interface = gr.ChatInterface(
81
- fn=generate,
82
- additional_inputs=[
83
- gr.Slider(
84
- label="Max new tokens",
85
- minimum=1,
86
- maximum=MAX_MAX_NEW_TOKENS,
87
- step=1,
88
- value=DEFAULT_MAX_NEW_TOKENS,
89
- ),
90
- gr.Slider(
91
- label="Temperature",
92
- minimum=0.1,
93
- maximum=4.0,
94
- step=0.1,
95
- value=0.6,
96
- ),
97
- gr.Slider(
98
- label="Top-p (nucleus sampling)",
99
- minimum=0.05,
100
- maximum=1.0,
101
- step=0.05,
102
- value=0.9,
103
- ),
104
- gr.Slider(
105
- label="Top-k",
106
- minimum=1,
107
- maximum=1000,
108
- step=1,
109
- value=50,
110
- ),
111
- gr.Slider(
112
- label="Repetition penalty",
113
- minimum=1.0,
114
- maximum=2.0,
115
- step=0.05,
116
- value=1.2,
117
- ),
118
- ],
119
- stop_btn=None,
120
- examples=[
121
- ["Hello there! How are you doing?"],
122
- ["Can you explain briefly to me what is the Python programming language?"],
123
- ["Explain the plot of Cinderella in a sentence."],
124
- ["How many hours does it take a man to eat a Helicopter?"],
125
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
126
- ],
127
- cache_examples=False,
128
  )
 
129
 
130
- with gr.Blocks(css="style.css", fill_height=True) as demo:
131
- gr.Markdown(DESCRIPTION)
132
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
133
- chat_interface.render()
134
 
135
- if __name__ == "__main__":
136
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+ from typing import Union, List, Dict, Tuple, Optional
4
+ from tqdm import tqdm
5
 
6
+ import requests
7
  import gradio as gr
8
+ from llama_cpp import Llama
9
+
10
+
11
+ # ================== ANNOTATIONS ========================
12
+
13
+ CHAT_HISTORY = List[Optional[Dict[str, Optional[str]]]]
14
+ MODEL_DICT = Dict[str, Llama]
15
+
16
+
17
+ # ================== FUNCS =============================
18
+
19
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
20
+ response = requests.get(file_url, stream=True)
21
+ if response.status_code != 200:
22
+ raise Exception(f'Файл недоступен для скачивания по ссылке: {file_url}')
23
+ total_size = int(response.headers.get('content-length', 0))
24
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
25
+ progress_gradio = gr.Progress()
26
+ completed_size = 0
27
+ with open(file_path, 'wb') as file:
28
+ for data in response.iter_content(chunk_size=4096):
29
+ size = file.write(data)
30
+ progress_tqdm.update(size)
31
+ completed_size += size
32
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
33
+ progress_gradio(completed_size/total_size, desc=desc)
34
+
35
+
36
+ def download_gguf_and_init_model(gguf_url: str, model_dict: MODEL_DICT) -> Tuple[MODEL_DICT, bool, str]:
37
+ log = ''
38
+ if not gguf_url.endswith('.gguf'):
39
+ log += f'The link must be a direct link to the GGUF file\n'
40
+ return model_dict, log
41
+
42
+ gguf_filename = gguf_url.rsplit('/')[-1]
43
+ model_path = MODELS_PATH / gguf_filename
44
+ progress = gr.Progress()
45
+
46
+ if not model_path.is_file():
47
+ progress(0.3, desc='Шаг 1/2: Loading GGUF model file')
48
+ try:
49
+ download_file(gguf_url, model_path)
50
+ log += f'Model file {gguf_filename} successfully loaded\n'
51
+ except Exception as ex:
52
+ log += f'Error loading model from link {gguf_url}, error code:\n{ex}\n'
53
+ curr_model = model_dict.get('model')
54
+ if curr_model is None:
55
+ log += f'Model is missing from dictionary "model_dict"\n'
56
+ return model_dict, load_log
57
+ curr_model_filename = Path(curr_model.model_path).name
58
+ log += f'Current initialized model: {curr_model_filename}\n'
59
+ return model_dict, log
60
+ else:
61
+ log += f'Model file {gguf_filename} loaded, initializing model...\n'
62
+
63
+ progress(0.7, desc='Шаг 2/2: Model initialization')
64
+ model = Llama(model_path=str(model_path), n_gpu_layers=-1, verbose=True)
65
+ model_dict = {'model': model}
66
+ support_system_role = 'System role not supported' not in model.metadata['tokenizer.chat_template']
67
+ log += f'Model {gguf_filename} initialized\n'
68
+ return model_dict, support_system_role, log
69
+
70
+
71
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
72
+ if user_message:
73
+ chatbot.append({'role': 'user', 'metadata': {'title': None}, 'content': user_message})
74
+ return '', chatbot
75
+
76
+
77
+ def bot_response_to_chatbot(
78
+ chatbot: CHAT_HISTORY,
79
+ model_dict: MODEL_DICT,
80
+ system_prompt: str,
81
+ support_system_role: bool,
82
+ history_len: int,
83
+ do_sample: bool,
84
+ *generate_args,
85
+ ):
86
+
87
+ model = model_dict.get('model')
88
+ if model is None:
89
+ gr.Info('Model not initialized')
90
+ yield chatbot
91
+ return
92
+
93
+ if len(chatbot) == 0 or chatbot[-1]['role'] == 'assistant':
94
+ yield chatbot
95
+ return
96
+
97
+ messages = []
98
+ if support_system_role and system_prompt:
99
+ messages.append({'role': 'system', 'metadata': {'title': None}, 'content': system_prompt})
100
+
101
+ if history_len != 0:
102
+ messages.extend(chatbot[:-1][-(history_len*2):])
103
+
104
+ messages.append(chatbot[-1])
105
+
106
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
107
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
108
+ if not do_sample:
109
+ gen_kwargs['top_p'] = 0.0
110
+ gen_kwargs['top_k'] = 1
111
+ gen_kwargs['repeat_penalty'] = 1.0
112
+
113
+ stream_response = model.create_chat_completion(
114
+ messages=messages,
115
+ stream=True,
116
+ **gen_kwargs,
117
  )
118
+
119
+ chatbot.append({'role': 'assistant', 'metadata': {'title': None}, 'content': ''})
120
+ for chunk in stream_response:
121
+ token = chunk['choices'][0]['delta'].get('content')
122
+ if token is not None:
123
+ chatbot[-1]['content'] += token
124
+ yield chatbot
125
+
126
+
127
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
128
+ value = '' if interactive else 'System prompt is not supported by this model'
129
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
130
+
131
+
132
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
133
+ generate_args = [
134
+ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
135
+ gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
136
+ gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
137
+ gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
138
+ ]
139
+ return generate_args
140
+
141
+
142
+ # ================== VARIABLES =============================
143
+
144
+ MODELS_PATH = Path('models')
145
+ MODELS_PATH.mkdir(exist_ok=True)
146
+ DEFAULT_GGUF_URL = 'https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q8_0.gguf'
147
+
148
+ start_model_dict, start_support_system_role, start_load_log = download_gguf_and_init_model(
149
+ gguf_url=DEFAULT_GGUF_URL, model_dict={},
150
  )
151
+
152
+ GENERATE_KWARGS = dict(
153
+ temperature=0.2,
154
+ top_p=0.95,
155
+ top_k=40,
156
+ repeat_penalty=1.0,
157
+ )
158
+
159
+ theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set(
160
+ loader_color='rgb(0, 255, 0)',
161
+ slider_color='rgb(0, 200, 0)',
162
+ body_text_color_dark='rgb(0, 200, 0)',
163
+ button_secondary_background_fill_dark='green',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  )
165
+ css = '''.gradio-container {width: 60% !important}'''
166
 
 
 
 
 
167
 
168
+ # ================== INTERFACE =============================
169
+
170
+ with gr.Blocks(theme=theme, css=css) as interface:
171
+ model_dict = gr.State(start_model_dict)
172
+ support_system_role = gr.State(start_support_system_role)
173
+
174
+ # ================= CHAT BOT PAGE ======================
175
+ with gr.Tab('Chatbot'):
176
+ with gr.Row():
177
+ with gr.Column(scale=3):
178
+ chatbot = gr.Chatbot(
179
+ type='messages', # new in gradio 5+
180
+ show_copy_button=True,
181
+ bubble_full_width=False,
182
+ height=480,
183
+ )
184
+ user_message = gr.Textbox(label='User')
185
+
186
+ with gr.Row():
187
+ user_message_btn = gr.Button('Send')
188
+ stop_btn = gr.Button('Stop')
189
+ clear_btn = gr.Button('Clear')
190
+
191
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
192
+
193
+ with gr.Column(scale=1, min_width=80):
194
+ with gr.Group():
195
+ gr.Markdown('Length of message history')
196
+ history_len = gr.Slider(
197
+ minimum=0,
198
+ maximum=10,
199
+ value=0,
200
+ step=1,
201
+ info='Number of previous messages taken into account in history',
202
+ label='history_len',
203
+ show_label=False,
204
+ )
205
+
206
+ with gr.Group():
207
+ gr.Markdown('Generation parameters')
208
+ do_sample = gr.Checkbox(
209
+ value=False,
210
+ label='do_sample',
211
+ info='Activate random sampling',
212
+ )
213
+ generate_args = get_generate_args(do_sample.value)
214
+ do_sample.change(
215
+ fn=get_generate_args,
216
+ inputs=do_sample,
217
+ outputs=generate_args,
218
+ show_progress=False,
219
+ )
220
+
221
+ generate_event = gr.on(
222
+ triggers=[user_message.submit, user_message_btn.click],
223
+ fn=user_message_to_chatbot,
224
+ inputs=[user_message, chatbot],
225
+ outputs=[user_message, chatbot],
226
+ ).then(
227
+ fn=bot_response_to_chatbot,
228
+ inputs=[chatbot, model_dict, system_prompt, support_system_role, history_len, do_sample, *generate_args],
229
+ outputs=[chatbot],
230
+ )
231
+ stop_btn.click(
232
+ fn=None,
233
+ inputs=None,
234
+ outputs=None,
235
+ cancels=generate_event,
236
+ )
237
+ clear_btn.click(
238
+ fn=lambda: None,
239
+ inputs=None,
240
+ outputs=[chatbot],
241
+ )
242
+
243
+ # ================= LOAD MODELS PAGE ======================
244
+ with gr.Tab('Load model'):
245
+ gguf_url = gr.Textbox(
246
+ value='',
247
+ label='Link to GGUF',
248
+ placeholder='URL link to the model in GGUF format',
249
+ )
250
+ load_model_btn = gr.Button('Downloading GGUF and initializing the model')
251
+ load_log = gr.Textbox(
252
+ value=start_load_log,
253
+ label='Model loading status',
254
+ lines=3,
255
+ )
256
+
257
+ load_model_btn.click(
258
+ fn=download_gguf_and_init_model,
259
+ inputs=[gguf_url, model_dict],
260
+ outputs=[model_dict, support_system_role, load_log],
261
+ ).success(
262
+ fn=get_system_prompt_component,
263
+ inputs=[support_system_role],
264
+ outputs=[system_prompt],
265
+ )
266
+
267
+ gr.HTML("""<h3 style='text-align: center'>
268
+ <a href="https://github.com/sergey21000/gradio-llamacpp-chatbot" target='_blank'>GitHub Repository</a></h3>
269
+ """)
270
+
271
+ interface.launch(server_name='0.0.0.0', server_port=7860)