Shriti09 commited on
Commit
b3e9d02
·
verified ·
1 Parent(s): 6302ce8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -24,24 +24,26 @@ base_model = AutoModelForCausalLM.from_pretrained(
24
  base_model_id,
25
  quantization_config=None, # Load base normally first
26
  torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype
27
- device_map="auto", # Let accelerate handle device mapping
 
28
  trust_remote_code=True
29
  )
30
  base_model.config.use_cache = True # Enable cache for inference speed
 
31
 
 
32
  print(f"Loading PEFT adapter from: {adapter_path}")
33
  # Load the PEFT model (adapter) on top of the base model
 
34
  model = PeftModel.from_pretrained(base_model, adapter_path)
35
  print("Adapter loaded.")
36
 
 
37
  print("Merging adapter weights...")
38
- # Merge the adapter weights into the base model
39
- # This creates a new model that doesn't need PEFT library for inference
40
- # Note: This might require significant RAM during the merge process
41
  model = model.merge_and_unload()
42
- print("Adapter merged.")
43
 
44
- # Load the tokenizer associated with the base model
45
  print("Loading tokenizer...")
46
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
47
 
 
24
  base_model_id,
25
  quantization_config=None, # Load base normally first
26
  torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype
27
+ # device_map="auto", # <--- REMOVE THIS LINE
28
+ device_map=device, # <--- CHANGE TO THIS (load directly to device)
29
  trust_remote_code=True
30
  )
31
  base_model.config.use_cache = True # Enable cache for inference speed
32
+ print(f"Base model loaded to device: {device}")
33
 
34
+ # --- Load PEFT Adapter ---
35
  print(f"Loading PEFT adapter from: {adapter_path}")
36
  # Load the PEFT model (adapter) on top of the base model
37
+ # Ensure the base_model is on the correct device before loading PEFT
38
  model = PeftModel.from_pretrained(base_model, adapter_path)
39
  print("Adapter loaded.")
40
 
41
+ # --- Merge Adapter ---
42
  print("Merging adapter weights...")
 
 
 
43
  model = model.merge_and_unload()
44
+ print("Adapter merged.") # Model should now be on the device specified earlier
45
 
46
+ # --- Load Tokenizer ---
47
  print("Loading tokenizer...")
48
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
49