dasds
Browse files- app.py +6 -6
- train_llama4.py +29 -35
app.py
CHANGED
@@ -68,12 +68,12 @@ print("Loading model with: quantization_config=", quant_config, ", device_map=",
|
|
68 |
# Load model with 8-bit quantization and CPU offloading
|
69 |
try:
|
70 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
except Exception as e:
|
78 |
print(f"Model loading failed: {str(e)}")
|
79 |
raise
|
|
|
68 |
# Load model with 8-bit quantization and CPU offloading
|
69 |
try:
|
70 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
71 |
+
MODEL_ID,
|
72 |
+
device_map="auto",
|
73 |
+
torch_dtype=torch.float16,
|
74 |
+
quantization_config=quant_config,
|
75 |
+
offload_folder="./offload"
|
76 |
+
)
|
77 |
except Exception as e:
|
78 |
print(f"Model loading failed: {str(e)}")
|
79 |
raise
|
train_llama4.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
# train_llama4.py
|
2 |
-
# Script to fine-tune Llama 4 Maverick for healthcare fraud detection
|
3 |
|
4 |
-
from transformers import AutoTokenizer, Llama4ForConditionalGeneration
|
5 |
import datasets
|
6 |
import torch
|
7 |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
@@ -9,55 +8,50 @@ from accelerate import Accelerator
|
|
9 |
import huggingface_hub
|
10 |
import os
|
11 |
|
12 |
-
|
13 |
-
print("Running train_llama4.py with CPU offloading (version: 2025-04-21 v2)")
|
14 |
|
|
|
15 |
# Authenticate with Hugging Face
|
16 |
-
|
17 |
-
if not
|
18 |
-
raise ValueError("LLama token not found. Set it in
|
19 |
-
huggingface_hub.login(token=
|
20 |
|
21 |
-
#
|
|
|
22 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
23 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
24 |
-
|
25 |
if tokenizer.pad_token is None:
|
26 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
"lm_head": 0
|
35 |
-
}
|
36 |
-
|
37 |
-
# Debug: Confirm offloading settings
|
38 |
-
print("Loading model with CPU offloading: llm_int8_enable_fp32_cpu_offload=True, device_map=", device_map)
|
39 |
|
40 |
-
|
41 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
42 |
MODEL_ID,
|
43 |
torch_dtype=torch.bfloat16,
|
44 |
-
device_map=
|
45 |
-
quantization_config=
|
46 |
-
|
47 |
-
attn_implementation="flex_attention"
|
48 |
)
|
49 |
|
50 |
-
# Resize
|
51 |
model.resize_token_embeddings(len(tokenizer))
|
52 |
|
53 |
-
#
|
|
|
54 |
accelerator = Accelerator()
|
55 |
model = accelerator.prepare(model)
|
56 |
|
57 |
-
# Load
|
58 |
dataset = datasets.load_dataset('json', data_files="Bingaman_training_data.json")['train']
|
59 |
|
60 |
-
# LoRA
|
61 |
lora_config = LoraConfig(
|
62 |
r=16,
|
63 |
lora_alpha=32,
|
@@ -67,7 +61,6 @@ lora_config = LoraConfig(
|
|
67 |
task_type="CAUSAL_LM"
|
68 |
)
|
69 |
|
70 |
-
# Prepare model for fine-tuning
|
71 |
model = prepare_model_for_kbit_training(model)
|
72 |
model = get_peft_model(model, lora_config)
|
73 |
|
@@ -87,16 +80,17 @@ training_args = {
|
|
87 |
"lr_scheduler_type": "cosine"
|
88 |
}
|
89 |
|
90 |
-
# Initialize
|
91 |
trainer = accelerator.prepare(
|
92 |
datasets.Trainer(
|
93 |
model=model,
|
94 |
args=datasets.TrainingArguments(**training_args),
|
95 |
-
train_dataset=dataset
|
96 |
)
|
97 |
)
|
98 |
|
99 |
-
#
|
|
|
100 |
trainer.train()
|
101 |
model.save_pretrained("./fine_tuned_model")
|
102 |
-
print("Training completed!")
|
|
|
1 |
# train_llama4.py
|
|
|
2 |
|
3 |
+
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, BitsAndBytesConfig
|
4 |
import datasets
|
5 |
import torch
|
6 |
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
|
|
8 |
import huggingface_hub
|
9 |
import os
|
10 |
|
11 |
+
print("Running train_llama4.py with CPU offloading (version: 2025-04-22 v1)")
|
|
|
12 |
|
13 |
+
# ββββββββββββββββββββββββββ
|
14 |
# Authenticate with Hugging Face
|
15 |
+
LLAMA = os.getenv("LLama")
|
16 |
+
if not LLAMA:
|
17 |
+
raise ValueError("LLama token not found. Set it in environment as 'LLama'.")
|
18 |
+
huggingface_hub.login(token=LLAMA)
|
19 |
|
20 |
+
# ββββββββββββββββββββββββββ
|
21 |
+
# Tokenizer
|
22 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
23 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
|
24 |
if tokenizer.pad_token is None:
|
25 |
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
26 |
|
27 |
+
# ββββββββββββββββββββββββββ
|
28 |
+
# Quantization + CPU offβload config
|
29 |
+
quant_config = BitsAndBytesConfig(
|
30 |
+
load_in_8bit=True,
|
31 |
+
llm_int8_enable_fp32_cpu_offload=True
|
32 |
+
)
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
print("Loading model with 8-bit quantization, CPU offload, and automatic device mapping")
|
35 |
model = Llama4ForConditionalGeneration.from_pretrained(
|
36 |
MODEL_ID,
|
37 |
torch_dtype=torch.bfloat16,
|
38 |
+
device_map="auto",
|
39 |
+
quantization_config=quant_config,
|
40 |
+
offload_folder="./offload"
|
|
|
41 |
)
|
42 |
|
43 |
+
# Resize embeddings if we added [PAD]
|
44 |
model.resize_token_embeddings(len(tokenizer))
|
45 |
|
46 |
+
# ββββββββββββββββββββββββββ
|
47 |
+
# Prepare for training
|
48 |
accelerator = Accelerator()
|
49 |
model = accelerator.prepare(model)
|
50 |
|
51 |
+
# Load training data
|
52 |
dataset = datasets.load_dataset('json', data_files="Bingaman_training_data.json")['train']
|
53 |
|
54 |
+
# LoRA setup
|
55 |
lora_config = LoraConfig(
|
56 |
r=16,
|
57 |
lora_alpha=32,
|
|
|
61 |
task_type="CAUSAL_LM"
|
62 |
)
|
63 |
|
|
|
64 |
model = prepare_model_for_kbit_training(model)
|
65 |
model = get_peft_model(model, lora_config)
|
66 |
|
|
|
80 |
"lr_scheduler_type": "cosine"
|
81 |
}
|
82 |
|
83 |
+
# Initialize Trainer via Accelerate
|
84 |
trainer = accelerator.prepare(
|
85 |
datasets.Trainer(
|
86 |
model=model,
|
87 |
args=datasets.TrainingArguments(**training_args),
|
88 |
+
train_dataset=dataset
|
89 |
)
|
90 |
)
|
91 |
|
92 |
+
# ββββββββββββββββββββββββββ
|
93 |
+
# Run training
|
94 |
trainer.train()
|
95 |
model.save_pretrained("./fine_tuned_model")
|
96 |
+
print("Training completed!")
|