Guetat Youssef commited on
Commit
c2215d0
·
1 Parent(s): 8f8763e
Files changed (1) hide show
  1. app.py +85 -123
app.py CHANGED
@@ -112,7 +112,7 @@ def detect_qa_columns(dataset):
112
  return question_col, answer_col
113
 
114
  def train_model_background(job_id, dataset_name, base_model_name=None):
115
- """Background training function with fixed tokenization"""
116
  progress = training_jobs[job_id]
117
 
118
  try:
@@ -138,12 +138,10 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
138
  TrainingArguments,
139
  Trainer,
140
  TrainerCallback,
141
- DataCollatorForLanguageModeling
142
  )
143
  from peft import (
144
  LoraConfig,
145
  get_peft_model,
146
- TaskType
147
  )
148
 
149
  # === Authentication ===
@@ -154,32 +152,29 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
154
  progress.status = "loading_model"
155
  progress.message = "Loading base model and tokenizer..."
156
 
157
- # === Model Configuration ===
158
- base_model = base_model_name or "microsoft/DialoGPT-medium"
159
  new_model = f"trained-model-{job_id}"
160
- max_length = 512
161
 
162
  # === Load Model and Tokenizer ===
163
  model = AutoModelForCausalLM.from_pretrained(
164
  base_model,
165
  cache_dir=temp_dir,
166
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
167
  device_map="auto" if torch.cuda.is_available() else "cpu",
168
- trust_remote_code=True,
169
- low_cpu_mem_usage=True
170
  )
171
 
172
  tokenizer = AutoTokenizer.from_pretrained(
173
  base_model,
174
  cache_dir=temp_dir,
175
- trust_remote_code=True,
176
- padding_side="right"
177
  )
178
 
179
  # Add padding token if not present
180
  if tokenizer.pad_token is None:
181
  tokenizer.pad_token = tokenizer.eos_token
182
- tokenizer.pad_token_id = tokenizer.eos_token_id
183
 
184
  # Resize token embeddings if needed
185
  model.resize_token_embeddings(len(tokenizer))
@@ -189,17 +184,13 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
189
 
190
  # === LoRA Config ===
191
  peft_config = LoraConfig(
192
- r=16,
193
- lora_alpha=32,
194
- lora_dropout=0.05,
195
  bias="none",
196
- task_type=TaskType.CAUSAL_LM,
197
- target_modules=["c_attn", "c_proj"],
198
  )
199
  model = get_peft_model(model, peft_config)
200
-
201
- # Print trainable parameters
202
- model.print_trainable_parameters()
203
 
204
  progress.status = "loading_dataset"
205
  progress.message = "Loading and preparing dataset..."
@@ -221,62 +212,71 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
221
  progress.detected_columns = {"question": question_col, "answer": answer_col}
222
  progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
223
 
224
- # Use subset for faster training
225
- dataset_size = min(500, len(dataset))
226
- dataset = dataset.shuffle(seed=42).select(range(dataset_size))
227
 
228
- # === Fixed Text Formatting ===
229
- def format_conversation(example):
230
- question = str(example[question_col]).strip()
231
- answer = str(example[answer_col]).strip()
232
-
233
- # Simple format that works well with tokenizer
234
- conversation = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
235
- return {"text": conversation}
236
-
237
- # Apply formatting
238
- formatted_dataset = dataset.map(format_conversation, remove_columns=dataset.column_names)
239
-
240
- # Filter out very short or very long examples
241
- formatted_dataset = formatted_dataset.filter(lambda x: 10 < len(x["text"]) < max_length * 3)
242
 
243
- # === Fixed Tokenization Function ===
244
- def tokenize_function(examples):
245
- # Tokenize the text
246
- model_inputs = tokenizer(
247
- examples["text"],
248
- truncation=True,
249
- padding=False, # Will be handled by data collator
250
- max_length=max_length,
251
- return_tensors=None,
252
- )
253
-
254
- # For causal LM, labels are the same as input_ids
255
- model_inputs["labels"] = model_inputs["input_ids"].copy()
256
- return model_inputs
257
 
258
- # Tokenize dataset
259
- tokenized_dataset = formatted_dataset.map(
260
- tokenize_function,
261
- batched=True,
262
- remove_columns=formatted_dataset.column_names,
263
- desc="Tokenizing dataset",
264
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
- # === Training Configuration ===
267
- batch_size = 4 if torch.cuda.is_available() else 2
268
- gradient_accumulation_steps = 2
269
- num_epochs = 2
270
- learning_rate = 2e-4
271
 
272
- steps_per_epoch = len(tokenized_dataset) // (batch_size * gradient_accumulation_steps)
273
  total_steps = steps_per_epoch * num_epochs
274
- warmup_steps = max(10, total_steps // 10)
275
 
276
  progress.total_steps = total_steps
277
  progress.status = "training"
278
  progress.message = "Starting training..."
279
 
 
280
  output_dir = os.path.join(temp_dir, new_model)
281
  os.makedirs(output_dir, exist_ok=True)
282
 
@@ -285,33 +285,19 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
285
  per_device_train_batch_size=batch_size,
286
  gradient_accumulation_steps=gradient_accumulation_steps,
287
  num_train_epochs=num_epochs,
288
- learning_rate=learning_rate,
289
- warmup_steps=warmup_steps,
290
- logging_steps=5,
291
- save_steps=max(10, total_steps // 4),
292
- save_total_limit=2,
293
- evaluation_strategy="no",
294
  logging_strategy="steps",
295
  save_strategy="steps",
296
- fp16=torch.cuda.is_available(),
297
  bf16=False,
298
  dataloader_num_workers=0,
299
  remove_unused_columns=False,
300
  report_to=None,
301
  prediction_loss_only=True,
302
- optim="adamw_torch",
303
- weight_decay=0.01,
304
- lr_scheduler_type="cosine",
305
- gradient_checkpointing=True,
306
- dataloader_pin_memory=False,
307
- )
308
-
309
- # === Data Collator ===
310
- data_collator = DataCollatorForLanguageModeling(
311
- tokenizer=tokenizer,
312
- mlm=False,
313
- return_tensors="pt",
314
- pad_to_multiple_of=8 if torch.cuda.is_available() else None,
315
  )
316
 
317
  # Custom callback to track progress
@@ -322,7 +308,8 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
322
 
323
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
324
  current_time = time.time()
325
- if current_time - self.last_update >= 5:
 
326
  self.progress_tracker.update_progress(
327
  state.global_step,
328
  state.max_steps,
@@ -330,12 +317,10 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
330
  )
331
  self.last_update = current_time
332
 
 
333
  if logs:
334
  loss = logs.get('train_loss', logs.get('loss', 'N/A'))
335
- lr = logs.get('learning_rate', 'N/A')
336
- if isinstance(loss, (int, float)):
337
- loss = f"{loss:.4f}"
338
- self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}, LR: {lr}"
339
 
340
  def on_train_begin(self, args, state, control, **kwargs):
341
  self.progress_tracker.status = "training"
@@ -349,50 +334,28 @@ def train_model_background(job_id, dataset_name, base_model_name=None):
349
  trainer = Trainer(
350
  model=model,
351
  args=training_args,
352
- train_dataset=tokenized_dataset,
353
- data_collator=data_collator,
354
  callbacks=[ProgressCallback(progress)],
355
  tokenizer=tokenizer,
356
  )
357
 
358
  # === Train & Save ===
359
  trainer.train()
360
-
361
- # Save the model properly
362
  trainer.save_model(output_dir)
363
  tokenizer.save_pretrained(output_dir)
364
 
365
- # Save additional info
366
- with open(os.path.join(output_dir, "base_model.txt"), "w") as f:
367
- f.write(base_model)
368
-
369
- training_info = {
370
- "base_model": base_model,
371
- "dataset_name": dataset_name,
372
- "dataset_size": len(tokenized_dataset),
373
- "max_length": max_length,
374
- "batch_size": batch_size,
375
- "learning_rate": learning_rate,
376
- "num_epochs": num_epochs,
377
- "total_steps": total_steps,
378
- "detected_columns": progress.detected_columns
379
- }
380
-
381
- with open(os.path.join(output_dir, "training_info.json"), "w") as f:
382
- import json
383
- json.dump(training_info, f, indent=2)
384
-
385
- # Update progress
386
  progress.model_path = output_dir
387
  progress.status = "completed"
388
  progress.progress = 100
389
- progress.message = f"Training completed successfully! Model ready for download."
390
 
391
- # Keep the temp directory for download
392
  def cleanup_temp_dir():
393
- time.sleep(7200) # Wait 2 hours before cleanup
394
  try:
395
  shutil.rmtree(temp_dir)
 
396
  if job_id in training_jobs:
397
  del training_jobs[job_id]
398
  except:
@@ -427,7 +390,6 @@ def create_model_zip(model_path, job_id):
427
 
428
  memory_file.seek(0)
429
  return memory_file
430
-
431
  # ============== API ROUTES ==============
432
  @app.route('/api/train', methods=['POST'])
433
  def start_training():
@@ -435,9 +397,9 @@ def start_training():
435
  try:
436
  data = request.get_json() if request.is_json else {}
437
  dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
438
- base_model_name = data.get('base_model', 'microsoft/DialoGPT-medium')
439
 
440
- job_id = str(uuid.uuid4())[:8]
441
  progress = TrainingProgress(job_id)
442
  training_jobs[job_id] = progress
443
 
@@ -531,7 +493,7 @@ def home():
531
  "url": "/api/train",
532
  "body": {
533
  "dataset_name": "your-dataset-name",
534
- "base_model": "microsoft/DialoGPT-medium"
535
  }
536
  }
537
  }
@@ -542,5 +504,5 @@ def health():
542
  return jsonify({"status": "healthy"})
543
 
544
  if __name__ == '__main__':
545
- port = int(os.environ.get('PORT', 7860))
546
  app.run(host='0.0.0.0', port=port, debug=False)
 
112
  return question_col, answer_col
113
 
114
  def train_model_background(job_id, dataset_name, base_model_name=None):
115
+ """Background training function with progress tracking"""
116
  progress = training_jobs[job_id]
117
 
118
  try:
 
138
  TrainingArguments,
139
  Trainer,
140
  TrainerCallback,
 
141
  )
142
  from peft import (
143
  LoraConfig,
144
  get_peft_model,
 
145
  )
146
 
147
  # === Authentication ===
 
152
  progress.status = "loading_model"
153
  progress.message = "Loading base model and tokenizer..."
154
 
155
+ # === Configuration ===
156
+ base_model = base_model_name or "microsoft/DialoGPT-small"
157
  new_model = f"trained-model-{job_id}"
158
+ max_length = 256
159
 
160
  # === Load Model and Tokenizer ===
161
  model = AutoModelForCausalLM.from_pretrained(
162
  base_model,
163
  cache_dir=temp_dir,
164
+ torch_dtype=torch.float32,
165
  device_map="auto" if torch.cuda.is_available() else "cpu",
166
+ trust_remote_code=True
 
167
  )
168
 
169
  tokenizer = AutoTokenizer.from_pretrained(
170
  base_model,
171
  cache_dir=temp_dir,
172
+ trust_remote_code=True
 
173
  )
174
 
175
  # Add padding token if not present
176
  if tokenizer.pad_token is None:
177
  tokenizer.pad_token = tokenizer.eos_token
 
178
 
179
  # Resize token embeddings if needed
180
  model.resize_token_embeddings(len(tokenizer))
 
184
 
185
  # === LoRA Config ===
186
  peft_config = LoraConfig(
187
+ r=8,
188
+ lora_alpha=16,
189
+ lora_dropout=0.1,
190
  bias="none",
191
+ task_type="CAUSAL_LM",
 
192
  )
193
  model = get_peft_model(model, peft_config)
 
 
 
194
 
195
  progress.status = "loading_dataset"
196
  progress.message = "Loading and preparing dataset..."
 
212
  progress.detected_columns = {"question": question_col, "answer": answer_col}
213
  progress.message = f"Detected columns - Question: {question_col}, Answer: {answer_col}"
214
 
215
+ # Use subset for faster testing (can be made configurable)
216
+ dataset = dataset.shuffle(seed=65).select(range(min(1000, len(dataset))))
 
217
 
218
+ # Custom dataset class for proper handling
219
+ class CustomDataset(torch.utils.data.Dataset):
220
+ def __init__(self, texts, tokenizer, max_length):
221
+ self.texts = texts
222
+ self.tokenizer = tokenizer
223
+ self.max_length = max_length
 
 
 
 
 
 
 
 
224
 
225
+ def __len__(self):
226
+ return len(self.texts)
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ def __getitem__(self, idx):
229
+ text = self.texts[idx]
230
+
231
+ # Tokenize the text
232
+ encoding = self.tokenizer(
233
+ text,
234
+ truncation=True,
235
+ padding='max_length',
236
+ max_length=self.max_length,
237
+ return_tensors='pt'
238
+ )
239
+
240
+ # Flatten the tensors (remove batch dimension)
241
+ input_ids = encoding['input_ids'].squeeze()
242
+ attention_mask = encoding['attention_mask'].squeeze()
243
+
244
+ # For causal language modeling, labels are the same as input_ids
245
+ labels = input_ids.clone()
246
+
247
+ # Set labels to -100 for padding tokens (they won't contribute to loss)
248
+ labels[attention_mask == 0] = -100
249
+
250
+ return {
251
+ 'input_ids': input_ids,
252
+ 'attention_mask': attention_mask,
253
+ 'labels': labels
254
+ }
255
+
256
+ # Prepare texts using detected columns
257
+ texts = []
258
+ for item in dataset:
259
+ question = str(item[question_col]).strip()
260
+ answer = str(item[answer_col]).strip()
261
+ text = f"Question: {question}\nAnswer: {answer}{tokenizer.eos_token}"
262
+ texts.append(text)
263
+
264
+ # Create custom dataset
265
+ train_dataset = CustomDataset(texts, tokenizer, max_length)
266
 
267
+ # Calculate total training steps
268
+ batch_size = 2
269
+ gradient_accumulation_steps = 1
270
+ num_epochs = 1
 
271
 
272
+ steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
273
  total_steps = steps_per_epoch * num_epochs
 
274
 
275
  progress.total_steps = total_steps
276
  progress.status = "training"
277
  progress.message = "Starting training..."
278
 
279
+ # === Training Arguments ===
280
  output_dir = os.path.join(temp_dir, new_model)
281
  os.makedirs(output_dir, exist_ok=True)
282
 
 
285
  per_device_train_batch_size=batch_size,
286
  gradient_accumulation_steps=gradient_accumulation_steps,
287
  num_train_epochs=num_epochs,
288
+ logging_steps=1,
289
+ save_steps=max(1, total_steps // 2),
290
+ save_total_limit=1,
291
+ learning_rate=5e-5,
292
+ warmup_steps=2,
 
293
  logging_strategy="steps",
294
  save_strategy="steps",
295
+ fp16=False,
296
  bf16=False,
297
  dataloader_num_workers=0,
298
  remove_unused_columns=False,
299
  report_to=None,
300
  prediction_loss_only=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  )
302
 
303
  # Custom callback to track progress
 
308
 
309
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
310
  current_time = time.time()
311
+ # Update every 3 seconds
312
+ if current_time - self.last_update >= 3:
313
  self.progress_tracker.update_progress(
314
  state.global_step,
315
  state.max_steps,
 
317
  )
318
  self.last_update = current_time
319
 
320
+ # Log training metrics if available
321
  if logs:
322
  loss = logs.get('train_loss', logs.get('loss', 'N/A'))
323
+ self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}"
 
 
 
324
 
325
  def on_train_begin(self, args, state, control, **kwargs):
326
  self.progress_tracker.status = "training"
 
334
  trainer = Trainer(
335
  model=model,
336
  args=training_args,
337
+ train_dataset=train_dataset,
 
338
  callbacks=[ProgressCallback(progress)],
339
  tokenizer=tokenizer,
340
  )
341
 
342
  # === Train & Save ===
343
  trainer.train()
 
 
344
  trainer.save_model(output_dir)
345
  tokenizer.save_pretrained(output_dir)
346
 
347
+ # Save model info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  progress.model_path = output_dir
349
  progress.status = "completed"
350
  progress.progress = 100
351
+ progress.message = f"Training completed! Model ready for download."
352
 
353
+ # Keep the temp directory for download (cleanup after 1 hour)
354
  def cleanup_temp_dir():
355
+ time.sleep(3600) # Wait 1 hour before cleanup
356
  try:
357
  shutil.rmtree(temp_dir)
358
+ # Remove from training_jobs after cleanup
359
  if job_id in training_jobs:
360
  del training_jobs[job_id]
361
  except:
 
390
 
391
  memory_file.seek(0)
392
  return memory_file
 
393
  # ============== API ROUTES ==============
394
  @app.route('/api/train', methods=['POST'])
395
  def start_training():
 
397
  try:
398
  data = request.get_json() if request.is_json else {}
399
  dataset_name = data.get('dataset_name', 'ruslanmv/ai-medical-chatbot')
400
+ base_model_name = data.get('base_model', 'microsoft/DialoGPT-small')
401
 
402
+ job_id = str(uuid.uuid4())[:8] # Short UUID
403
  progress = TrainingProgress(job_id)
404
  training_jobs[job_id] = progress
405
 
 
493
  "url": "/api/train",
494
  "body": {
495
  "dataset_name": "your-dataset-name",
496
+ "base_model": "microsoft/DialoGPT-small"
497
  }
498
  }
499
  }
 
504
  return jsonify({"status": "healthy"})
505
 
506
  if __name__ == '__main__':
507
+ port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
508
  app.run(host='0.0.0.0', port=port, debug=False)