Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +17 -29
train_model.py
CHANGED
@@ -12,10 +12,10 @@ from transformers import (
|
|
12 |
DataCollatorForLanguageModeling,
|
13 |
DataCollatorWithPadding,
|
14 |
)
|
15 |
-
from datasets import load_dataset
|
16 |
import torch
|
17 |
import os
|
18 |
-
from huggingface_hub import login, HfApi
|
19 |
import logging
|
20 |
|
21 |
from torch.optim import AdamW # Import PyTorch's AdamW
|
@@ -34,10 +34,9 @@ def setup_logging(log_file_path):
|
|
34 |
f_handler.setLevel(logging.INFO)
|
35 |
|
36 |
# Create formatters and add to handlers
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
f_handler.setFormatter(f_format)
|
41 |
|
42 |
# Add handlers to the logger
|
43 |
logger.addHandler(c_handler)
|
@@ -66,30 +65,18 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
66 |
"""
|
67 |
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
|
68 |
try:
|
69 |
-
if
|
70 |
-
|
71 |
-
|
72 |
-
dataset, config = dataset_name.split('/', 1)
|
73 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
74 |
-
else:
|
75 |
-
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
76 |
-
logging.info("Dataset loaded successfully for generation task.")
|
77 |
-
def tokenize_function(examples):
|
78 |
-
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
79 |
-
elif task == "classification":
|
80 |
-
if '/' in dataset_name:
|
81 |
-
dataset, config = dataset_name.split('/', 1)
|
82 |
-
dataset = load_dataset(dataset, config, split='train')
|
83 |
-
else:
|
84 |
-
dataset = load_dataset(dataset_name, split='train')
|
85 |
-
logging.info("Dataset loaded successfully for classification task.")
|
86 |
-
# Assuming the dataset has 'text' and 'label' columns
|
87 |
-
def tokenize_function(examples):
|
88 |
-
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
89 |
else:
|
90 |
-
|
91 |
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
94 |
logging.info("Dataset tokenization complete.")
|
95 |
return tokenized_datasets
|
@@ -186,7 +173,7 @@ def main():
|
|
186 |
logging.info("Setting pad_token to eos_token.")
|
187 |
tokenizer.pad_token = tokenizer.eos_token
|
188 |
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
189 |
-
#
|
190 |
model = initialize_model(
|
191 |
task=args.task,
|
192 |
model_name=args.model_name,
|
@@ -315,3 +302,4 @@ def main():
|
|
315 |
|
316 |
if __name__ == "__main__":
|
317 |
main()
|
|
|
|
12 |
DataCollatorForLanguageModeling,
|
13 |
DataCollatorWithPadding,
|
14 |
)
|
15 |
+
from datasets import load_dataset
|
16 |
import torch
|
17 |
import os
|
18 |
+
from huggingface_hub import login, HfApi
|
19 |
import logging
|
20 |
|
21 |
from torch.optim import AdamW # Import PyTorch's AdamW
|
|
|
34 |
f_handler.setLevel(logging.INFO)
|
35 |
|
36 |
# Create formatters and add to handlers
|
37 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
38 |
+
c_handler.setFormatter(formatter)
|
39 |
+
f_handler.setFormatter(formatter)
|
|
|
40 |
|
41 |
# Add handlers to the logger
|
42 |
logger.addHandler(c_handler)
|
|
|
65 |
"""
|
66 |
logging.info(f"Loading dataset '{dataset_name}' for task '{task}'...")
|
67 |
try:
|
68 |
+
if '/' in dataset_name:
|
69 |
+
dataset, config = dataset_name.split('/', 1)
|
70 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
else:
|
72 |
+
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
|
73 |
|
74 |
+
logging.info("Dataset loaded successfully.")
|
75 |
+
|
76 |
+
def tokenize_function(examples):
|
77 |
+
return tokenizer(examples['text'], truncation=True, max_length=sequence_length)
|
78 |
+
|
79 |
+
# Tokenize the dataset
|
80 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
81 |
logging.info("Dataset tokenization complete.")
|
82 |
return tokenized_datasets
|
|
|
173 |
logging.info("Setting pad_token to eos_token.")
|
174 |
tokenizer.pad_token = tokenizer.eos_token
|
175 |
logging.info(f"Tokenizer pad_token set to: {tokenizer.pad_token}")
|
176 |
+
# Initialize model after setting pad_token
|
177 |
model = initialize_model(
|
178 |
task=args.task,
|
179 |
model_name=args.model_name,
|
|
|
302 |
|
303 |
if __name__ == "__main__":
|
304 |
main()
|
305 |
+
|