Guetat Youssef commited on
Commit
e4256df
·
1 Parent(s): 3d011a9
Files changed (2) hide show
  1. app.py +289 -0
  2. main.py +0 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, jsonify, request
2
+ import threading
3
+ import time
4
+ import os
5
+ import torch
6
+ from datasets import load_dataset
7
+ from huggingface_hub import login
8
+ from transformers import (
9
+ AutoConfig,
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ BitsAndBytesConfig,
13
+ TrainingArguments,
14
+ pipeline,
15
+ logging,
16
+ DataCollatorForLanguageModeling,
17
+ )
18
+ from peft import (
19
+ LoraConfig,
20
+ PeftModel,
21
+ prepare_model_for_kbit_training,
22
+ get_peft_model,
23
+ )
24
+ from trl import SFTTrainer, setup_chat_format
25
+ import uuid
26
+ from datetime import datetime, timedelta
27
+
28
+ # ============== CONFIGURATION ==============
29
+ app = Flask(__name__)
30
+
31
+ # Global variables to track training progress
32
+ training_jobs = {}
33
+
34
+ class TrainingProgress:
35
+ def __init__(self, job_id):
36
+ self.job_id = job_id
37
+ self.status = "initializing"
38
+ self.progress = 0
39
+ self.current_step = 0
40
+ self.total_steps = 0
41
+ self.start_time = time.time()
42
+ self.estimated_finish_time = None
43
+ self.message = "Starting training..."
44
+ self.error = None
45
+
46
+ def update_progress(self, current_step, total_steps, message=""):
47
+ self.current_step = current_step
48
+ self.total_steps = total_steps
49
+ self.progress = (current_step / total_steps) * 100 if total_steps > 0 else 0
50
+ self.message = message
51
+
52
+ # Calculate estimated finish time
53
+ if current_step > 0:
54
+ elapsed_time = time.time() - self.start_time
55
+ time_per_step = elapsed_time / current_step
56
+ remaining_steps = total_steps - current_step
57
+ estimated_remaining_time = remaining_steps * time_per_step
58
+ self.estimated_finish_time = datetime.now() + timedelta(seconds=estimated_remaining_time)
59
+
60
+ def to_dict(self):
61
+ return {
62
+ "job_id": self.job_id,
63
+ "status": self.status,
64
+ "progress": round(self.progress, 2),
65
+ "current_step": self.current_step,
66
+ "total_steps": self.total_steps,
67
+ "message": self.message,
68
+ "estimated_finish_time": self.estimated_finish_time.isoformat() if self.estimated_finish_time else None,
69
+ "error": self.error
70
+ }
71
+
72
+ def train_model_background(job_id):
73
+ """Background training function with progress tracking"""
74
+ progress = training_jobs[job_id]
75
+
76
+ try:
77
+ # === Authentication ===
78
+ import os
79
+ from huggingface_hub import login
80
+
81
+ hf_token = os.getenv('HF_TOKEN')
82
+
83
+ if not hf_token:
84
+ raise ValueError("HF_TOKEN is not set. Please define it as an environment variable or secret.")
85
+
86
+ login(token=hf_token)
87
+
88
+
89
+ progress.status = "loading_model"
90
+ progress.message = "Loading base model and tokenizer..."
91
+
92
+ # === Configuration ===
93
+ base_model = "meta-llama/Llama-3.2-1B"
94
+ dataset_name = "ruslanmv/ai-medical-chatbot"
95
+ new_model = f"Llama-3.2-3B-chat-doctor-{job_id}"
96
+ torch_dtype = torch.float16
97
+ attn_implementation = "eager"
98
+
99
+ # === QLoRA Config ===
100
+ bnb_config = BitsAndBytesConfig(
101
+ load_in_4bit=True,
102
+ bnb_4bit_quant_type="nf4",
103
+ bnb_4bit_compute_dtype=torch_dtype,
104
+ bnb_4bit_use_double_quant=True,
105
+ )
106
+
107
+ # === Load Model and Tokenizer ===
108
+ model = AutoModelForCausalLM.from_pretrained(
109
+ base_model,
110
+ quantization_config=bnb_config,
111
+ device_map="auto",
112
+ attn_implementation=attn_implementation
113
+ )
114
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
115
+ model, tokenizer = setup_chat_format(model, tokenizer)
116
+
117
+ progress.status = "preparing_model"
118
+ progress.message = "Setting up LoRA configuration..."
119
+
120
+ # === LoRA Config ===
121
+ peft_config = LoraConfig(
122
+ r=16,
123
+ lora_alpha=32,
124
+ lora_dropout=0.05,
125
+ bias="none",
126
+ task_type="CAUSAL_LM",
127
+ target_modules=[
128
+ 'up_proj', 'down_proj', 'gate_proj',
129
+ 'k_proj', 'q_proj', 'v_proj', 'o_proj'
130
+ ]
131
+ )
132
+ model = get_peft_model(model, peft_config)
133
+
134
+ progress.status = "loading_dataset"
135
+ progress.message = "Loading and preparing dataset..."
136
+
137
+ # === Load & Prepare Dataset ===
138
+ dataset = load_dataset(dataset_name, split="all")
139
+ dataset = dataset.shuffle(seed=65).select(range(1000)) # Use 1000 samples
140
+
141
+ def format_chat_template(row, tokenizer):
142
+ row_json = [
143
+ {"role": "user", "content": row["Patient"]},
144
+ {"role": "assistant", "content": row["Doctor"]}
145
+ ]
146
+ row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
147
+ return row
148
+
149
+ dataset = dataset.map(
150
+ format_chat_template,
151
+ fn_kwargs={"tokenizer": tokenizer},
152
+ num_proc=4
153
+ )
154
+
155
+ dataset = dataset.train_test_split(test_size=0.1)
156
+
157
+ # Calculate total training steps
158
+ train_size = len(dataset["train"])
159
+ batch_size = 1
160
+ gradient_accumulation_steps = 2
161
+ num_epochs = 1
162
+
163
+ steps_per_epoch = train_size // (batch_size * gradient_accumulation_steps)
164
+ total_steps = steps_per_epoch * num_epochs
165
+
166
+ progress.total_steps = total_steps
167
+ progress.status = "training"
168
+ progress.message = "Starting training..."
169
+
170
+ # === Training Arguments ===
171
+ training_args = TrainingArguments(
172
+ output_dir=new_model,
173
+ per_device_train_batch_size=batch_size,
174
+ per_device_eval_batch_size=1,
175
+ gradient_accumulation_steps=gradient_accumulation_steps,
176
+ optim="paged_adamw_32bit",
177
+ num_train_epochs=num_epochs,
178
+ eval_steps=0.2,
179
+ logging_steps=1,
180
+ warmup_steps=10,
181
+ logging_strategy="steps",
182
+ learning_rate=2e-5,
183
+ fp16=False,
184
+ bf16=False,
185
+ group_by_length=True,
186
+ save_steps=50,
187
+ save_total_limit=2,
188
+ report_to=None # Disable wandb for HF Spaces
189
+ )
190
+
191
+ # === Data Collator ===
192
+ tokenizer.model_max_length = 512
193
+
194
+ # Custom callback to track progress
195
+ class ProgressCallback:
196
+ def __init__(self, progress_tracker):
197
+ self.progress_tracker = progress_tracker
198
+ self.last_update = time.time()
199
+
200
+ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
201
+ current_time = time.time()
202
+ # Update every 10 seconds or on significant step changes
203
+ if current_time - self.last_update >= 10 or state.global_step % 10 == 0:
204
+ self.progress_tracker.update_progress(
205
+ state.global_step,
206
+ state.max_steps,
207
+ f"Training step {state.global_step}/{state.max_steps}"
208
+ )
209
+ self.last_update = current_time
210
+
211
+ # === Trainer Initialization ===
212
+ trainer = SFTTrainer(
213
+ model=model,
214
+ train_dataset=dataset["train"],
215
+ eval_dataset=dataset["test"],
216
+ peft_config=peft_config,
217
+ args=training_args,
218
+ callbacks=[ProgressCallback(progress)]
219
+ )
220
+
221
+ # === Train & Save ===
222
+ trainer.train()
223
+ trainer.save_model(new_model)
224
+
225
+ progress.status = "completed"
226
+ progress.progress = 100
227
+ progress.message = f"Training completed! Model saved as {new_model}"
228
+
229
+ except Exception as e:
230
+ progress.status = "error"
231
+ progress.error = str(e)
232
+ progress.message = f"Training failed: {str(e)}"
233
+
234
+ # ============== API ROUTES ==============
235
+ @app.route('/api/train', methods=['POST'])
236
+ def start_training():
237
+ """Start training and return job ID for tracking"""
238
+ try:
239
+ job_id = str(uuid.uuid4())[:8] # Short UUID
240
+ progress = TrainingProgress(job_id)
241
+ training_jobs[job_id] = progress
242
+
243
+ # Start training in background thread
244
+ training_thread = threading.Thread(target=train_model_background, args=(job_id,))
245
+ training_thread.daemon = True
246
+ training_thread.start()
247
+
248
+ return jsonify({
249
+ "status": "started",
250
+ "job_id": job_id,
251
+ "message": "Training started. Use /api/status/<job_id> to track progress."
252
+ })
253
+
254
+ except Exception as e:
255
+ return jsonify({"status": "error", "message": str(e)}), 500
256
+
257
+ @app.route('/api/status/<job_id>', methods=['GET'])
258
+ def get_training_status(job_id):
259
+ """Get training progress and estimated completion time"""
260
+ if job_id not in training_jobs:
261
+ return jsonify({"status": "error", "message": "Job not found"}), 404
262
+
263
+ progress = training_jobs[job_id]
264
+ return jsonify(progress.to_dict())
265
+
266
+ @app.route('/api/jobs', methods=['GET'])
267
+ def list_jobs():
268
+ """List all training jobs"""
269
+ jobs = {job_id: progress.to_dict() for job_id, progress in training_jobs.items()}
270
+ return jsonify({"jobs": jobs})
271
+
272
+ @app.route('/')
273
+ def home():
274
+ return jsonify({
275
+ "message": "Welcome to LLaMA Fine-tuning API!",
276
+ "endpoints": {
277
+ "POST /api/train": "Start training",
278
+ "GET /api/status/<job_id>": "Get training status",
279
+ "GET /api/jobs": "List all jobs"
280
+ }
281
+ })
282
+
283
+ @app.route('/health')
284
+ def health():
285
+ return jsonify({"status": "healthy"})
286
+
287
+ if __name__ == '__main__':
288
+ port = int(os.environ.get('PORT', 7860)) # HF Spaces uses port 7860
289
+ app.run(host='0.0.0.0', port=port, debug=False)
main.py DELETED
File without changes