Update train_llama4.py
Browse files- train_llama4.py +12 -18
train_llama4.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# train_llama4.py
|
2 |
-
# Script to fine-tune Llama 4 Maverick for healthcare fraud detection
|
3 |
|
4 |
-
from transformers import
|
5 |
from transformers import BitsAndBytesConfig
|
6 |
import datasets
|
7 |
import torch
|
@@ -22,9 +22,13 @@ if not LLama:
|
|
22 |
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
|
23 |
huggingface_hub.login(token=LLama)
|
24 |
|
25 |
-
# Load Llama 4 model and
|
26 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
|
29 |
# Quantization config for A100 80 GB VRAM
|
30 |
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
@@ -56,18 +60,8 @@ print("First example from dataset:", dataset["train"][0])
|
|
56 |
|
57 |
# Tokenization
|
58 |
def tokenize_data(example):
|
59 |
-
|
60 |
-
|
61 |
-
"role": "user",
|
62 |
-
"content": [{"type": "text", "text": example['input']}]
|
63 |
-
},
|
64 |
-
{
|
65 |
-
"role": "assistant",
|
66 |
-
"content": [{"type": "text", "text": example['output']}]
|
67 |
-
}
|
68 |
-
]
|
69 |
-
formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
70 |
-
inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
|
71 |
input_ids = inputs["input_ids"].squeeze(0).tolist()
|
72 |
attention_mask = inputs["attention_mask"].squeeze(0).tolist()
|
73 |
labels = input_ids.copy()
|
@@ -124,5 +118,5 @@ trainer = Trainer(
|
|
124 |
# Start training
|
125 |
trainer.train()
|
126 |
model.save_pretrained("./fine_tuned_llama4_healthcare")
|
127 |
-
|
128 |
-
print("Training complete. Model and
|
|
|
1 |
# train_llama4.py
|
2 |
+
# Script to fine-tune Llama 4 Maverick for healthcare fraud detection (text-only)
|
3 |
|
4 |
+
from transformers import AutoTokenizer, Llama4ForConditionalGeneration, Trainer, TrainingArguments
|
5 |
from transformers import BitsAndBytesConfig
|
6 |
import datasets
|
7 |
import torch
|
|
|
22 |
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
|
23 |
huggingface_hub.login(token=LLama)
|
24 |
|
25 |
+
# Load Llama 4 model and tokenizer
|
26 |
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
28 |
+
|
29 |
+
# Add padding token if it doesn't exist
|
30 |
+
if tokenizer.pad_token is None:
|
31 |
+
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
32 |
|
33 |
# Quantization config for A100 80 GB VRAM
|
34 |
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
|
|
60 |
|
61 |
# Tokenization
|
62 |
def tokenize_data(example):
|
63 |
+
formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
|
64 |
+
inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
input_ids = inputs["input_ids"].squeeze(0).tolist()
|
66 |
attention_mask = inputs["attention_mask"].squeeze(0).tolist()
|
67 |
labels = input_ids.copy()
|
|
|
118 |
# Start training
|
119 |
trainer.train()
|
120 |
model.save_pretrained("./fine_tuned_llama4_healthcare")
|
121 |
+
tokenizer.save_pretrained("./fine_tuned_llama4_healthcare")
|
122 |
+
print("Training complete. Model and tokenizer saved to ./fine_tuned_llama4_healthcare")
|