DiDustin commited on
Commit
15dcf22
·
verified ·
1 Parent(s): 9412e8b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +232 -0
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import spaces
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+ HF_TOKEN = os.getenv("HF_TOKEN_KAZLLM")
8
+
9
+ MODELS = {
10
+ "V-1: LLama-3.1-KazLLM-8B": {
11
+ "model_name": "issai/LLama-3.1-KazLLM-1.0-8B",
12
+ "tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-8B",
13
+ "duration": 120,
14
+ "defaults": {
15
+ "max_length": 100,
16
+ "temperature": 0.7,
17
+ "top_p": 0.9,
18
+ "do_sample": True
19
+ }
20
+ },
21
+ "V-2: LLama-3.1-KazLLM-70B-AWQ4": {
22
+ "model_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4",
23
+ "tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4",
24
+ "duration": 180,
25
+ "defaults": {
26
+ "max_length": 150,
27
+ "temperature": 0.8,
28
+ "top_p": 0.95,
29
+ "do_sample": True
30
+ }
31
+ }
32
+ }
33
+
34
+ LANGUAGES = {
35
+ "Русский": {
36
+ "title": "LLama-3.1 KazLLM с выбором модели и языка",
37
+ "description": "Выберите модель, язык интерфейса, введите запрос и получите сгенерированный текст с использованием выбранной модели LLama-3.1 KazLLM.",
38
+ "select_model": "Выберите модель",
39
+ "enter_prompt": "Введите запрос",
40
+ "max_length": "Максимальная длина текста",
41
+ "temperature": "Креативность (Температура)",
42
+ "top_p": "Top-p (ядро вероятности)",
43
+ "do_sample": "Использовать выборку (Do Sample)",
44
+ "generate_button": "Сгенерировать текст",
45
+ "generated_text": "Сгенерированный текст",
46
+ "language": "Выберите язык интерфейса"
47
+ },
48
+ "Қазақша": {
49
+ "title": "LLama-3.1 KazLLM модель таңдауы және тілін қолдау",
50
+ "description": "Модельді, интерфейс тілін таңдаңыз, сұрауыңызды енгізіңіз және таңдалған LLama-3.1 KazLLM моделін пайдаланып генерирленген мәтінді алыңыз.",
51
+ "select_model": "Модельді таңдаңыз",
52
+ "enter_prompt": "Сұрауыңызды енгізіңіз",
53
+ "max_length": "Мәтіннің максималды ұзындығы",
54
+ "temperature": "Шығармашылық (Температура)",
55
+ "top_p": "Top-p (ықтималдық негізі)",
56
+ "do_sample": "Үлгіні қолдану (Do Sample)",
57
+ "generate_button": "Мәтінді генерациялау",
58
+ "generated_text": "Генерацияланған мәтін",
59
+ "language": "Интерфейс тілін таңдаңыз"
60
+ }
61
+ }
62
+
63
+ loaded_models = {}
64
+ loaded_tokenizers = {}
65
+
66
+
67
+ def load_model_and_tokenizer(model_key):
68
+ if model_key not in loaded_models:
69
+ model_info = MODELS[model_key]
70
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
71
+ model = AutoModelForCausalLM.from_pretrained(
72
+ model_info["model_name"],
73
+ token=HF_TOKEN,
74
+ torch_dtype=torch.float16
75
+ ).to(device)
76
+ loaded_models[model_key] = model
77
+
78
+ tokenizer = AutoTokenizer.from_pretrained(
79
+ model_info["tokenizer_name"],
80
+ use_fast=True,
81
+ token=HF_TOKEN
82
+ )
83
+ if tokenizer.pad_token is None:
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+ loaded_tokenizers[model_key] = tokenizer
86
+
87
+
88
+ def generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
89
+ load_model_and_tokenizer(model_choice)
90
+
91
+ model = loaded_models[model_choice]
92
+ tokenizer = loaded_tokenizers[model_choice]
93
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
+
95
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device)
96
+
97
+ outputs = model.generate(
98
+ input_ids=inputs["input_ids"],
99
+ attention_mask=inputs["attention_mask"],
100
+ max_length=max_length,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ repetition_penalty=1.2,
104
+ no_repeat_ngram_size=2,
105
+ do_sample=do_sample,
106
+ )
107
+
108
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
109
+
110
+ return generated_text
111
+
112
+
113
+ def update_settings(model_choice):
114
+ defaults = MODELS[model_choice]["defaults"]
115
+ return (
116
+ gr.update(value=defaults["max_length"]),
117
+ gr.update(value=defaults["temperature"]),
118
+ gr.update(value=defaults["top_p"]),
119
+ gr.update(value=defaults["do_sample"])
120
+ )
121
+
122
+
123
+ def update_language(selected_language):
124
+ lang = LANGUAGES[selected_language]
125
+ return (
126
+ gr.update(value=lang["title"]),
127
+ gr.update(value=lang["description"]),
128
+ gr.update(label=lang["select_model"]),
129
+ gr.update(label=lang["enter_prompt"]),
130
+ gr.update(label=lang["max_length"]),
131
+ gr.update(label=lang["temperature"]),
132
+ gr.update(label=lang["top_p"]),
133
+ gr.update(label=lang["do_sample"]),
134
+ gr.update(value=lang["generate_button"]),
135
+ gr.update(label=lang["generated_text"])
136
+ )
137
+
138
+
139
+ @spaces.GPU(duration=180)
140
+ def wrapped_generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample):
141
+ return generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample)
142
+
143
+
144
+ with gr.Blocks() as iface:
145
+ with gr.Row():
146
+ language_dropdown = gr.Dropdown(
147
+ choices=list(LANGUAGES.keys()),
148
+ value="Русский",
149
+ label=LANGUAGES["Русский"]["language"]
150
+ )
151
+
152
+ title = gr.Markdown(LANGUAGES["Русский"]["title"])
153
+ description = gr.Markdown(LANGUAGES["Русский"]["description"])
154
+
155
+ with gr.Row():
156
+ model_dropdown = gr.Dropdown(
157
+ choices=list(MODELS.keys()),
158
+ value="V-2: LLama-3.1-KazLLM-70B-AWQ4",
159
+ label=LANGUAGES["Русский"]["select_model"]
160
+ )
161
+
162
+ with gr.Row():
163
+ prompt_input = gr.Textbox(
164
+ lines=4,
165
+ placeholder="Введите ваш запрос здесь...",
166
+ label=LANGUAGES["Русский"]["enter_prompt"]
167
+ )
168
+
169
+ with gr.Row():
170
+ max_length_slider = gr.Slider(
171
+ minimum=50,
172
+ maximum=1000,
173
+ step=10,
174
+ value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["max_length"],
175
+ label=LANGUAGES["Русский"]["max_length"]
176
+ )
177
+ temperature_slider = gr.Slider(
178
+ minimum=0.1,
179
+ maximum=2.0,
180
+ step=0.1,
181
+ value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["temperature"],
182
+ label=LANGUAGES["Русский"]["temperature"]
183
+ )
184
+
185
+ with gr.Row():
186
+ top_p_slider = gr.Slider(
187
+ minimum=0.1,
188
+ maximum=1.0,
189
+ step=0.05,
190
+ value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["top_p"],
191
+ label=LANGUAGES["Русский"]["top_p"]
192
+ )
193
+ do_sample_checkbox = gr.Checkbox(
194
+ value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["do_sample"],
195
+ label=LANGUAGES["Русский"]["do_sample"]
196
+ )
197
+
198
+ generate_button = gr.Button(LANGUAGES["Русский"]["generate_button"])
199
+
200
+ output_text = gr.Textbox(
201
+ label=LANGUAGES["Русский"]["generated_text"],
202
+ lines=10
203
+ )
204
+
205
+ model_dropdown.change(
206
+ fn=update_settings,
207
+ inputs=[model_dropdown],
208
+ outputs=[max_length_slider, temperature_slider, top_p_slider, do_sample_checkbox]
209
+ )
210
+
211
+ language_dropdown.change(
212
+ fn=update_language,
213
+ inputs=[language_dropdown],
214
+ outputs=[title, description, model_dropdown, prompt_input, max_length_slider, temperature_slider, top_p_slider,
215
+ do_sample_checkbox, generate_button, output_text]
216
+ )
217
+
218
+ generate_button.click(
219
+ fn=wrapped_generate_text,
220
+ inputs=[
221
+ model_dropdown,
222
+ prompt_input,
223
+ max_length_slider,
224
+ temperature_slider,
225
+ top_p_slider,
226
+ do_sample_checkbox
227
+ ],
228
+ outputs=output_text
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ iface.launch()