Guetat Youssef commited on
Commit
9774f95
·
1 Parent(s): aba82e3
Files changed (1) hide show
  1. app.py +65 -43
app.py CHANGED
@@ -69,7 +69,7 @@ def train_model_background(job_id):
69
 
70
  # Import heavy libraries after setting cache paths
71
  import torch
72
- from datasets import load_dataset
73
  from huggingface_hub import login
74
  from transformers import (
75
  AutoModelForCausalLM,
@@ -77,7 +77,6 @@ def train_model_background(job_id):
77
  TrainingArguments,
78
  Trainer,
79
  TrainerCallback,
80
- DataCollatorForLanguageModeling
81
  )
82
  from peft import (
83
  LoraConfig,
@@ -93,9 +92,10 @@ def train_model_background(job_id):
93
  progress.message = "Loading base model and tokenizer..."
94
 
95
  # === Configuration ===
96
- base_model = "microsoft/DialoGPT-small" # Smaller model for testing
97
  dataset_name = "ruslanmv/ai-medical-chatbot"
98
  new_model = f"trained-model-{job_id}"
 
99
 
100
  # === Load Model and Tokenizer ===
101
  model = AutoModelForCausalLM.from_pretrained(
@@ -115,6 +115,9 @@ def train_model_background(job_id):
115
  # Add padding token if not present
116
  if tokenizer.pad_token is None:
117
  tokenizer.pad_token = tokenizer.eos_token
 
 
 
118
 
119
  progress.status = "preparing_model"
120
  progress.message = "Setting up LoRA configuration..."
@@ -139,49 +142,62 @@ def train_model_background(job_id):
139
  cache_dir=temp_dir,
140
  trust_remote_code=True
141
  )
142
- dataset = dataset.shuffle(seed=65).select(range(50)) # Use only 50 samples for faster testing
143
 
144
- def tokenize_function(examples):
145
- # Format the text
146
- texts = []
147
- for i in range(len(examples['Patient'])):
148
- text = f"Patient: {examples['Patient'][i]}\nDoctor: {examples['Doctor'][i]}{tokenizer.eos_token}"
149
- texts.append(text)
150
-
151
- # Tokenize
152
- tokenized = tokenizer(
153
- texts,
154
- truncation=True,
155
- padding=False,
156
- max_length=256,
157
- return_tensors=None
158
- )
159
-
160
- # For causal LM, labels are the same as input_ids
161
- tokenized["labels"] = tokenized["input_ids"].copy()
162
- return tokenized
163
 
164
- # Tokenize dataset
165
- tokenized_dataset = dataset.map(
166
- tokenize_function,
167
- batched=True,
168
- remove_columns=dataset.column_names,
169
- desc="Tokenizing dataset"
170
- )
171
 
172
- # Data collator for language modeling
173
- data_collator = DataCollatorForLanguageModeling(
174
- tokenizer=tokenizer,
175
- mlm=False, # We're doing causal LM, not masked LM
176
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  # Calculate total training steps
179
- train_size = len(tokenized_dataset)
180
  batch_size = 2
181
  gradient_accumulation_steps = 1
182
  num_epochs = 1
183
 
184
- steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
185
  total_steps = steps_per_epoch * num_epochs
186
 
187
  progress.total_steps = total_steps
@@ -198,10 +214,10 @@ def train_model_background(job_id):
198
  gradient_accumulation_steps=gradient_accumulation_steps,
199
  num_train_epochs=num_epochs,
200
  logging_steps=1,
201
- save_steps=20,
202
  save_total_limit=1,
203
  learning_rate=5e-5,
204
- warmup_steps=5,
205
  logging_strategy="steps",
206
  save_strategy="steps",
207
  fp16=False,
@@ -209,6 +225,7 @@ def train_model_background(job_id):
209
  dataloader_num_workers=0,
210
  remove_unused_columns=False,
211
  report_to=None,
 
212
  )
213
 
214
  # Custom callback to track progress
@@ -219,14 +236,19 @@ def train_model_background(job_id):
219
 
220
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
221
  current_time = time.time()
222
- # Update every 5 seconds or on significant step changes
223
- if current_time - self.last_update >= 5 or state.global_step % 2 == 0:
224
  self.progress_tracker.update_progress(
225
  state.global_step,
226
  state.max_steps,
227
  f"Training step {state.global_step}/{state.max_steps}"
228
  )
229
  self.last_update = current_time
 
 
 
 
 
230
 
231
  def on_train_begin(self, args, state, control, **kwargs):
232
  self.progress_tracker.status = "training"
@@ -240,9 +262,9 @@ def train_model_background(job_id):
240
  trainer = Trainer(
241
  model=model,
242
  args=training_args,
243
- train_dataset=tokenized_dataset,
244
- data_collator=data_collator,
245
  callbacks=[ProgressCallback(progress)],
 
246
  )
247
 
248
  # === Train & Save ===
 
69
 
70
  # Import heavy libraries after setting cache paths
71
  import torch
72
+ from datasets import load_dataset, Dataset
73
  from huggingface_hub import login
74
  from transformers import (
75
  AutoModelForCausalLM,
 
77
  TrainingArguments,
78
  Trainer,
79
  TrainerCallback,
 
80
  )
81
  from peft import (
82
  LoraConfig,
 
92
  progress.message = "Loading base model and tokenizer..."
93
 
94
  # === Configuration ===
95
+ base_model = "microsoft/DialoGPT-small"
96
  dataset_name = "ruslanmv/ai-medical-chatbot"
97
  new_model = f"trained-model-{job_id}"
98
+ max_length = 256
99
 
100
  # === Load Model and Tokenizer ===
101
  model = AutoModelForCausalLM.from_pretrained(
 
115
  # Add padding token if not present
116
  if tokenizer.pad_token is None:
117
  tokenizer.pad_token = tokenizer.eos_token
118
+
119
+ # Resize token embeddings if needed
120
+ model.resize_token_embeddings(len(tokenizer))
121
 
122
  progress.status = "preparing_model"
123
  progress.message = "Setting up LoRA configuration..."
 
142
  cache_dir=temp_dir,
143
  trust_remote_code=True
144
  )
145
+ dataset = dataset.shuffle(seed=65).select(range(30)) # Use only 30 samples for faster testing
146
 
147
+ # Custom dataset class for proper handling
148
+ class CustomDataset(torch.utils.data.Dataset):
149
+ def __init__(self, texts, tokenizer, max_length):
150
+ self.texts = texts
151
+ self.tokenizer = tokenizer
152
+ self.max_length = max_length
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ def __len__(self):
155
+ return len(self.texts)
 
 
 
 
 
156
 
157
+ def __getitem__(self, idx):
158
+ text = self.texts[idx]
159
+
160
+ # Tokenize the text
161
+ encoding = self.tokenizer(
162
+ text,
163
+ truncation=True,
164
+ padding='max_length',
165
+ max_length=self.max_length,
166
+ return_tensors='pt'
167
+ )
168
+
169
+ # Flatten the tensors (remove batch dimension)
170
+ input_ids = encoding['input_ids'].squeeze()
171
+ attention_mask = encoding['attention_mask'].squeeze()
172
+
173
+ # For causal language modeling, labels are the same as input_ids
174
+ # But we shift them so the model predicts the next token
175
+ labels = input_ids.clone()
176
+
177
+ # Set labels to -100 for padding tokens (they won't contribute to loss)
178
+ labels[attention_mask == 0] = -100
179
+
180
+ return {
181
+ 'input_ids': input_ids,
182
+ 'attention_mask': attention_mask,
183
+ 'labels': labels
184
+ }
185
+
186
+ # Prepare texts
187
+ texts = []
188
+ for item in dataset:
189
+ text = f"Patient: {item['Patient']}\nDoctor: {item['Doctor']}{tokenizer.eos_token}"
190
+ texts.append(text)
191
+
192
+ # Create custom dataset
193
+ train_dataset = CustomDataset(texts, tokenizer, max_length)
194
 
195
  # Calculate total training steps
 
196
  batch_size = 2
197
  gradient_accumulation_steps = 1
198
  num_epochs = 1
199
 
200
+ steps_per_epoch = len(train_dataset) // (batch_size * gradient_accumulation_steps)
201
  total_steps = steps_per_epoch * num_epochs
202
 
203
  progress.total_steps = total_steps
 
214
  gradient_accumulation_steps=gradient_accumulation_steps,
215
  num_train_epochs=num_epochs,
216
  logging_steps=1,
217
+ save_steps=15,
218
  save_total_limit=1,
219
  learning_rate=5e-5,
220
+ warmup_steps=2,
221
  logging_strategy="steps",
222
  save_strategy="steps",
223
  fp16=False,
 
225
  dataloader_num_workers=0,
226
  remove_unused_columns=False,
227
  report_to=None,
228
+ prediction_loss_only=True,
229
  )
230
 
231
  # Custom callback to track progress
 
236
 
237
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
238
  current_time = time.time()
239
+ # Update every 3 seconds
240
+ if current_time - self.last_update >= 3:
241
  self.progress_tracker.update_progress(
242
  state.global_step,
243
  state.max_steps,
244
  f"Training step {state.global_step}/{state.max_steps}"
245
  )
246
  self.last_update = current_time
247
+
248
+ # Log training metrics if available
249
+ if logs:
250
+ loss = logs.get('train_loss', logs.get('loss', 'N/A'))
251
+ self.progress_tracker.message = f"Step {state.global_step}/{state.max_steps}, Loss: {loss}"
252
 
253
  def on_train_begin(self, args, state, control, **kwargs):
254
  self.progress_tracker.status = "training"
 
262
  trainer = Trainer(
263
  model=model,
264
  args=training_args,
265
+ train_dataset=train_dataset,
 
266
  callbacks=[ProgressCallback(progress)],
267
+ tokenizer=tokenizer,
268
  )
269
 
270
  # === Train & Save ===