Twelve2five commited on
Commit
2784605
·
verified ·
1 Parent(s): c6225e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -141,57 +141,59 @@ def load_model():
141
  bnb_4bit_use_double_quant=True,
142
  )
143
 
144
- # Use DeepSpeed if available
145
- try:
146
- from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
147
- use_deepspeed = True
148
- print("DeepSpeed available, will use ZeRO-3")
149
- except ImportError:
150
- use_deepspeed = False
151
- print("DeepSpeed not available, falling back to standard distribution")
152
 
153
- # Calculate per-GPU reserved memory (be very conservative)
154
- n_gpus = max(1, torch.cuda.device_count())
155
- max_memory = {i: f"{int(torch.cuda.get_device_properties(i).total_memory / 1e9) - 4}GB" for i in range(n_gpus)}
156
- max_memory["cpu"] = "32GB"
157
 
158
- print(f"Using {n_gpus} GPUs with memory configuration: {max_memory}")
 
159
 
160
- # Load model with proper device distribution
161
  model = AutoModelForCausalLM.from_pretrained(
162
  hf_model_repo_id,
163
  quantization_config=bnb_config,
164
- device_map="balanced_low_0", # Distribute evenly with priority to minimize GPU 0 usage
165
- max_memory=max_memory,
166
  trust_remote_code=True,
167
  use_cache=False,
168
  torch_dtype=torch.float16,
169
  low_cpu_mem_usage=True,
170
  )
171
- print(f"Loaded model vocab size: {model.config.vocab_size}")
172
- print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
173
 
174
- # --- Prepare for K-bit Training & Apply LoRA ---
175
- model = prepare_model_for_kbit_training(model)
176
 
 
 
 
 
 
177
  lora_config = LoraConfig(
178
- task_type=TaskType.CAUSAL_LM,
179
- r=16,
180
  lora_alpha=32,
181
  lora_dropout=0.05,
182
  bias="none",
183
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
 
184
  )
185
 
186
- peft_model = get_peft_model(model, lora_config)
187
- peft_model.print_trainable_parameters()
188
 
189
- # Cleanup
190
- gc.collect()
191
- if torch.cuda.is_available():
192
- torch.cuda.empty_cache()
 
193
 
194
- return peft_model
195
 
196
  def load_dataset():
197
  # --- Download the dataset repository files ---
 
141
  bnb_4bit_use_double_quant=True,
142
  )
143
 
144
+ # For 4-bit training, we need to load on a single device
145
+ # Choose GPU with most available memory
146
+ free_memory = []
147
+ for i in range(torch.cuda.device_count()):
148
+ total_memory = torch.cuda.get_device_properties(i).total_memory
149
+ reserved_memory = torch.cuda.memory_reserved(i)
150
+ free_memory.append((total_memory - reserved_memory) / 1e9) # Convert to GB
 
151
 
152
+ # Choose the GPU with the most free memory
153
+ target_gpu = free_memory.index(max(free_memory))
154
+ print(f"Loading model on GPU {target_gpu} with {free_memory[target_gpu]:.2f}GB free memory")
 
155
 
156
+ # Use target GPU for model loading (crucial for 4-bit training)
157
+ device_map = {'': target_gpu}
158
 
159
+ # Load model on the single target GPU
160
  model = AutoModelForCausalLM.from_pretrained(
161
  hf_model_repo_id,
162
  quantization_config=bnb_config,
163
+ device_map=device_map, # Place entire model on one GPU
 
164
  trust_remote_code=True,
165
  use_cache=False,
166
  torch_dtype=torch.float16,
167
  low_cpu_mem_usage=True,
168
  )
 
 
169
 
170
+ # Add print statement to check which device the model is on
171
+ print(f"Model loaded on device: {next(model.parameters()).device}")
172
 
173
+ # Continue with the LoRA config as before
174
+ print(f"Loaded model vocab size: {model.get_input_embeddings().weight.shape[0]}")
175
+ print(f"Input embedding shape: {model.get_input_embeddings().weight.shape}")
176
+
177
+ # --- Configure PEFT/LoRA ---
178
  lora_config = LoraConfig(
179
+ r=16, # rank
 
180
  lora_alpha=32,
181
  lora_dropout=0.05,
182
  bias="none",
183
+ task_type=TaskType.CAUSAL_LM,
184
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
185
  )
186
 
187
+ # Prepare model for k-bit training
188
+ model = prepare_model_for_kbit_training(model)
189
 
190
+ # Add LoRA adapters
191
+ model = get_peft_model(model, lora_config)
192
+
193
+ # Log number of trainable parameters
194
+ model.print_trainable_parameters()
195
 
196
+ return model
197
 
198
  def load_dataset():
199
  # --- Download the dataset repository files ---