Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,26 @@
|
|
1 |
import os
|
2 |
import torch
|
|
|
3 |
from unsloth import FastLanguageModel, is_bfloat16_supported
|
4 |
from trl import SFTTrainer
|
5 |
from transformers import TrainingArguments
|
6 |
from datasets import load_dataset
|
7 |
import gradio as gr
|
8 |
|
9 |
-
|
10 |
max_seq_length = 4096
|
11 |
dtype = None
|
12 |
load_in_4bit = True
|
13 |
hf_token = os.getenv("Token")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
print("Starting model and tokenizer loading...")
|
16 |
|
17 |
# Load the model and tokenizer
|
@@ -121,8 +130,10 @@ trainer = SFTTrainer(
|
|
121 |
weight_decay=0.01,
|
122 |
lr_scheduler_type="linear",
|
123 |
seed=3407,
|
124 |
-
local_rank=4,
|
125 |
output_dir="outputs",
|
|
|
|
|
|
|
126 |
),
|
127 |
)
|
128 |
print("Trainer initialized.")
|
@@ -142,4 +153,6 @@ model.push_to_hub_merged(
|
|
142 |
save_method="merged_16bit",
|
143 |
token=True
|
144 |
)
|
145 |
-
print("Model pushed to hub successfully.")
|
|
|
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
from unsloth import FastLanguageModel, is_bfloat16_supported
|
5 |
from trl import SFTTrainer
|
6 |
from transformers import TrainingArguments
|
7 |
from datasets import load_dataset
|
8 |
import gradio as gr
|
9 |
|
|
|
10 |
max_seq_length = 4096
|
11 |
dtype = None
|
12 |
load_in_4bit = True
|
13 |
hf_token = os.getenv("Token")
|
14 |
|
15 |
+
def setup_distributed_training():
|
16 |
+
dist.init_process_group(backend='nccl')
|
17 |
+
torch.cuda.set_device(dist.get_rank())
|
18 |
+
|
19 |
+
def cleanup_distributed_training():
|
20 |
+
dist.destroy_process_group()
|
21 |
+
|
22 |
+
setup_distributed_training()
|
23 |
+
|
24 |
print("Starting model and tokenizer loading...")
|
25 |
|
26 |
# Load the model and tokenizer
|
|
|
130 |
weight_decay=0.01,
|
131 |
lr_scheduler_type="linear",
|
132 |
seed=3407,
|
|
|
133 |
output_dir="outputs",
|
134 |
+
# Distributed training arguments
|
135 |
+
deepspeed=None, # If using deepspeed for further optimizations
|
136 |
+
local_rank=dist.get_rank(), # Add this line
|
137 |
),
|
138 |
)
|
139 |
print("Trainer initialized.")
|
|
|
153 |
save_method="merged_16bit",
|
154 |
token=True
|
155 |
)
|
156 |
+
print("Model pushed to hub successfully.")
|
157 |
+
|
158 |
+
cleanup_distributed_training()
|