Guetat Youssef commited on
Commit
aba82e3
·
1 Parent(s): 10b3fe6
Files changed (2) hide show
  1. app.py +53 -35
  2. requirements.txt +5 -5
app.py CHANGED
@@ -72,19 +72,17 @@ def train_model_background(job_id):
72
  from datasets import load_dataset
73
  from huggingface_hub import login
74
  from transformers import (
75
- AutoConfig,
76
  AutoModelForCausalLM,
77
  AutoTokenizer,
78
- BitsAndBytesConfig,
79
  TrainingArguments,
80
- logging,
81
- TrainerCallback
 
82
  )
83
  from peft import (
84
  LoraConfig,
85
  get_peft_model,
86
  )
87
- from trl import SFTTrainer, setup_chat_format
88
 
89
  # === Authentication ===
90
  hf_token = os.getenv('HF_TOKEN')
@@ -99,11 +97,11 @@ def train_model_background(job_id):
99
  dataset_name = "ruslanmv/ai-medical-chatbot"
100
  new_model = f"trained-model-{job_id}"
101
 
102
- # === Load Model and Tokenizer (without quantization for simplicity) ===
103
  model = AutoModelForCausalLM.from_pretrained(
104
  base_model,
105
  cache_dir=temp_dir,
106
- torch_dtype=torch.float32, # Use float32 for compatibility
107
  device_map="auto" if torch.cuda.is_available() else "cpu",
108
  trust_remote_code=True
109
  )
@@ -121,9 +119,9 @@ def train_model_background(job_id):
121
  progress.status = "preparing_model"
122
  progress.message = "Setting up LoRA configuration..."
123
 
124
- # === LoRA Config (simplified) ===
125
  peft_config = LoraConfig(
126
- r=8, # Smaller rank
127
  lora_alpha=16,
128
  lora_dropout=0.1,
129
  bias="none",
@@ -141,19 +139,45 @@ def train_model_background(job_id):
141
  cache_dir=temp_dir,
142
  trust_remote_code=True
143
  )
144
- dataset = dataset.shuffle(seed=65).select(range(100)) # Use only 100 samples for testing
145
 
146
- def format_chat_template(row):
147
- # Simple formatting without chat template
148
- text = f"Patient: {row['Patient']}\nDoctor: {row['Doctor']}"
149
- return {"text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- dataset = dataset.map(format_chat_template, num_proc=1)
152
- dataset = dataset.train_test_split(test_size=0.1)
 
 
 
 
 
 
 
 
 
 
 
153
 
154
  # Calculate total training steps
155
- train_size = len(dataset["train"])
156
- batch_size = 1
157
  gradient_accumulation_steps = 1
158
  num_epochs = 1
159
 
@@ -171,25 +195,20 @@ def train_model_background(job_id):
171
  training_args = TrainingArguments(
172
  output_dir=output_dir,
173
  per_device_train_batch_size=batch_size,
174
- per_device_eval_batch_size=1,
175
  gradient_accumulation_steps=gradient_accumulation_steps,
176
- optim="adamw_torch", # Use standard optimizer
177
  num_train_epochs=num_epochs,
178
- eval_steps=0.5,
179
  logging_steps=1,
 
 
 
180
  warmup_steps=5,
181
  logging_strategy="steps",
182
- learning_rate=5e-5,
183
  fp16=False,
184
  bf16=False,
185
- group_by_length=True,
186
- save_steps=10,
187
- save_total_limit=1,
188
- report_to=None,
189
  dataloader_num_workers=0,
190
  remove_unused_columns=False,
191
- load_best_model_at_end=False,
192
- # Remove evaluation_strategy parameter - not supported in this version
193
  )
194
 
195
  # Custom callback to track progress
@@ -200,8 +219,8 @@ def train_model_background(job_id):
200
 
201
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
202
  current_time = time.time()
203
- # Update every 10 seconds or on significant step changes
204
- if current_time - self.last_update >= 10 or state.global_step % 5 == 0:
205
  self.progress_tracker.update_progress(
206
  state.global_step,
207
  state.max_steps,
@@ -218,19 +237,18 @@ def train_model_background(job_id):
218
  self.progress_tracker.message = "Training complete, saving model..."
219
 
220
  # === Trainer Initialization ===
221
- trainer = SFTTrainer(
222
  model=model,
223
- train_dataset=dataset["train"],
224
- peft_config=peft_config,
225
  args=training_args,
 
 
226
  callbacks=[ProgressCallback(progress)],
227
- tokenizer=tokenizer,
228
- max_seq_length=256, # Shorter sequences
229
  )
230
 
231
  # === Train & Save ===
232
  trainer.train()
233
  trainer.save_model(output_dir)
 
234
 
235
  progress.status = "completed"
236
  progress.progress = 100
 
72
  from datasets import load_dataset
73
  from huggingface_hub import login
74
  from transformers import (
 
75
  AutoModelForCausalLM,
76
  AutoTokenizer,
 
77
  TrainingArguments,
78
+ Trainer,
79
+ TrainerCallback,
80
+ DataCollatorForLanguageModeling
81
  )
82
  from peft import (
83
  LoraConfig,
84
  get_peft_model,
85
  )
 
86
 
87
  # === Authentication ===
88
  hf_token = os.getenv('HF_TOKEN')
 
97
  dataset_name = "ruslanmv/ai-medical-chatbot"
98
  new_model = f"trained-model-{job_id}"
99
 
100
+ # === Load Model and Tokenizer ===
101
  model = AutoModelForCausalLM.from_pretrained(
102
  base_model,
103
  cache_dir=temp_dir,
104
+ torch_dtype=torch.float32,
105
  device_map="auto" if torch.cuda.is_available() else "cpu",
106
  trust_remote_code=True
107
  )
 
119
  progress.status = "preparing_model"
120
  progress.message = "Setting up LoRA configuration..."
121
 
122
+ # === LoRA Config ===
123
  peft_config = LoraConfig(
124
+ r=8,
125
  lora_alpha=16,
126
  lora_dropout=0.1,
127
  bias="none",
 
139
  cache_dir=temp_dir,
140
  trust_remote_code=True
141
  )
142
+ dataset = dataset.shuffle(seed=65).select(range(50)) # Use only 50 samples for faster testing
143
 
144
+ def tokenize_function(examples):
145
+ # Format the text
146
+ texts = []
147
+ for i in range(len(examples['Patient'])):
148
+ text = f"Patient: {examples['Patient'][i]}\nDoctor: {examples['Doctor'][i]}{tokenizer.eos_token}"
149
+ texts.append(text)
150
+
151
+ # Tokenize
152
+ tokenized = tokenizer(
153
+ texts,
154
+ truncation=True,
155
+ padding=False,
156
+ max_length=256,
157
+ return_tensors=None
158
+ )
159
+
160
+ # For causal LM, labels are the same as input_ids
161
+ tokenized["labels"] = tokenized["input_ids"].copy()
162
+ return tokenized
163
 
164
+ # Tokenize dataset
165
+ tokenized_dataset = dataset.map(
166
+ tokenize_function,
167
+ batched=True,
168
+ remove_columns=dataset.column_names,
169
+ desc="Tokenizing dataset"
170
+ )
171
+
172
+ # Data collator for language modeling
173
+ data_collator = DataCollatorForLanguageModeling(
174
+ tokenizer=tokenizer,
175
+ mlm=False, # We're doing causal LM, not masked LM
176
+ )
177
 
178
  # Calculate total training steps
179
+ train_size = len(tokenized_dataset)
180
+ batch_size = 2
181
  gradient_accumulation_steps = 1
182
  num_epochs = 1
183
 
 
195
  training_args = TrainingArguments(
196
  output_dir=output_dir,
197
  per_device_train_batch_size=batch_size,
 
198
  gradient_accumulation_steps=gradient_accumulation_steps,
 
199
  num_train_epochs=num_epochs,
 
200
  logging_steps=1,
201
+ save_steps=20,
202
+ save_total_limit=1,
203
+ learning_rate=5e-5,
204
  warmup_steps=5,
205
  logging_strategy="steps",
206
+ save_strategy="steps",
207
  fp16=False,
208
  bf16=False,
 
 
 
 
209
  dataloader_num_workers=0,
210
  remove_unused_columns=False,
211
+ report_to=None,
 
212
  )
213
 
214
  # Custom callback to track progress
 
219
 
220
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
221
  current_time = time.time()
222
+ # Update every 5 seconds or on significant step changes
223
+ if current_time - self.last_update >= 5 or state.global_step % 2 == 0:
224
  self.progress_tracker.update_progress(
225
  state.global_step,
226
  state.max_steps,
 
237
  self.progress_tracker.message = "Training complete, saving model..."
238
 
239
  # === Trainer Initialization ===
240
+ trainer = Trainer(
241
  model=model,
 
 
242
  args=training_args,
243
+ train_dataset=tokenized_dataset,
244
+ data_collator=data_collator,
245
  callbacks=[ProgressCallback(progress)],
 
 
246
  )
247
 
248
  # === Train & Save ===
249
  trainer.train()
250
  trainer.save_model(output_dir)
251
+ tokenizer.save_pretrained(output_dir)
252
 
253
  progress.status = "completed"
254
  progress.progress = 100
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
  flask==2.3.3
2
- transformers>=4.36.0,<4.45.0
3
- datasets>=2.14.0
4
- accelerate>=0.24.0
5
- peft>=0.6.0,<0.8.0
6
- trl>=0.7.0
7
  bitsandbytes
8
  torch>=2.0.0
9
  torchvision
 
1
  flask==2.3.3
2
+ transformers==4.44.2
3
+ datasets==2.20.0
4
+ accelerate==0.33.0
5
+ peft==0.12.0
6
+ trl==0.9.6
7
  bitsandbytes
8
  torch>=2.0.0
9
  torchvision