Spaces:
Runtime error
Runtime error
zetavg
commited on
show loss/epoch chart on finetune ui
Browse files- llama_lora/ui/finetune/finetune_ui.py +16 -3
- llama_lora/ui/finetune/style.css +27 -1
- llama_lora/ui/finetune/training.py +59 -5
- requirements.txt +5 -3
llama_lora/ui/finetune/finetune_ui.py
CHANGED
|
@@ -28,7 +28,8 @@ from .previewing import (
|
|
| 28 |
)
|
| 29 |
from .training import (
|
| 30 |
do_train,
|
| 31 |
-
render_training_status
|
|
|
|
| 32 |
)
|
| 33 |
|
| 34 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
|
@@ -773,10 +774,15 @@ def finetune_ui():
|
|
| 773 |
)
|
| 774 |
|
| 775 |
train_status = gr.HTML(
|
| 776 |
-
"
|
| 777 |
label="Train Output",
|
| 778 |
elem_id="finetune_training_status")
|
| 779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
training_indicator = gr.HTML(
|
| 781 |
"training_indicator", visible=False, elem_id="finetune_training_indicator")
|
| 782 |
|
|
@@ -787,7 +793,8 @@ def finetune_ui():
|
|
| 787 |
continue_from_model,
|
| 788 |
continue_from_checkpoint,
|
| 789 |
]),
|
| 790 |
-
outputs=[train_status, training_indicator
|
|
|
|
| 791 |
)
|
| 792 |
|
| 793 |
# controlled by JS, shows the confirm_abort_button
|
|
@@ -803,6 +810,12 @@ def finetune_ui():
|
|
| 803 |
outputs=[train_status, training_indicator],
|
| 804 |
every=0.2
|
| 805 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
| 807 |
|
| 808 |
# things_that_might_timeout.append(training_status_updates)
|
|
|
|
| 28 |
)
|
| 29 |
from .training import (
|
| 30 |
do_train,
|
| 31 |
+
render_training_status,
|
| 32 |
+
render_loss_plot
|
| 33 |
)
|
| 34 |
|
| 35 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
|
|
|
| 774 |
)
|
| 775 |
|
| 776 |
train_status = gr.HTML(
|
| 777 |
+
"",
|
| 778 |
label="Train Output",
|
| 779 |
elem_id="finetune_training_status")
|
| 780 |
|
| 781 |
+
with gr.Column(visible=False, elem_id="finetune_loss_plot_container") as loss_plot_container:
|
| 782 |
+
loss_plot = gr.Plot(
|
| 783 |
+
visible=False, show_label=False,
|
| 784 |
+
elem_id="finetune_loss_plot")
|
| 785 |
+
|
| 786 |
training_indicator = gr.HTML(
|
| 787 |
"training_indicator", visible=False, elem_id="finetune_training_indicator")
|
| 788 |
|
|
|
|
| 793 |
continue_from_model,
|
| 794 |
continue_from_checkpoint,
|
| 795 |
]),
|
| 796 |
+
outputs=[train_status, training_indicator,
|
| 797 |
+
loss_plot_container, loss_plot]
|
| 798 |
)
|
| 799 |
|
| 800 |
# controlled by JS, shows the confirm_abort_button
|
|
|
|
| 810 |
outputs=[train_status, training_indicator],
|
| 811 |
every=0.2
|
| 812 |
)
|
| 813 |
+
loss_plot_updates = finetune_ui_blocks.load(
|
| 814 |
+
fn=render_loss_plot,
|
| 815 |
+
inputs=None,
|
| 816 |
+
outputs=[loss_plot_container, loss_plot],
|
| 817 |
+
every=10
|
| 818 |
+
)
|
| 819 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
| 820 |
|
| 821 |
# things_that_might_timeout.append(training_status_updates)
|
llama_lora/ui/finetune/style.css
CHANGED
|
@@ -255,7 +255,9 @@
|
|
| 255 |
display: none;
|
| 256 |
}
|
| 257 |
|
| 258 |
-
#finetune_training_status > .wrap
|
|
|
|
|
|
|
| 259 |
border: 0;
|
| 260 |
background: transparent;
|
| 261 |
pointer-events: none;
|
|
@@ -264,6 +266,17 @@
|
|
| 264 |
left: 0;
|
| 265 |
right: 0;
|
| 266 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
#finetune_training_status > .wrap .meta-text-center {
|
| 268 |
transform: none !important;
|
| 269 |
}
|
|
@@ -383,5 +396,18 @@
|
|
| 383 |
/* background: var(--error-background-fill) !important; */
|
| 384 |
border: 1px solid var(--error-border-color) !important;
|
| 385 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
#finetune_training_indicator { display: none; }
|
|
|
|
| 255 |
display: none;
|
| 256 |
}
|
| 257 |
|
| 258 |
+
#finetune_training_status > .wrap,
|
| 259 |
+
#finetune_loss_plot_container > .wrap,
|
| 260 |
+
#finetune_loss_plot > .wrap {
|
| 261 |
border: 0;
|
| 262 |
background: transparent;
|
| 263 |
pointer-events: none;
|
|
|
|
| 266 |
left: 0;
|
| 267 |
right: 0;
|
| 268 |
}
|
| 269 |
+
#finetune_training_status > .wrap:not(.generating)::after {
|
| 270 |
+
content: "Refresh the page if this takes too long.";
|
| 271 |
+
position: absolute;
|
| 272 |
+
top: 0;
|
| 273 |
+
left: 0;
|
| 274 |
+
right: 0;
|
| 275 |
+
bottom: 0;
|
| 276 |
+
padding-top: 64px;
|
| 277 |
+
opacity: 0.5;
|
| 278 |
+
text-align: center;
|
| 279 |
+
}
|
| 280 |
#finetune_training_status > .wrap .meta-text-center {
|
| 281 |
transform: none !important;
|
| 282 |
}
|
|
|
|
| 396 |
/* background: var(--error-background-fill) !important; */
|
| 397 |
border: 1px solid var(--error-border-color) !important;
|
| 398 |
}
|
| 399 |
+
#finetune_loss_plot {
|
| 400 |
+
padding: var(--block-padding);
|
| 401 |
+
}
|
| 402 |
+
#finetune_loss_plot .altair {
|
| 403 |
+
overflow: auto !important;
|
| 404 |
+
}
|
| 405 |
+
#finetune_loss_plot .altair > * {
|
| 406 |
+
margin: auto !important;
|
| 407 |
+
}
|
| 408 |
+
#finetune_loss_plot .vega-embed summary {
|
| 409 |
+
border: 0;
|
| 410 |
+
box-shadow: none;
|
| 411 |
+
}
|
| 412 |
|
| 413 |
#finetune_training_indicator { display: none; }
|
llama_lora/ui/finetune/training.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import time
|
|
|
|
| 4 |
import datetime
|
| 5 |
import pytz
|
| 6 |
import socket
|
| 7 |
import threading
|
| 8 |
import traceback
|
|
|
|
|
|
|
| 9 |
import gradio as gr
|
| 10 |
|
| 11 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
|
@@ -71,7 +74,7 @@ def do_train(
|
|
| 71 |
progress=gr.Progress(track_tqdm=False),
|
| 72 |
):
|
| 73 |
if Global.is_training or Global.is_train_starting:
|
| 74 |
-
return render_training_status()
|
| 75 |
|
| 76 |
reset_training_status()
|
| 77 |
Global.is_train_starting = True
|
|
@@ -206,6 +209,9 @@ def do_train(
|
|
| 206 |
print(message)
|
| 207 |
|
| 208 |
total_steps = 300
|
|
|
|
|
|
|
|
|
|
| 209 |
for i in range(300):
|
| 210 |
if (Global.should_stop_training):
|
| 211 |
break
|
|
@@ -213,11 +219,14 @@ def do_train(
|
|
| 213 |
current_step = i + 1
|
| 214 |
total_epochs = 3
|
| 215 |
current_epoch = i / 100
|
| 216 |
-
log_history = []
|
| 217 |
|
| 218 |
if (i > 20):
|
| 219 |
-
loss =
|
| 220 |
-
log_history
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
update_training_states(
|
| 223 |
total_steps=total_steps,
|
|
@@ -295,7 +304,7 @@ def do_train(
|
|
| 295 |
finally:
|
| 296 |
Global.is_train_starting = False
|
| 297 |
|
| 298 |
-
return render_training_status()
|
| 299 |
|
| 300 |
|
| 301 |
def render_training_status():
|
|
@@ -411,6 +420,51 @@ def render_training_status():
|
|
| 411 |
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
| 412 |
|
| 413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
def format_time(seconds):
|
| 415 |
hours, remainder = divmod(seconds, 3600)
|
| 416 |
minutes, seconds = divmod(remainder, 60)
|
|
|
|
| 1 |
import os
|
| 2 |
import json
|
| 3 |
import time
|
| 4 |
+
import math
|
| 5 |
import datetime
|
| 6 |
import pytz
|
| 7 |
import socket
|
| 8 |
import threading
|
| 9 |
import traceback
|
| 10 |
+
import altair as alt
|
| 11 |
+
import pandas as pd
|
| 12 |
import gradio as gr
|
| 13 |
|
| 14 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
|
|
|
| 74 |
progress=gr.Progress(track_tqdm=False),
|
| 75 |
):
|
| 76 |
if Global.is_training or Global.is_train_starting:
|
| 77 |
+
return render_training_status() + render_loss_plot()
|
| 78 |
|
| 79 |
reset_training_status()
|
| 80 |
Global.is_train_starting = True
|
|
|
|
| 209 |
print(message)
|
| 210 |
|
| 211 |
total_steps = 300
|
| 212 |
+
log_history = []
|
| 213 |
+
initial_loss = 2
|
| 214 |
+
loss_decay_rate = 0.8
|
| 215 |
for i in range(300):
|
| 216 |
if (Global.should_stop_training):
|
| 217 |
break
|
|
|
|
| 219 |
current_step = i + 1
|
| 220 |
total_epochs = 3
|
| 221 |
current_epoch = i / 100
|
|
|
|
| 222 |
|
| 223 |
if (i > 20):
|
| 224 |
+
loss = initial_loss * math.exp(-loss_decay_rate * current_epoch)
|
| 225 |
+
log_history.append({
|
| 226 |
+
'loss': loss,
|
| 227 |
+
'learning_rate': 0.0001,
|
| 228 |
+
'epoch': current_epoch
|
| 229 |
+
})
|
| 230 |
|
| 231 |
update_training_states(
|
| 232 |
total_steps=total_steps,
|
|
|
|
| 304 |
finally:
|
| 305 |
Global.is_train_starting = False
|
| 306 |
|
| 307 |
+
return render_training_status() + render_loss_plot()
|
| 308 |
|
| 309 |
|
| 310 |
def render_training_status():
|
|
|
|
| 420 |
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
| 421 |
|
| 422 |
|
| 423 |
+
def render_loss_plot():
|
| 424 |
+
if len(Global.training_log_history) <= 2:
|
| 425 |
+
return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
|
| 426 |
+
|
| 427 |
+
training_log_history = Global.training_log_history
|
| 428 |
+
|
| 429 |
+
loss_data = [
|
| 430 |
+
{
|
| 431 |
+
'type': 'train_loss' if 'loss' in item else 'eval_loss',
|
| 432 |
+
'loss': item.get('loss') or item.get('eval_loss'),
|
| 433 |
+
'epoch': item.get('epoch')
|
| 434 |
+
} for item in training_log_history
|
| 435 |
+
if ('loss' in item or 'eval_loss' in item)
|
| 436 |
+
and 'epoch' in item
|
| 437 |
+
]
|
| 438 |
+
|
| 439 |
+
source = pd.DataFrame(loss_data)
|
| 440 |
+
|
| 441 |
+
highlight = alt.selection(
|
| 442 |
+
type='single', # type: ignore
|
| 443 |
+
on='mouseover', fields=['type'], nearest=True
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
base = alt.Chart(source).encode( # type: ignore
|
| 447 |
+
x='epoch:Q',
|
| 448 |
+
y='loss:Q',
|
| 449 |
+
color='type:N',
|
| 450 |
+
tooltip=['type:N', 'loss:Q', 'epoch:Q']
|
| 451 |
+
)
|
| 452 |
+
|
| 453 |
+
points = base.mark_circle().encode(
|
| 454 |
+
opacity=alt.value(0)
|
| 455 |
+
).add_selection(
|
| 456 |
+
highlight
|
| 457 |
+
).properties(
|
| 458 |
+
width=640
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
lines = base.mark_line().encode(
|
| 462 |
+
size=alt.condition(~highlight, alt.value(1), alt.value(3))
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
return (gr.Column.update(visible=True), gr.Plot.update(points + lines, visible=True))
|
| 466 |
+
|
| 467 |
+
|
| 468 |
def format_time(seconds):
|
| 469 |
hours, remainder = divmod(seconds, 3600)
|
| 470 |
minutes, seconds = divmod(remainder, 60)
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
accelerate
|
|
|
|
| 2 |
appdirs
|
| 3 |
bitsandbytes
|
| 4 |
black
|
|
@@ -7,10 +8,11 @@ datasets
|
|
| 7 |
fire
|
| 8 |
git+https://github.com/huggingface/peft.git
|
| 9 |
git+https://github.com/huggingface/transformers.git
|
|
|
|
| 10 |
huggingface_hub
|
|
|
|
| 11 |
numba
|
| 12 |
nvidia-ml-py3
|
| 13 |
-
|
| 14 |
-
loralib
|
| 15 |
-
sentencepiece
|
| 16 |
random-word
|
|
|
|
|
|
| 1 |
accelerate
|
| 2 |
+
altair
|
| 3 |
appdirs
|
| 4 |
bitsandbytes
|
| 5 |
black
|
|
|
|
| 8 |
fire
|
| 9 |
git+https://github.com/huggingface/peft.git
|
| 10 |
git+https://github.com/huggingface/transformers.git
|
| 11 |
+
gradio
|
| 12 |
huggingface_hub
|
| 13 |
+
loralib
|
| 14 |
numba
|
| 15 |
nvidia-ml-py3
|
| 16 |
+
pandas
|
|
|
|
|
|
|
| 17 |
random-word
|
| 18 |
+
sentencepiece
|