Guetat Youssef commited on
Commit
3349c56
·
1 Parent(s): e4256df
Files changed (2) hide show
  1. Dockerfile +13 -4
  2. app.py +109 -89
Dockerfile CHANGED
@@ -7,8 +7,21 @@ RUN apt-get update && apt-get install -y \
7
  git \
8
  curl \
9
  build-essential \
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Copy requirements and install Python dependencies
13
  COPY requirements.txt .
14
  RUN pip install --no-cache-dir -r requirements.txt
@@ -19,9 +32,5 @@ COPY . .
19
  # Expose port
20
  EXPOSE 7860
21
 
22
- # Set environment variables
23
- ENV PYTHONPATH=/app
24
- ENV FLASK_APP=app.py
25
-
26
  # Run the application
27
  CMD ["python", "app.py"]
 
7
  git \
8
  curl \
9
  build-essential \
10
+ wget \
11
  && rm -rf /var/lib/apt/lists/*
12
 
13
+ # Create cache directory with proper permissions
14
+ RUN mkdir -p /app/cache && chmod 777 /app/cache
15
+ RUN mkdir -p /app/models && chmod 777 /app/models
16
+
17
+ # Set environment variables for caching
18
+ ENV HF_HOME=/app/cache
19
+ ENV TRANSFORMERS_CACHE=/app/cache
20
+ ENV HF_DATASETS_CACHE=/app/cache
21
+ ENV TORCH_HOME=/app/cache
22
+ ENV PYTHONPATH=/app
23
+ ENV FLASK_APP=app.py
24
+
25
  # Copy requirements and install Python dependencies
26
  COPY requirements.txt .
27
  RUN pip install --no-cache-dir -r requirements.txt
 
32
  # Expose port
33
  EXPOSE 7860
34
 
 
 
 
 
35
  # Run the application
36
  CMD ["python", "app.py"]
app.py CHANGED
@@ -2,30 +2,11 @@ from flask import Flask, jsonify, request
2
  import threading
3
  import time
4
  import os
5
- import torch
6
- from datasets import load_dataset
7
- from huggingface_hub import login
8
- from transformers import (
9
- AutoConfig,
10
- AutoModelForCausalLM,
11
- AutoTokenizer,
12
- BitsAndBytesConfig,
13
- TrainingArguments,
14
- pipeline,
15
- logging,
16
- DataCollatorForLanguageModeling,
17
- )
18
- from peft import (
19
- LoraConfig,
20
- PeftModel,
21
- prepare_model_for_kbit_training,
22
- get_peft_model,
23
- )
24
- from trl import SFTTrainer, setup_chat_format
25
  import uuid
26
  from datetime import datetime, timedelta
27
 
28
- # ============== CONFIGURATION ==============
29
  app = Flask(__name__)
30
 
31
  # Global variables to track training progress
@@ -74,60 +55,78 @@ def train_model_background(job_id):
74
  progress = training_jobs[job_id]
75
 
76
  try:
77
- # === Authentication ===
78
- import os
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  from huggingface_hub import login
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  hf_token = os.getenv('HF_TOKEN')
82
-
83
- if not hf_token:
84
- raise ValueError("HF_TOKEN is not set. Please define it as an environment variable or secret.")
85
-
86
- login(token=hf_token)
87
-
88
 
89
  progress.status = "loading_model"
90
  progress.message = "Loading base model and tokenizer..."
91
 
92
  # === Configuration ===
93
- base_model = "meta-llama/Llama-3.2-1B"
94
  dataset_name = "ruslanmv/ai-medical-chatbot"
95
- new_model = f"Llama-3.2-3B-chat-doctor-{job_id}"
96
- torch_dtype = torch.float16
97
- attn_implementation = "eager"
98
-
99
- # === QLoRA Config ===
100
- bnb_config = BitsAndBytesConfig(
101
- load_in_4bit=True,
102
- bnb_4bit_quant_type="nf4",
103
- bnb_4bit_compute_dtype=torch_dtype,
104
- bnb_4bit_use_double_quant=True,
105
- )
106
-
107
- # === Load Model and Tokenizer ===
108
  model = AutoModelForCausalLM.from_pretrained(
109
  base_model,
110
- quantization_config=bnb_config,
111
- device_map="auto",
112
- attn_implementation=attn_implementation
 
113
  )
114
- tokenizer = AutoTokenizer.from_pretrained(base_model)
115
- model, tokenizer = setup_chat_format(model, tokenizer)
 
 
 
 
 
 
 
 
116
 
117
  progress.status = "preparing_model"
118
  progress.message = "Setting up LoRA configuration..."
119
 
120
- # === LoRA Config ===
121
  peft_config = LoraConfig(
122
- r=16,
123
- lora_alpha=32,
124
- lora_dropout=0.05,
125
  bias="none",
126
  task_type="CAUSAL_LM",
127
- target_modules=[
128
- 'up_proj', 'down_proj', 'gate_proj',
129
- 'k_proj', 'q_proj', 'v_proj', 'o_proj'
130
- ]
131
  )
132
  model = get_peft_model(model, peft_config)
133
 
@@ -135,29 +134,26 @@ def train_model_background(job_id):
135
  progress.message = "Loading and preparing dataset..."
136
 
137
  # === Load & Prepare Dataset ===
138
- dataset = load_dataset(dataset_name, split="all")
139
- dataset = dataset.shuffle(seed=65).select(range(1000)) # Use 1000 samples
140
-
141
- def format_chat_template(row, tokenizer):
142
- row_json = [
143
- {"role": "user", "content": row["Patient"]},
144
- {"role": "assistant", "content": row["Doctor"]}
145
- ]
146
- row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
147
- return row
148
-
149
- dataset = dataset.map(
150
- format_chat_template,
151
- fn_kwargs={"tokenizer": tokenizer},
152
- num_proc=4
153
  )
 
 
 
 
 
 
154
 
 
155
  dataset = dataset.train_test_split(test_size=0.1)
156
 
157
  # Calculate total training steps
158
  train_size = len(dataset["train"])
159
  batch_size = 1
160
- gradient_accumulation_steps = 2
161
  num_epochs = 1
162
 
163
  steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
@@ -168,29 +164,33 @@ def train_model_background(job_id):
168
  progress.message = "Starting training..."
169
 
170
  # === Training Arguments ===
 
 
 
171
  training_args = TrainingArguments(
172
- output_dir=new_model,
173
  per_device_train_batch_size=batch_size,
174
  per_device_eval_batch_size=1,
175
  gradient_accumulation_steps=gradient_accumulation_steps,
176
- optim="paged_adamw_32bit",
177
  num_train_epochs=num_epochs,
178
- eval_steps=0.2,
179
  logging_steps=1,
180
- warmup_steps=10,
181
  logging_strategy="steps",
182
- learning_rate=2e-5,
183
  fp16=False,
184
  bf16=False,
185
  group_by_length=True,
186
- save_steps=50,
187
- save_total_limit=2,
188
- report_to=None # Disable wandb for HF Spaces
 
 
 
 
189
  )
190
 
191
- # === Data Collator ===
192
- tokenizer.model_max_length = 512
193
-
194
  # Custom callback to track progress
195
  class ProgressCallback:
196
  def __init__(self, progress_tracker):
@@ -200,7 +200,7 @@ def train_model_background(job_id):
200
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
201
  current_time = time.time()
202
  # Update every 10 seconds or on significant step changes
203
- if current_time - self.last_update >= 10 or state.global_step % 10 == 0:
204
  self.progress_tracker.update_progress(
205
  state.global_step,
206
  state.max_steps,
@@ -212,24 +212,44 @@ def train_model_background(job_id):
212
  trainer = SFTTrainer(
213
  model=model,
214
  train_dataset=dataset["train"],
215
- eval_dataset=dataset["test"],
216
  peft_config=peft_config,
217
  args=training_args,
218
- callbacks=[ProgressCallback(progress)]
 
 
219
  )
220
 
221
  # === Train & Save ===
222
  trainer.train()
223
- trainer.save_model(new_model)
224
 
225
  progress.status = "completed"
226
  progress.progress = 100
227
- progress.message = f"Training completed! Model saved as {new_model}"
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
  except Exception as e:
230
  progress.status = "error"
231
  progress.error = str(e)
232
  progress.message = f"Training failed: {str(e)}"
 
 
 
 
 
 
 
233
 
234
  # ============== API ROUTES ==============
235
  @app.route('/api/train', methods=['POST'])
 
2
  import threading
3
  import time
4
  import os
5
+ import tempfile
6
+ import shutil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import uuid
8
  from datetime import datetime, timedelta
9
 
 
10
  app = Flask(__name__)
11
 
12
  # Global variables to track training progress
 
55
  progress = training_jobs[job_id]
56
 
57
  try:
58
+ # Create a temporary directory for this job
59
+ temp_dir = tempfile.mkdtemp(prefix=f"train_{job_id}_")
60
+
61
+ # Set environment variables for caching
62
+ os.environ['HF_HOME'] = temp_dir
63
+ os.environ['TRANSFORMERS_CACHE'] = temp_dir
64
+ os.environ['HF_DATASETS_CACHE'] = temp_dir
65
+ os.environ['TORCH_HOME'] = temp_dir
66
+
67
+ progress.status = "loading_libraries"
68
+ progress.message = "Loading required libraries..."
69
+
70
+ # Import heavy libraries after setting cache paths
71
+ import torch
72
+ from datasets import load_dataset
73
  from huggingface_hub import login
74
+ from transformers import (
75
+ AutoConfig,
76
+ AutoModelForCausalLM,
77
+ AutoTokenizer,
78
+ BitsAndBytesConfig,
79
+ TrainingArguments,
80
+ logging,
81
+ )
82
+ from peft import (
83
+ LoraConfig,
84
+ get_peft_model,
85
+ )
86
+ from trl import SFTTrainer, setup_chat_format
87
+
88
+ # === Authentication ===
89
  hf_token = os.getenv('HF_TOKEN')
90
+ if hf_token:
91
+ login(token=hf_token)
 
 
 
 
92
 
93
  progress.status = "loading_model"
94
  progress.message = "Loading base model and tokenizer..."
95
 
96
  # === Configuration ===
97
+ base_model = "microsoft/DialoGPT-small" # Smaller model for testing
98
  dataset_name = "ruslanmv/ai-medical-chatbot"
99
+ new_model = f"trained-model-{job_id}"
100
+
101
+ # === Load Model and Tokenizer (without quantization for simplicity) ===
 
 
 
 
 
 
 
 
 
 
102
  model = AutoModelForCausalLM.from_pretrained(
103
  base_model,
104
+ cache_dir=temp_dir,
105
+ torch_dtype=torch.float32, # Use float32 for compatibility
106
+ device_map="auto" if torch.cuda.is_available() else "cpu",
107
+ trust_remote_code=True
108
  )
109
+
110
+ tokenizer = AutoTokenizer.from_pretrained(
111
+ base_model,
112
+ cache_dir=temp_dir,
113
+ trust_remote_code=True
114
+ )
115
+
116
+ # Add padding token if not present
117
+ if tokenizer.pad_token is None:
118
+ tokenizer.pad_token = tokenizer.eos_token
119
 
120
  progress.status = "preparing_model"
121
  progress.message = "Setting up LoRA configuration..."
122
 
123
+ # === LoRA Config (simplified) ===
124
  peft_config = LoraConfig(
125
+ r=8, # Smaller rank
126
+ lora_alpha=16,
127
+ lora_dropout=0.1,
128
  bias="none",
129
  task_type="CAUSAL_LM",
 
 
 
 
130
  )
131
  model = get_peft_model(model, peft_config)
132
 
 
134
  progress.message = "Loading and preparing dataset..."
135
 
136
  # === Load & Prepare Dataset ===
137
+ dataset = load_dataset(
138
+ dataset_name,
139
+ split="all",
140
+ cache_dir=temp_dir,
141
+ trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
142
  )
143
+ dataset = dataset.shuffle(seed=65).select(range(100)) # Use only 100 samples for testing
144
+
145
+ def format_chat_template(row):
146
+ # Simple formatting without chat template
147
+ text = f"Patient: {row['Patient']}\nDoctor: {row['Doctor']}"
148
+ return {"text": text}
149
 
150
+ dataset = dataset.map(format_chat_template, num_proc=1)
151
  dataset = dataset.train_test_split(test_size=0.1)
152
 
153
  # Calculate total training steps
154
  train_size = len(dataset["train"])
155
  batch_size = 1
156
+ gradient_accumulation_steps = 1
157
  num_epochs = 1
158
 
159
  steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
 
164
  progress.message = "Starting training..."
165
 
166
  # === Training Arguments ===
167
+ output_dir = os.path.join(temp_dir, new_model)
168
+ os.makedirs(output_dir, exist_ok=True)
169
+
170
  training_args = TrainingArguments(
171
+ output_dir=output_dir,
172
  per_device_train_batch_size=batch_size,
173
  per_device_eval_batch_size=1,
174
  gradient_accumulation_steps=gradient_accumulation_steps,
175
+ optim="adamw_torch", # Use standard optimizer
176
  num_train_epochs=num_epochs,
177
+ eval_steps=0.5,
178
  logging_steps=1,
179
+ warmup_steps=5,
180
  logging_strategy="steps",
181
+ learning_rate=5e-5,
182
  fp16=False,
183
  bf16=False,
184
  group_by_length=True,
185
+ save_steps=10,
186
+ save_total_limit=1,
187
+ report_to=None,
188
+ dataloader_num_workers=0,
189
+ remove_unused_columns=False,
190
+ load_best_model_at_end=False,
191
+ evaluation_strategy="no" # Disable evaluation for simplicity
192
  )
193
 
 
 
 
194
  # Custom callback to track progress
195
  class ProgressCallback:
196
  def __init__(self, progress_tracker):
 
200
  def on_log(self, args, state, control, model=None, logs=None, **kwargs):
201
  current_time = time.time()
202
  # Update every 10 seconds or on significant step changes
203
+ if current_time - self.last_update >= 10 or state.global_step % 5 == 0:
204
  self.progress_tracker.update_progress(
205
  state.global_step,
206
  state.max_steps,
 
212
  trainer = SFTTrainer(
213
  model=model,
214
  train_dataset=dataset["train"],
 
215
  peft_config=peft_config,
216
  args=training_args,
217
+ callbacks=[ProgressCallback(progress)],
218
+ tokenizer=tokenizer,
219
+ max_seq_length=256, # Shorter sequences
220
  )
221
 
222
  # === Train & Save ===
223
  trainer.train()
224
+ trainer.save_model(output_dir)
225
 
226
  progress.status = "completed"
227
  progress.progress = 100
228
+ progress.message = f"Training completed! Model saved to {output_dir}"
229
+
230
+ # Clean up temporary directory after a delay
231
+ def cleanup_temp_dir():
232
+ time.sleep(300) # Wait 5 minutes before cleanup
233
+ try:
234
+ shutil.rmtree(temp_dir)
235
+ except:
236
+ pass
237
+
238
+ cleanup_thread = threading.Thread(target=cleanup_temp_dir)
239
+ cleanup_thread.daemon = True
240
+ cleanup_thread.start()
241
 
242
  except Exception as e:
243
  progress.status = "error"
244
  progress.error = str(e)
245
  progress.message = f"Training failed: {str(e)}"
246
+
247
+ # Clean up on error
248
+ try:
249
+ if 'temp_dir' in locals():
250
+ shutil.rmtree(temp_dir)
251
+ except:
252
+ pass
253
 
254
  # ============== API ROUTES ==============
255
  @app.route('/api/train', methods=['POST'])