Guetat Youssef commited on
Commit
8f8763e
·
1 Parent(s): 0e7f220
Files changed (1) hide show
  1. app.py +65 -331
app.py CHANGED
@@ -112,272 +112,7 @@ def detect_qa_columns(dataset):
112
  return question_col, answer_col
113
 
114
  def train_model_background(job_id, dataset_name, base_model_name=None):
115
- """Background training function with progress tracking"""
116
- progress = training_jobs[job_id]
117
-
118
- try:
119
- # Create a temporary directory for this job
120
- temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
121
-
122
- # Set environment variables for caching
123
- os.environ['HF_HOME'] = temp_dir
124
- os.environ['TRANSFORMERS_CACHE'] = temp_dir
125
- os.environ['HF_DATASETS_CACHE'] = temp_dir
126
- os.environ['TORCH_HOME'] = temp_dir
127
-
128
- progress.status = "loading_libraries"
129
- progress.message = "Loading required libraries..."
130
-
131
- # Import heavy libraries after setting cache paths
132
- import torch
133
- from datasets import load_dataset, Dataset
134
- from huggingface_hub import login
135
- from transformers import (
136
- AutoModelForCausalLM,
137
- AutoTokenizer,
138
- TrainingArguments,
139
- Trainer,
140
- TrainerCallback,
141
- )
142
- from peft import (
143
- LoraConfig,
144
- get_peft_model,
145
- )
146
-
147
- # === Authentication ===
148
- hf_token = os.getenv('HF_TOKEN')
149
- if hf_token:
150
- login(token=hf_token)
151
-
152
- progress.status = "loading_model"
153
- progress.message = "Loading base model and tokenizer..."
154
-
155
- # === Configuration ===
156
- base_model = base_model_name or "microsoft/DialoGPT-small"
157
- new_model = f"trained-model-{job_id}"
158
- max_length = 256
159
-
160
- # === Load Model and Tokenizer ===
161
- model = AutoModelForCausalLM.from_pretrained(
162
- base_model,
163
- cache_dir=temp_dir,
164
- torch_dtype=torch.float32,
165
- device_map="auto" if torch.cuda.is_available() else "cpu",
166
- trust_remote_code=True
167
- )
168
-
169
- tokenizer = AutoTokenizer.from_pretrained(
170
- base_model,
171
- cache_dir=temp_dir,
172
- trust_remote_code=True
173
- )
174
-
175
- # Add padding token if not present
176
- if tokenizer.pad_token is None:
177
- tokenizer.pad_token = tokenizer.eos_token
178
-
179
- # Resize token embeddings if needed
180
- model.resize_token_embeddings(len(tokenizer))
181
-
182
- progress.status = "preparing_model"
183
- progress.message = "Setting up LoRA configuration..."
184
-
185
- # === LoRA Config ===
186
- peft_config = LoraConfig(
187
- r=8,
188
- lora_alpha=16,
189
- lora_dropout=0.1,
190
- bias="none",
191
- task_type="CAUSAL_LM",
192
- )
193
- model = get_peft_model(model, peft_config)
194
-
195
- progress.status = "loading_dataset"
196
- progress.message = "Loading and preparing dataset..."
197
-
198
- # === Load & Prepare Dataset ===
199
- dataset = load_dataset(
200
- dataset_name,
201
- split="train" if "train" in load_dataset(dataset_name, cache_dir=temp_dir).keys() else "all",
202
- cache_dir=temp_dir,
203
- trust_remote_code=True
204
- )
205
-
206
- # Automatically detect question and answer columns
207
- question_col, answer_col = detect_qa_columns(dataset)
208
-
209
- if not question_col or not answer_col:
210
- raise ValueError("Could not automatically detect question and answer columns in the dataset")
211
-
212
- progress.detected_columns = {"question": question_col, "answer": answer_col}
213
- progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
214
-
215
- # Use subset for faster testing (can be made configurable)
216
- dataset = dataset.shuffle(seed=65).select(range(min(100, len(dataset))))
217
-
218
- # Custom dataset class for proper handling
219
- class CustomDataset(torch.utils.data.Dataset):
220
- def __init__(self, texts, tokenizer, max_length):
221
- self.texts = texts
222
- self.tokenizer = tokenizer
223
- self.max_length = max_length
224
-
225
- def __len__(self):
226
- return len(self.texts)
227
-
228
- def __getitem__(self, idx):
229
- text = self.texts[idx]
230
-
231
- # Tokenize the text
232
- encoding = self.tokenizer(
233
- text,
234
- truncation=True,
235
- padding='max_length',
236
- max_length=self.max_length,
237
- return_tensors='pt'
238
- )
239
-
240
- # Flatten the tensors (remove batch dimension)
241
- input_ids = encoding['input_ids'].squeeze()
242
- attention_mask = encoding['attention_mask'].squeeze()
243
-
244
- # For causal language modeling, labels are the same as input_ids
245
- labels = input_ids.clone()
246
-
247
- # Set labels to -100 for padding tokens (they won't contribute to loss)
248
- labels[attention_mask == 0] = -100
249
-
250
- return {
251
- 'input_ids': input_ids,
252
- 'attention_mask': attention_mask,
253
- 'labels': labels
254
- }
255
-
256
- # Prepare texts using detected columns
257
- texts = []
258
- for item in dataset:
259
- question = str(item[question_col]).strip()
260
- answer = str(item[answer_col]).strip()
261
- text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
262
- texts.append(text)
263
-
264
- # Create custom dataset
265
- train_dataset = CustomDataset(texts, tokenizer, max_length)
266
-
267
- # Calculate total training steps
268
- batch_size = 2
269
- gradient_accumulation_steps = 1
270
- num_epochs = 1
271
-
272
- steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
273
- total_steps = steps_per_epoch * num_epochs
274
-
275
- progress.total_steps = total_steps
276
- progress.status = "training"
277
- progress.message = "Starting training..."
278
-
279
- # === Training Arguments ===
280
- output_dir = os.path.join(temp_dir, new_model)
281
- os.makedirs(output_dir, exist_ok=True)
282
-
283
- training_args = TrainingArguments(
284
- output_dir=output_dir,
285
- per_device_train_batch_size=batch_size,
286
- gradient_accumulation_steps=gradient_accumulation_steps,
287
- num_train_epochs=num_epochs,
288
- logging_steps=1,
289
- save_steps=max(1, total_steps // 2),
290
- save_total_limit=1,
291
- learning_rate=5e-5,
292
- warmup_steps=2,
293
- logging_strategy="steps",
294
- save_strategy="steps",
295
- fp16=False,
296
- bf16=False,
297
- dataloader_num_workers=0,
298
- remove_unused_columns=False,
299
- report_to=None,
300
- prediction_loss_only=True,
301
- )
302
-
303
- # Custom callback to track progress
304
- class ProgressCallback(TrainerCallback):
305
- def __init__(self, progress_tracker):
306
- self.progress_tracker = progress_tracker
307
- self.last_update = time.time()
308
-
309
- def on_log(self, args, state, control, model=None, logs=None, **kwargs):
310
- current_time = time.time()
311
- # Update every 3 seconds
312
- if current_time - self.last_update >= 3:
313
- self.progress_tracker.update_progress(
314
- state.global_step,
315
- state.max_steps,
316
- f"Training step {state.global_step}/{state.max_steps}"
317
- )
318
- self.last_update = current_time
319
-
320
- # Log training metrics if available
321
- if logs:
322
- loss = logs.get('train_loss', logs.get('loss', 'N/A'))
323
- self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}"
324
-
325
- def on_train_begin(self, args, state, control, **kwargs):
326
- self.progress_tracker.status = "training"
327
- self.progress_tracker.message = "Training started..."
328
-
329
- def on_train_end(self, args, state, control, **kwargs):
330
- self.progress_tracker.status = "saving"
331
- self.progress_tracker.message = "Training complete, saving model..."
332
-
333
- # === Trainer Initialization ===
334
- trainer = Trainer(
335
- model=model,
336
- args=training_args,
337
- train_dataset=train_dataset,
338
- callbacks=[ProgressCallback(progress)],
339
- tokenizer=tokenizer,
340
- )
341
-
342
- # === Train & Save ===
343
- trainer.train()
344
- trainer.save_model(output_dir)
345
- tokenizer.save_pretrained(output_dir)
346
-
347
- # Save model info
348
- progress.model_path = output_dir
349
- progress.status = "completed"
350
- progress.progress = 100
351
- progress.message = f"Training completed! Model ready for download."
352
-
353
- # Keep the temp directory for download (cleanup after 1 hour)
354
- def cleanup_temp_dir():
355
- time.sleep(3600) # Wait 1 hour before cleanup
356
- try:
357
- shutil.rmtree(temp_dir)
358
- # Remove from training_jobs after cleanup
359
- if job_id in training_jobs:
360
- del training_jobs[job_id]
361
- except:
362
- pass
363
-
364
- cleanup_thread = threading.Thread(target=cleanup_temp_dir)
365
- cleanup_thread.daemon = True
366
- cleanup_thread.start()
367
-
368
- except Exception as e:
369
- progress.status = "error"
370
- progress.error = str(e)
371
- progress.message = f"Training failed: {str(e)}"
372
-
373
- # Clean up on error
374
- try:
375
- if 'temp_dir' in locals():
376
- shutil.rmtree(temp_dir)
377
- except:
378
- pass
379
- def train_model_background(job_id, dataset_name, base_model_name=None):
380
- """Background training function with improved configuration"""
381
  progress = training_jobs[job_id]
382
 
383
  try:
@@ -419,11 +154,10 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
419
  progress.status = "loading_model"
420
  progress.message = "Loading base model and tokenizer..."
421
 
422
- # === Better Model Selection ===
423
- # Use a more suitable model for medical conversations
424
- base_model = base_model_name or "microsoft/DialoGPT-medium" # Better than small
425
  new_model = f"trained-model-{job_id}"
426
- max_length = 512 # Increased for better context
427
 
428
  # === Load Model and Tokenizer ===
429
  model = AutoModelForCausalLM.from_pretrained(
@@ -439,7 +173,7 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
439
  base_model,
440
  cache_dir=temp_dir,
441
  trust_remote_code=True,
442
- padding_side="right" # Important for causal LM
443
  )
444
 
445
  # Add padding token if not present
@@ -451,16 +185,16 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
451
  model.resize_token_embeddings(len(tokenizer))
452
 
453
  progress.status = "preparing_model"
454
- progress.message = "Setting up improved LoRA configuration..."
455
 
456
- # === Better LoRA Config ===
457
  peft_config = LoraConfig(
458
- r=16, # Increased rank for better learning
459
- lora_alpha=32, # Increased alpha
460
- lora_dropout=0.05, # Reduced dropout
461
  bias="none",
462
  task_type=TaskType.CAUSAL_LM,
463
- target_modules=["c_attn", "c_proj"], # Target specific modules for DialoGPT
464
  )
465
  model = get_peft_model(model, peft_config)
466
 
@@ -487,38 +221,61 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
487
  progress.detected_columns = {"question": question_col, "answer": answer_col}
488
  progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
489
 
490
- # Use more data for better training
491
- dataset_size = min(1000, len(dataset)) # Increased from 100 to 1000
492
  dataset = dataset.shuffle(seed=42).select(range(dataset_size))
493
 
494
- # === Better Text Formatting ===
495
  def format_conversation(example):
496
  question = str(example[question_col]).strip()
497
  answer = str(example[answer_col]).strip()
498
 
499
- # Better formatting with special tokens
500
- conversation = f"<|user|>{question}<|assistant|>{answer}<|endoftext|>"
501
  return {"text": conversation}
502
 
503
  # Apply formatting
504
- dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
505
 
506
  # Filter out very short or very long examples
507
- dataset = dataset.filter(lambda x: 10 < len(x["text"]) < max_length * 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
- # === Improved Training Arguments ===
 
 
 
 
 
 
 
 
510
  batch_size = 4 if torch.cuda.is_available() else 2
511
  gradient_accumulation_steps = 2
512
- num_epochs = 3 # Increased epochs
513
- learning_rate = 2e-4 # Better learning rate
514
 
515
- steps_per_epoch = len(dataset) // (batch_size * gradient_accumulation_steps)
516
  total_steps = steps_per_epoch * num_epochs
517
- warmup_steps = max(10, total_steps // 10) # 10% warmup
518
 
519
  progress.total_steps = total_steps
520
  progress.status = "training"
521
- progress.message = "Starting improved training..."
522
 
523
  output_dir = os.path.join(temp_dir, new_model)
524
  os.makedirs(output_dir, exist_ok=True)
@@ -552,32 +309,9 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
552
  # === Data Collator ===
553
  data_collator = DataCollatorForLanguageModeling(
554
  tokenizer=tokenizer,
555
- mlm=False, # We're doing causal LM, not masked LM
556
  return_tensors="pt",
557
- pad_to_multiple_of=8,
558
- )
559
-
560
- # Custom tokenization function
561
- def tokenize_function(examples):
562
- # Tokenize the text
563
- tokenized = tokenizer(
564
- examples["text"],
565
- truncation=True,
566
- padding=False, # Will be handled by data collator
567
- max_length=max_length,
568
- return_tensors=None,
569
- )
570
-
571
- # For causal LM, labels are the same as input_ids
572
- tokenized["labels"] = tokenized["input_ids"].copy()
573
- return tokenized
574
-
575
- # Tokenize dataset
576
- tokenized_dataset = dataset.map(
577
- tokenize_function,
578
- batched=True,
579
- remove_columns=dataset.column_names,
580
- desc="Tokenizing dataset",
581
  )
582
 
583
  # Custom callback to track progress
@@ -588,7 +322,6 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
588
 
589
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
590
  current_time = time.time()
591
- # Update every 5 seconds
592
  if current_time - self.last_update >= 5:
593
  self.progress_tracker.update_progress(
594
  state.global_step,
@@ -597,19 +330,20 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
597
  )
598
  self.last_update = current_time
599
 
600
- # Log training metrics if available
601
  if logs:
602
  loss = logs.get('train_loss', logs.get('loss', 'N/A'))
603
  lr = logs.get('learning_rate', 'N/A')
604
- self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss:.4f}, LR: {lr}"
 
 
605
 
606
  def on_train_begin(self, args, state, control, **kwargs):
607
  self.progress_tracker.status = "training"
608
- self.progress_tracker.message = "Training started with improved configuration..."
609
 
610
  def on_train_end(self, args, state, control, **kwargs):
611
  self.progress_tracker.status = "saving"
612
- self.progress_tracker.message = "Training complete, saving improved model..."
613
 
614
  # === Trainer Initialization ===
615
  trainer = Trainer(
@@ -628,15 +362,14 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
628
  trainer.save_model(output_dir)
629
  tokenizer.save_pretrained(output_dir)
630
 
631
- # Also save the base model name for inference
632
  with open(os.path.join(output_dir, "base_model.txt"), "w") as f:
633
  f.write(base_model)
634
 
635
- # Save training info
636
  training_info = {
637
  "base_model": base_model,
638
  "dataset_name": dataset_name,
639
- "dataset_size": len(dataset),
640
  "max_length": max_length,
641
  "batch_size": batch_size,
642
  "learning_rate": learning_rate,
@@ -649,18 +382,17 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
649
  import json
650
  json.dump(training_info, f, indent=2)
651
 
652
- # Save model info
653
  progress.model_path = output_dir
654
  progress.status = "completed"
655
  progress.progress = 100
656
- progress.message = f"Improved training completed! Model ready for download."
657
 
658
- # Keep the temp directory for download (cleanup after 2 hours for larger model)
659
  def cleanup_temp_dir():
660
  time.sleep(7200) # Wait 2 hours before cleanup
661
  try:
662
  shutil.rmtree(temp_dir)
663
- # Remove from training_jobs after cleanup
664
  if job_id in training_jobs:
665
  del training_jobs[job_id]
666
  except:
@@ -681,6 +413,7 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
681
  shutil.rmtree(temp_dir)
682
  except:
683
  pass
 
684
  def create_model_zip(model_path, job_id):
685
  """Create a zip file containing the trained model"""
686
  memory_file = io.BytesIO()
@@ -694,6 +427,7 @@ def create_model_zip(model_path, job_id):
694
 
695
  memory_file.seek(0)
696
  return memory_file
 
697
  # ============== API ROUTES ==============
698
  @app.route('/api/train', methods=['POST'])
699
  def start_training():
@@ -701,9 +435,9 @@ def start_training():
701
  try:
702
  data = request.get_json() if request.is_json else {}
703
  dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
704
- base_model_name = data.get('base_model', 'microsoft/DialoGPT-small')
705
 
706
- job_id = str(uuid.uuid4())[:8] # Short UUID
707
  progress = TrainingProgress(job_id)
708
  training_jobs[job_id] = progress
709
 
@@ -797,7 +531,7 @@ def home():
797
  "url": "/api/train",
798
  "body": {
799
  "dataset_name": "your-dataset-name",
800
- "base_model": "microsoft/DialoGPT-small"
801
  }
802
  }
803
  }
@@ -808,5 +542,5 @@ def health():
808
  return jsonify({"status": "healthy"})
809
 
810
  if __name__ == '__main__':
811
- port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
812
  app.run(host='0.0.0.0', port=port, debug=False)
 
112
  return question_col, answer_col
113
 
114
  def train_model_background(job_id, dataset_name, base_model_name=None):
115
+ """Background training function with fixed tokenization"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  progress = training_jobs[job_id]
117
 
118
  try:
 
154
  progress.status = "loading_model"
155
  progress.message = "Loading base model and tokenizer..."
156
 
157
+ # === Model Configuration ===
158
+ base_model = base_model_name or "microsoft/DialoGPT-medium"
 
159
  new_model = f"trained-model-{job_id}"
160
+ max_length = 512
161
 
162
  # === Load Model and Tokenizer ===
163
  model = AutoModelForCausalLM.from_pretrained(
 
173
  base_model,
174
  cache_dir=temp_dir,
175
  trust_remote_code=True,
176
+ padding_side="right"
177
  )
178
 
179
  # Add padding token if not present
 
185
  model.resize_token_embeddings(len(tokenizer))
186
 
187
  progress.status = "preparing_model"
188
+ progress.message = "Setting up LoRA configuration..."
189
 
190
+ # === LoRA Config ===
191
  peft_config = LoraConfig(
192
+ r=16,
193
+ lora_alpha=32,
194
+ lora_dropout=0.05,
195
  bias="none",
196
  task_type=TaskType.CAUSAL_LM,
197
+ target_modules=["c_attn", "c_proj"],
198
  )
199
  model = get_peft_model(model, peft_config)
200
 
 
221
  progress.detected_columns = {"question": question_col, "answer": answer_col}
222
  progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
223
 
224
+ # Use subset for faster training
225
+ dataset_size = min(500, len(dataset))
226
  dataset = dataset.shuffle(seed=42).select(range(dataset_size))
227
 
228
+ # === Fixed Text Formatting ===
229
  def format_conversation(example):
230
  question = str(example[question_col]).strip()
231
  answer = str(example[answer_col]).strip()
232
 
233
+ # Simple format that works well with tokenizer
234
+ conversation = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
235
  return {"text": conversation}
236
 
237
  # Apply formatting
238
+ formatted_dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
239
 
240
  # Filter out very short or very long examples
241
+ formatted_dataset = formatted_dataset.filter(lambda x: 10 < len(x["text"]) < max_length * 3)
242
+
243
+ # === Fixed Tokenization Function ===
244
+ def tokenize_function(examples):
245
+ # Tokenize the text
246
+ model_inputs = tokenizer(
247
+ examples["text"],
248
+ truncation=True,
249
+ padding=False, # Will be handled by data collator
250
+ max_length=max_length,
251
+ return_tensors=None,
252
+ )
253
+
254
+ # For causal LM, labels are the same as input_ids
255
+ model_inputs["labels"] = model_inputs["input_ids"].copy()
256
+ return model_inputs
257
 
258
+ # Tokenize dataset
259
+ tokenized_dataset = formatted_dataset.map(
260
+ tokenize_function,
261
+ batched=True,
262
+ remove_columns=formatted_dataset.column_names,
263
+ desc="Tokenizing dataset",
264
+ )
265
+
266
+ # === Training Configuration ===
267
  batch_size = 4 if torch.cuda.is_available() else 2
268
  gradient_accumulation_steps = 2
269
+ num_epochs = 2
270
+ learning_rate = 2e-4
271
 
272
+ steps_per_epoch = len(tokenized_dataset) // (batch_size * gradient_accumulation_steps)
273
  total_steps = steps_per_epoch * num_epochs
274
+ warmup_steps = max(10, total_steps // 10)
275
 
276
  progress.total_steps = total_steps
277
  progress.status = "training"
278
+ progress.message = "Starting training..."
279
 
280
  output_dir = os.path.join(temp_dir, new_model)
281
  os.makedirs(output_dir, exist_ok=True)
 
309
  # === Data Collator ===
310
  data_collator = DataCollatorForLanguageModeling(
311
  tokenizer=tokenizer,
312
+ mlm=False,
313
  return_tensors="pt",
314
+ pad_to_multiple_of=8 if torch.cuda.is_available() else None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  )
316
 
317
  # Custom callback to track progress
 
322
 
323
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
324
  current_time = time.time()
 
325
  if current_time - self.last_update >= 5:
326
  self.progress_tracker.update_progress(
327
  state.global_step,
 
330
  )
331
  self.last_update = current_time
332
 
 
333
  if logs:
334
  loss = logs.get('train_loss', logs.get('loss', 'N/A'))
335
  lr = logs.get('learning_rate', 'N/A')
336
+ if isinstance(loss, (int, float)):
337
+ loss = f"{loss:.4f}"
338
+ self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}, LR: {lr}"
339
 
340
  def on_train_begin(self, args, state, control, **kwargs):
341
  self.progress_tracker.status = "training"
342
+ self.progress_tracker.message = "Training started..."
343
 
344
  def on_train_end(self, args, state, control, **kwargs):
345
  self.progress_tracker.status = "saving"
346
+ self.progress_tracker.message = "Training complete, saving model..."
347
 
348
  # === Trainer Initialization ===
349
  trainer = Trainer(
 
362
  trainer.save_model(output_dir)
363
  tokenizer.save_pretrained(output_dir)
364
 
365
+ # Save additional info
366
  with open(os.path.join(output_dir, "base_model.txt"), "w") as f:
367
  f.write(base_model)
368
 
 
369
  training_info = {
370
  "base_model": base_model,
371
  "dataset_name": dataset_name,
372
+ "dataset_size": len(tokenized_dataset),
373
  "max_length": max_length,
374
  "batch_size": batch_size,
375
  "learning_rate": learning_rate,
 
382
  import json
383
  json.dump(training_info, f, indent=2)
384
 
385
+ # Update progress
386
  progress.model_path = output_dir
387
  progress.status = "completed"
388
  progress.progress = 100
389
+ progress.message = f"Training completed successfully! Model ready for download."
390
 
391
+ # Keep the temp directory for download
392
  def cleanup_temp_dir():
393
  time.sleep(7200) # Wait 2 hours before cleanup
394
  try:
395
  shutil.rmtree(temp_dir)
 
396
  if job_id in training_jobs:
397
  del training_jobs[job_id]
398
  except:
 
413
  shutil.rmtree(temp_dir)
414
  except:
415
  pass
416
+
417
  def create_model_zip(model_path, job_id):
418
  """Create a zip file containing the trained model"""
419
  memory_file = io.BytesIO()
 
427
 
428
  memory_file.seek(0)
429
  return memory_file
430
+
431
  # ============== API ROUTES ==============
432
  @app.route('/api/train', methods=['POST'])
433
  def start_training():
 
435
  try:
436
  data = request.get_json() if request.is_json else {}
437
  dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
438
+ base_model_name = data.get('base_model', 'microsoft/DialoGPT-medium')
439
 
440
+ job_id = str(uuid.uuid4())[:8]
441
  progress = TrainingProgress(job_id)
442
  training_jobs[job_id] = progress
443
 
 
531
  "url": "/api/train",
532
  "body": {
533
  "dataset_name": "your-dataset-name",
534
+ "base_model": "microsoft/DialoGPT-medium"
535
  }
536
  }
537
  }
 
542
  return jsonify({"status": "healthy"})
543
 
544
  if __name__ == '__main__':
545
+ port = int(os.environ.get('PORT', 7860))
546
  app.run(host='0.0.0.0', port=port, debug=False)