Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -281,6 +281,73 @@ def train_model(progress=gr.Progress()):
|
|
281 |
# Initialize trainer with memory-optimized settings
|
282 |
progress(0.2, desc="Initializing trainer...")
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
# Optional: try a custom data collator that explicitly caps sequence length
|
285 |
def data_capped_collator(examples):
|
286 |
# Call your existing collator
|
|
|
281 |
# Initialize trainer with memory-optimized settings
|
282 |
progress(0.2, desc="Initializing trainer...")
|
283 |
|
284 |
+
# Setup DeepSpeed config if available
|
285 |
+
try:
|
286 |
+
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
|
287 |
+
use_deepspeed = True
|
288 |
+
print("DeepSpeed available, will use ZeRO-3")
|
289 |
+
|
290 |
+
ds_config = {
|
291 |
+
"zero_optimization": {
|
292 |
+
"stage": 3,
|
293 |
+
"offload_optimizer": {
|
294 |
+
"device": "cpu",
|
295 |
+
"pin_memory": True
|
296 |
+
},
|
297 |
+
"offload_param": {
|
298 |
+
"device": "cpu",
|
299 |
+
"pin_memory": True
|
300 |
+
},
|
301 |
+
"overlap_comm": True,
|
302 |
+
"contiguous_gradients": True,
|
303 |
+
"reduce_bucket_size": 5e7,
|
304 |
+
"stage3_prefetch_bucket_size": 5e7,
|
305 |
+
"stage3_param_persistence_threshold": 1e5
|
306 |
+
},
|
307 |
+
"train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE,
|
308 |
+
"gradient_accumulation_steps": GRAD_ACCUMULATION_STEPS,
|
309 |
+
"fp16": {"enabled": True},
|
310 |
+
"zero_allow_untested_optimizer": True,
|
311 |
+
"aio": {"block_size": 1048576, "queue_depth": 8, "thread_count": 1}
|
312 |
+
}
|
313 |
+
except ImportError:
|
314 |
+
use_deepspeed = False
|
315 |
+
print("DeepSpeed not available, falling back to standard distribution")
|
316 |
+
ds_config = None
|
317 |
+
|
318 |
+
# Define training arguments inside the function
|
319 |
+
training_args = TrainingArguments(
|
320 |
+
output_dir=OUTPUT_TRAINING_DIR,
|
321 |
+
logging_dir=LOGGING_DIR,
|
322 |
+
num_train_epochs=NUM_EPOCHS,
|
323 |
+
per_device_train_batch_size=BATCH_SIZE_PER_DEVICE,
|
324 |
+
gradient_accumulation_steps=GRAD_ACCUMULATION_STEPS,
|
325 |
+
learning_rate=LEARNING_RATE,
|
326 |
+
weight_decay=WEIGHT_DECAY,
|
327 |
+
warmup_ratio=WARMUP_RATIO,
|
328 |
+
lr_scheduler_type=LR_SCHEDULER,
|
329 |
+
report_to="tensorboard",
|
330 |
+
fp16=True,
|
331 |
+
bf16=False,
|
332 |
+
|
333 |
+
# Memory optimization
|
334 |
+
optim="adamw_torch_fused",
|
335 |
+
gradient_checkpointing=True,
|
336 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
337 |
+
|
338 |
+
# Explicit model distribution
|
339 |
+
ddp_find_unused_parameters=False,
|
340 |
+
deepspeed=ds_config if use_deepspeed else None,
|
341 |
+
|
342 |
+
# Other memory-saving settings
|
343 |
+
save_strategy="steps",
|
344 |
+
save_steps=50,
|
345 |
+
logging_steps=10,
|
346 |
+
dataloader_num_workers=0, # Avoid extra memory usage with workers
|
347 |
+
group_by_length=True, # Group samples of similar length
|
348 |
+
max_grad_norm=0.5,
|
349 |
+
)
|
350 |
+
|
351 |
# Optional: try a custom data collator that explicitly caps sequence length
|
352 |
def data_capped_collator(examples):
|
353 |
# Call your existing collator
|