Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,24 +24,26 @@ base_model = AutoModelForCausalLM.from_pretrained(
|
|
24 |
base_model_id,
|
25 |
quantization_config=None, # Load base normally first
|
26 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype
|
27 |
-
device_map="auto", #
|
|
|
28 |
trust_remote_code=True
|
29 |
)
|
30 |
base_model.config.use_cache = True # Enable cache for inference speed
|
|
|
31 |
|
|
|
32 |
print(f"Loading PEFT adapter from: {adapter_path}")
|
33 |
# Load the PEFT model (adapter) on top of the base model
|
|
|
34 |
model = PeftModel.from_pretrained(base_model, adapter_path)
|
35 |
print("Adapter loaded.")
|
36 |
|
|
|
37 |
print("Merging adapter weights...")
|
38 |
-
# Merge the adapter weights into the base model
|
39 |
-
# This creates a new model that doesn't need PEFT library for inference
|
40 |
-
# Note: This might require significant RAM during the merge process
|
41 |
model = model.merge_and_unload()
|
42 |
-
print("Adapter merged.")
|
43 |
|
44 |
-
# Load
|
45 |
print("Loading tokenizer...")
|
46 |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
|
47 |
|
|
|
24 |
base_model_id,
|
25 |
quantization_config=None, # Load base normally first
|
26 |
torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, # Use appropriate dtype
|
27 |
+
# device_map="auto", # <--- REMOVE THIS LINE
|
28 |
+
device_map=device, # <--- CHANGE TO THIS (load directly to device)
|
29 |
trust_remote_code=True
|
30 |
)
|
31 |
base_model.config.use_cache = True # Enable cache for inference speed
|
32 |
+
print(f"Base model loaded to device: {device}")
|
33 |
|
34 |
+
# --- Load PEFT Adapter ---
|
35 |
print(f"Loading PEFT adapter from: {adapter_path}")
|
36 |
# Load the PEFT model (adapter) on top of the base model
|
37 |
+
# Ensure the base_model is on the correct device before loading PEFT
|
38 |
model = PeftModel.from_pretrained(base_model, adapter_path)
|
39 |
print("Adapter loaded.")
|
40 |
|
41 |
+
# --- Merge Adapter ---
|
42 |
print("Merging adapter weights...")
|
|
|
|
|
|
|
43 |
model = model.merge_and_unload()
|
44 |
+
print("Adapter merged.") # Model should now be on the device specified earlier
|
45 |
|
46 |
+
# --- Load Tokenizer ---
|
47 |
print("Loading tokenizer...")
|
48 |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
|
49 |
|