Cylanoid commited on
Commit
a16809f
·
verified ·
1 Parent(s): 19103d4

Update train_llama4.py

Browse files
Files changed (1) hide show
  1. train_llama4.py +12 -18
train_llama4.py CHANGED
@@ -1,7 +1,7 @@
1
  # train_llama4.py
2
- # Script to fine-tune Llama 4 Maverick for healthcare fraud detection
3
 
4
- from transformers import AutoProcessor, Llama4ForConditionalGeneration, Trainer, TrainingArguments
5
  from transformers import BitsAndBytesConfig
6
  import datasets
7
  import torch
@@ -22,9 +22,13 @@ if not LLama:
22
  raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
23
  huggingface_hub.login(token=LLama)
24
 
25
- # Load Llama 4 model and processor
26
  MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
27
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
28
 
29
  # Quantization config for A100 80 GB VRAM
30
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
@@ -56,18 +60,8 @@ print("First example from dataset:", dataset["train"][0])
56
 
57
  # Tokenization
58
  def tokenize_data(example):
59
- messages = [
60
- {
61
- "role": "user",
62
- "content": [{"type": "text", "text": example['input']}]
63
- },
64
- {
65
- "role": "assistant",
66
- "content": [{"type": "text", "text": example['output']}]
67
- }
68
- ]
69
- formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False)
70
- inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
71
  input_ids = inputs["input_ids"].squeeze(0).tolist()
72
  attention_mask = inputs["attention_mask"].squeeze(0).tolist()
73
  labels = input_ids.copy()
@@ -124,5 +118,5 @@ trainer = Trainer(
124
  # Start training
125
  trainer.train()
126
  model.save_pretrained("./fine_tuned_llama4_healthcare")
127
- processor.save_pretrained("./fine_tuned_llama4_healthcare")
128
- print("Training complete. Model and processor saved to ./fine_tuned_llama4_healthcare")
 
1
  # train_llama4.py
2
+ # Script to fine-tune Llama 4 Maverick for healthcare fraud detection (text-only)
3
 
4
+ from transformers import AutoTokenizer, Llama4ForConditionalGeneration, Trainer, TrainingArguments
5
  from transformers import BitsAndBytesConfig
6
  import datasets
7
  import torch
 
22
  raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.")
23
  huggingface_hub.login(token=LLama)
24
 
25
+ # Load Llama 4 model and tokenizer
26
  MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
28
+
29
+ # Add padding token if it doesn't exist
30
+ if tokenizer.pad_token is None:
31
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
32
 
33
  # Quantization config for A100 80 GB VRAM
34
  quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
60
 
61
  # Tokenization
62
  def tokenize_data(example):
63
+ formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
64
+ inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
65
  input_ids = inputs["input_ids"].squeeze(0).tolist()
66
  attention_mask = inputs["attention_mask"].squeeze(0).tolist()
67
  labels = input_ids.copy()
 
118
  # Start training
119
  trainer.train()
120
  model.save_pretrained("./fine_tuned_llama4_healthcare")
121
+ tokenizer.save_pretrained("./fine_tuned_llama4_healthcare")
122
+ print("Training complete. Model and tokenizer saved to ./fine_tuned_llama4_healthcare")