jonngan commited on
Commit
2a17ce7
·
verified ·
1 Parent(s): 74462d3

Upload 2 files

Browse files
Files changed (2) hide show
  1. lockinai.py +45 -0
  2. train.py +152 -0
lockinai.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+
3
+ # Load your fine-tuned model and tokenizer
4
+ tokenizer = AutoTokenizer.from_pretrained("./lockin_model")
5
+ model = AutoModelForCausalLM.from_pretrained("./lockin_model")
6
+
7
+ # Function to generate yes/no questions
8
+ def generate_question(input_text, max_retries=20):
9
+ for _ in range(max_retries):
10
+ # Add padding and attention mask
11
+ inputs = tokenizer(
12
+ input_text,
13
+ return_tensors="pt",
14
+ padding=True,
15
+ truncation=True,
16
+ return_attention_mask=True
17
+ )
18
+
19
+ output = model.generate(
20
+ inputs["input_ids"],
21
+ attention_mask=inputs["attention_mask"],
22
+ max_new_tokens=100,
23
+ do_sample=True,
24
+ temperature=1.9,
25
+ top_p=0.8,
26
+ top_k=50,
27
+ pad_token_id=tokenizer.eos_token_id
28
+ )
29
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
30
+
31
+ # Remove the input text from the generated output
32
+ if generated_text.startswith(input_text):
33
+ generated_text = generated_text[len(input_text):].strip()
34
+
35
+ # If we got a non-empty response and it contains $LOCKIN, return it
36
+ if generated_text and "$LOCKIN" in generated_text:
37
+ return generated_text
38
+
39
+ # If all retries failed, return default question
40
+ return "Does $LOCKIN look great?"
41
+
42
+ # Example usage
43
+ prompt = "I need a yes/no question about $LOCKIN."
44
+ question = generate_question(prompt)
45
+ print("Generated Question:", question)
train.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, TrainerCallback
2
+ from datasets import load_dataset
3
+ import torch
4
+ import os
5
+ import psutil
6
+ import gc
7
+
8
+ # Memory management and environment setup
9
+ def cleanup_memory():
10
+ gc.collect()
11
+ torch.mps.empty_cache()
12
+ if hasattr(torch.cuda, 'empty_cache'):
13
+ torch.cuda.empty_cache()
14
+
15
+ # Set MPS memory limits and environment variables
16
+ # Note: Changed watermark ratio to a more conservative value
17
+ os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.7' # Changed from 0.8
18
+ os.environ['PYTORCH_MPS_LOW_WATERMARK_RATIO'] = '0.5' # Added explicit low watermark
19
+ os.environ['PYTORCH_MPS_ALLOCATOR_POLICY'] = 'garbage_collection_conservative'
20
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
21
+
22
+ # Memory monitoring
23
+ def print_memory_stats():
24
+ process = psutil.Process()
25
+ print(f"RAM Memory usage: {process.memory_info().rss / 1024 / 1024:.2f} MB")
26
+ if hasattr(torch.mps, 'current_allocated_memory'):
27
+ print(f"MPS Memory allocated: {torch.mps.current_allocated_memory() / 1024 / 1024:.2f} MB")
28
+
29
+ # Custom callback for memory monitoring
30
+ class MemoryCallback(TrainerCallback):
31
+ def __init__(self, print_memory_stats_fn):
32
+ self.print_memory_stats_fn = print_memory_stats_fn
33
+
34
+ def on_step_end(self, args, state, control, **kwargs):
35
+ if state.global_step % 100 == 0:
36
+ print(f"\nStep {state.global_step}:")
37
+ self.print_memory_stats_fn()
38
+ cleanup_memory()
39
+
40
+ # Set device
41
+ device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
42
+ print(f"Using device: {device}")
43
+
44
+ # Load model and tokenizer
45
+ model_name = "distilgpt2"
46
+ model = AutoModelForCausalLM.from_pretrained(
47
+ model_name,
48
+ use_cache=False,
49
+ torch_dtype=torch.float32
50
+ )
51
+ model.to(device) # Explicitly move model to device
52
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
53
+
54
+ # Add pad token
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
+ # Load and filter dataset
58
+ train_data = load_dataset("json", data_files={"train": "data.json"})
59
+
60
+ def filter_dataset(example):
61
+ return len(example["prompt"]) + len(example["completion"]) <= 512
62
+
63
+ train_data = train_data.filter(filter_dataset)
64
+
65
+ # Preprocess function
66
+ def preprocess_function(examples):
67
+ inputs = [prompt + tokenizer.eos_token + completion
68
+ for prompt, completion in zip(examples["prompt"], examples["completion"])]
69
+
70
+ model_inputs = tokenizer(
71
+ inputs,
72
+ max_length=256,
73
+ truncation=True,
74
+ padding="max_length"
75
+ )
76
+
77
+ model_inputs["labels"] = model_inputs["input_ids"].copy()
78
+ return model_inputs
79
+
80
+ # Preprocess the dataset
81
+ train_dataset = train_data["train"].map(preprocess_function, batched=True)
82
+
83
+ # Training arguments
84
+ training_args = TrainingArguments(
85
+ output_dir="./results",
86
+ num_train_epochs=15,
87
+ per_device_train_batch_size=1,
88
+ gradient_accumulation_steps=8, # Reduced from 32
89
+ logging_dir="./logs",
90
+ fp16=False,
91
+ eval_strategy="no",
92
+ learning_rate=1e-5, # Reduced from 5e-5
93
+ save_steps=100,
94
+ save_total_limit=2,
95
+ gradient_checkpointing=True,
96
+ optim="adamw_torch",
97
+ dataloader_num_workers=0,
98
+ dataloader_pin_memory=False,
99
+ torch_compile=False,
100
+ max_grad_norm=1.0, # Increased from 0.5
101
+ logging_steps=5, # More frequent logging
102
+ max_steps=1000,
103
+ warmup_steps=300, # Increased warmup steps
104
+ weight_decay=0.2, # Increased from 0.01
105
+ logging_first_step=True,
106
+ lr_scheduler_type="cosine_with_restarts", # Changed to cosine with restarts
107
+ warmup_ratio=0.15, # Increased warmup ratio
108
+ )
109
+
110
+ # Clear cache before training
111
+ cleanup_memory()
112
+
113
+ # Initialize trainer
114
+ trainer = Trainer(
115
+ model=model,
116
+ args=training_args,
117
+ train_dataset=train_dataset,
118
+ callbacks=[MemoryCallback(print_memory_stats)]
119
+ )
120
+
121
+ # Monitor initial memory usage
122
+ print("Initial memory usage:")
123
+ print_memory_stats()
124
+
125
+ # Training with error handling
126
+ try:
127
+ trainer.train()
128
+ except Exception as e:
129
+ print(f"Training error: {str(e)}")
130
+ cleanup_memory()
131
+ try:
132
+ model.save_pretrained("./lockin_model_partial")
133
+ tokenizer.save_pretrained("./lockin_model_partial")
134
+ print("Saved partial progress")
135
+ except:
136
+ print("Could not save partial progress")
137
+ raise e
138
+ finally:
139
+ cleanup_memory()
140
+
141
+ # Save the complete model
142
+ try:
143
+ model.save_pretrained("./lockin_model")
144
+ tokenizer.save_pretrained("./lockin_model")
145
+ print("Model saved successfully")
146
+ except Exception as e:
147
+ print(f"Error saving model: {str(e)}")
148
+
149
+ # Final cleanup
150
+ cleanup_memory()
151
+ print("\nFinal memory usage:")
152
+ print_memory_stats()