Guetat Youssef commited on
Commit
0e7f220
·
1 Parent(s): fbe7ca1
Files changed (1) hide show
  1. app.py +304 -0
app.py CHANGED
@@ -376,7 +376,311 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
376
  shutil.rmtree(temp_dir)
377
  except:
378
  pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  def create_model_zip(model_path, job_id):
381
  """Create a zip file containing the trained model"""
382
  memory_file = io.BytesIO()
 
376
  shutil.rmtree(temp_dir)
377
  except:
378
  pass
379
+ def train_model_background(job_id, dataset_name, base_model_name=None):
380
+ """Background training function with improved configuration"""
381
+ progress = training_jobs[job_id]
382
+
383
+ try:
384
+ # Create a temporary directory for this job
385
+ temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
386
+
387
+ # Set environment variables for caching
388
+ os.environ['HF_HOME'] = temp_dir
389
+ os.environ['TRANSFORMERS_CACHE'] = temp_dir
390
+ os.environ['HF_DATASETS_CACHE'] = temp_dir
391
+ os.environ['TORCH_HOME'] = temp_dir
392
+
393
+ progress.status = "loading_libraries"
394
+ progress.message = "Loading required libraries..."
395
+
396
+ # Import heavy libraries after setting cache paths
397
+ import torch
398
+ from datasets import load_dataset, Dataset
399
+ from huggingface_hub import login
400
+ from transformers import (
401
+ AutoModelForCausalLM,
402
+ AutoTokenizer,
403
+ TrainingArguments,
404
+ Trainer,
405
+ TrainerCallback,
406
+ DataCollatorForLanguageModeling
407
+ )
408
+ from peft import (
409
+ LoraConfig,
410
+ get_peft_model,
411
+ TaskType
412
+ )
413
+
414
+ # === Authentication ===
415
+ hf_token = os.getenv('HF_TOKEN')
416
+ if hf_token:
417
+ login(token=hf_token)
418
+
419
+ progress.status = "loading_model"
420
+ progress.message = "Loading base model and tokenizer..."
421
+
422
+ # === Better Model Selection ===
423
+ # Use a more suitable model for medical conversations
424
+ base_model = base_model_name or "microsoft/DialoGPT-medium" # Better than small
425
+ new_model = f"trained-model-{job_id}"
426
+ max_length = 512 # Increased for better context
427
+
428
+ # === Load Model and Tokenizer ===
429
+ model = AutoModelForCausalLM.from_pretrained(
430
+ base_model,
431
+ cache_dir=temp_dir,
432
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
433
+ device_map="auto" if torch.cuda.is_available() else "cpu",
434
+ trust_remote_code=True,
435
+ low_cpu_mem_usage=True
436
+ )
437
+
438
+ tokenizer = AutoTokenizer.from_pretrained(
439
+ base_model,
440
+ cache_dir=temp_dir,
441
+ trust_remote_code=True,
442
+ padding_side="right" # Important for causal LM
443
+ )
444
+
445
+ # Add padding token if not present
446
+ if tokenizer.pad_token is None:
447
+ tokenizer.pad_token = tokenizer.eos_token
448
+ tokenizer.pad_token_id = tokenizer.eos_token_id
449
+
450
+ # Resize token embeddings if needed
451
+ model.resize_token_embeddings(len(tokenizer))
452
+
453
+ progress.status = "preparing_model"
454
+ progress.message = "Setting up improved LoRA configuration..."
455
+
456
+ # === Better LoRA Config ===
457
+ peft_config = LoraConfig(
458
+ r=16, # Increased rank for better learning
459
+ lora_alpha=32, # Increased alpha
460
+ lora_dropout=0.05, # Reduced dropout
461
+ bias="none",
462
+ task_type=TaskType.CAUSAL_LM,
463
+ target_modules=["c_attn", "c_proj"], # Target specific modules for DialoGPT
464
+ )
465
+ model = get_peft_model(model, peft_config)
466
+
467
+ # Print trainable parameters
468
+ model.print_trainable_parameters()
469
+
470
+ progress.status = "loading_dataset"
471
+ progress.message = "Loading and preparing dataset..."
472
+
473
+ # === Load & Prepare Dataset ===
474
+ dataset = load_dataset(
475
+ dataset_name,
476
+ split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
477
+ cache_dir=temp_dir,
478
+ trust_remote_code=True
479
+ )
480
+
481
+ # Automatically detect question and answer columns
482
+ question_col, answer_col = detect_qa_columns(dataset)
483
+
484
+ if not question_col or not answer_col:
485
+ raise ValueError("Could not automatically detect question and answer columns in the dataset")
486
+
487
+ progress.detected_columns = {"question": question_col, "answer": answer_col}
488
+ progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
489
+
490
+ # Use more data for better training
491
+ dataset_size = min(1000, len(dataset)) # Increased from 100 to 1000
492
+ dataset = dataset.shuffle(seed=42).select(range(dataset_size))
493
+
494
+ # === Better Text Formatting ===
495
+ def format_conversation(example):
496
+ question = str(example[question_col]).strip()
497
+ answer = str(example[answer_col]).strip()
498
+
499
+ # Better formatting with special tokens
500
+ conversation = f"<|user|>{question}<|assistant|>{answer}<|endoftext|>"
501
+ return {"text": conversation}
502
+
503
+ # Apply formatting
504
+ dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
505
+
506
+ # Filter out very short or very long examples
507
+ dataset = dataset.filter(lambda x: 10 < len(x["text"]) < max_length * 2)
508
+
509
+ # === Improved Training Arguments ===
510
+ batch_size = 4 if torch.cuda.is_available() else 2
511
+ gradient_accumulation_steps = 2
512
+ num_epochs = 3 # Increased epochs
513
+ learning_rate = 2e-4 # Better learning rate
514
+
515
+ steps_per_epoch = len(dataset) // (batch_size * gradient_accumulation_steps)
516
+ total_steps = steps_per_epoch * num_epochs
517
+ warmup_steps = max(10, total_steps // 10) # 10% warmup
518
+
519
+ progress.total_steps = total_steps
520
+ progress.status = "training"
521
+ progress.message = "Starting improved training..."
522
+
523
+ output_dir = os.path.join(temp_dir, new_model)
524
+ os.makedirs(output_dir, exist_ok=True)
525
+
526
+ training_args = TrainingArguments(
527
+ output_dir=output_dir,
528
+ per_device_train_batch_size=batch_size,
529
+ gradient_accumulation_steps=gradient_accumulation_steps,
530
+ num_train_epochs=num_epochs,
531
+ learning_rate=learning_rate,
532
+ warmup_steps=warmup_steps,
533
+ logging_steps=5,
534
+ save_steps=max(10, total_steps // 4),
535
+ save_total_limit=2,
536
+ evaluation_strategy="no",
537
+ logging_strategy="steps",
538
+ save_strategy="steps",
539
+ fp16=torch.cuda.is_available(),
540
+ bf16=False,
541
+ dataloader_num_workers=0,
542
+ remove_unused_columns=False,
543
+ report_to=None,
544
+ prediction_loss_only=True,
545
+ optim="adamw_torch",
546
+ weight_decay=0.01,
547
+ lr_scheduler_type="cosine",
548
+ gradient_checkpointing=True,
549
+ dataloader_pin_memory=False,
550
+ )
551
+
552
+ # === Data Collator ===
553
+ data_collator = DataCollatorForLanguageModeling(
554
+ tokenizer=tokenizer,
555
+ mlm=False, # We're doing causal LM, not masked LM
556
+ return_tensors="pt",
557
+ pad_to_multiple_of=8,
558
+ )
559
+
560
+ # Custom tokenization function
561
+ def tokenize_function(examples):
562
+ # Tokenize the text
563
+ tokenized = tokenizer(
564
+ examples["text"],
565
+ truncation=True,
566
+ padding=False, # Will be handled by data collator
567
+ max_length=max_length,
568
+ return_tensors=None,
569
+ )
570
+
571
+ # For causal LM, labels are the same as input_ids
572
+ tokenized["labels"] = tokenized["input_ids"].copy()
573
+ return tokenized
574
+
575
+ # Tokenize dataset
576
+ tokenized_dataset = dataset.map(
577
+ tokenize_function,
578
+ batched=True,
579
+ remove_columns=dataset.column_names,
580
+ desc="Tokenizing dataset",
581
+ )
582
+
583
+ # Custom callback to track progress
584
+ class ProgressCallback(TrainerCallback):
585
+ def __init__(self, progress_tracker):
586
+ self.progress_tracker = progress_tracker
587
+ self.last_update = time.time()
588
+
589
+ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
590
+ current_time = time.time()
591
+ # Update every 5 seconds
592
+ if current_time - self.last_update >= 5:
593
+ self.progress_tracker.update_progress(
594
+ state.global_step,
595
+ state.max_steps,
596
+ f"Training step {state.global_step}/{state.max_steps}"
597
+ )
598
+ self.last_update = current_time
599
+
600
+ # Log training metrics if available
601
+ if logs:
602
+ loss = logs.get('train_loss', logs.get('loss', 'N/A'))
603
+ lr = logs.get('learning_rate', 'N/A')
604
+ self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss:.4f}, LR: {lr}"
605
+
606
+ def on_train_begin(self, args, state, control, **kwargs):
607
+ self.progress_tracker.status = "training"
608
+ self.progress_tracker.message = "Training started with improved configuration..."
609
+
610
+ def on_train_end(self, args, state, control, **kwargs):
611
+ self.progress_tracker.status = "saving"
612
+ self.progress_tracker.message = "Training complete, saving improved model..."
613
 
614
+ # === Trainer Initialization ===
615
+ trainer = Trainer(
616
+ model=model,
617
+ args=training_args,
618
+ train_dataset=tokenized_dataset,
619
+ data_collator=data_collator,
620
+ callbacks=[ProgressCallback(progress)],
621
+ tokenizer=tokenizer,
622
+ )
623
+
624
+ # === Train & Save ===
625
+ trainer.train()
626
+
627
+ # Save the model properly
628
+ trainer.save_model(output_dir)
629
+ tokenizer.save_pretrained(output_dir)
630
+
631
+ # Also save the base model name for inference
632
+ with open(os.path.join(output_dir, "base_model.txt"), "w") as f:
633
+ f.write(base_model)
634
+
635
+ # Save training info
636
+ training_info = {
637
+ "base_model": base_model,
638
+ "dataset_name": dataset_name,
639
+ "dataset_size": len(dataset),
640
+ "max_length": max_length,
641
+ "batch_size": batch_size,
642
+ "learning_rate": learning_rate,
643
+ "num_epochs": num_epochs,
644
+ "total_steps": total_steps,
645
+ "detected_columns": progress.detected_columns
646
+ }
647
+
648
+ with open(os.path.join(output_dir, "training_info.json"), "w") as f:
649
+ import json
650
+ json.dump(training_info, f, indent=2)
651
+
652
+ # Save model info
653
+ progress.model_path = output_dir
654
+ progress.status = "completed"
655
+ progress.progress = 100
656
+ progress.message = f"Improved training completed! Model ready for download."
657
+
658
+ # Keep the temp directory for download (cleanup after 2 hours for larger model)
659
+ def cleanup_temp_dir():
660
+ time.sleep(7200) # Wait 2 hours before cleanup
661
+ try:
662
+ shutil.rmtree(temp_dir)
663
+ # Remove from training_jobs after cleanup
664
+ if job_id in training_jobs:
665
+ del training_jobs[job_id]
666
+ except:
667
+ pass
668
+
669
+ cleanup_thread = threading.Thread(target=cleanup_temp_dir)
670
+ cleanup_thread.daemon = True
671
+ cleanup_thread.start()
672
+
673
+ except Exception as e:
674
+ progress.status = "error"
675
+ progress.error = str(e)
676
+ progress.message = f"Training failed: {str(e)}"
677
+
678
+ # Clean up on error
679
+ try:
680
+ if 'temp_dir' in locals():
681
+ shutil.rmtree(temp_dir)
682
+ except:
683
+ pass
684
  def create_model_zip(model_path, job_id):
685
  """Create a zip file containing the trained model"""
686
  memory_file = io.BytesIO()