Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +14 -8
train_model.py
CHANGED
@@ -16,7 +16,7 @@ import torch
|
|
16 |
import os
|
17 |
from huggingface_hub import login, HfApi
|
18 |
import logging
|
19 |
-
from torch.optim import AdamW
|
20 |
|
21 |
def setup_logging(log_file_path):
|
22 |
"""
|
@@ -66,23 +66,28 @@ def load_and_prepare_dataset(task, dataset_name, tokenizer, sequence_length):
|
|
66 |
dataset = load_dataset(dataset_name, split='train')
|
67 |
logging.info("Dataset loaded successfully.")
|
68 |
|
|
|
|
|
|
|
69 |
def tokenize_function(examples):
|
70 |
try:
|
71 |
-
# Tokenize with truncation
|
72 |
tokens = tokenizer(
|
73 |
examples['text'],
|
74 |
truncation=True,
|
75 |
-
max_length=sequence_length,
|
76 |
-
padding=
|
77 |
-
return_tensors=None # Let the
|
78 |
)
|
|
|
|
|
79 |
return tokens
|
80 |
except Exception as e:
|
81 |
logging.error(f"Error during tokenization: {e}")
|
82 |
-
logging.error(f"
|
83 |
raise e
|
84 |
|
85 |
-
# Tokenize the dataset
|
86 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
87 |
logging.info("Dataset tokenization complete.")
|
88 |
return tokenized_datasets
|
@@ -210,7 +215,7 @@ def main():
|
|
210 |
if args.task == "generation":
|
211 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
212 |
elif args.task == "classification":
|
213 |
-
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) #
|
214 |
else:
|
215 |
logging.error("Unsupported task type for data collator.")
|
216 |
raise ValueError("Unsupported task type for data collator.")
|
@@ -276,3 +281,4 @@ def main():
|
|
276 |
|
277 |
if __name__ == "__main__":
|
278 |
main()
|
|
|
|
16 |
import os
|
17 |
from huggingface_hub import login, HfApi
|
18 |
import logging
|
19 |
+
from torch.optim import AdamW
|
20 |
|
21 |
def setup_logging(log_file_path):
|
22 |
"""
|
|
|
66 |
dataset = load_dataset(dataset_name, split='train')
|
67 |
logging.info("Dataset loaded successfully.")
|
68 |
|
69 |
+
# Log some examples to check dataset structure
|
70 |
+
logging.info(f"Example data from the dataset: {dataset[:5]}")
|
71 |
+
|
72 |
def tokenize_function(examples):
|
73 |
try:
|
74 |
+
# Tokenize with truncation and padding
|
75 |
tokens = tokenizer(
|
76 |
examples['text'],
|
77 |
truncation=True,
|
78 |
+
max_length=sequence_length,
|
79 |
+
padding='max_length', # Force padding to max length for debugging
|
80 |
+
return_tensors=None # Let the collator handle tensor conversion
|
81 |
)
|
82 |
+
# Log the tokens for debugging
|
83 |
+
logging.info(f"Tokenized example: {tokens}")
|
84 |
return tokens
|
85 |
except Exception as e:
|
86 |
logging.error(f"Error during tokenization: {e}")
|
87 |
+
logging.error(f"Problematic example: {examples}")
|
88 |
raise e
|
89 |
|
90 |
+
# Tokenize the dataset
|
91 |
tokenized_datasets = dataset.shuffle(seed=42).select(range(500)).map(tokenize_function, batched=True)
|
92 |
logging.info("Dataset tokenization complete.")
|
93 |
return tokenized_datasets
|
|
|
215 |
if args.task == "generation":
|
216 |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
217 |
elif args.task == "classification":
|
218 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Handle padding dynamically during batching
|
219 |
else:
|
220 |
logging.error("Unsupported task type for data collator.")
|
221 |
raise ValueError("Unsupported task type for data collator.")
|
|
|
281 |
|
282 |
if __name__ == "__main__":
|
283 |
main()
|
284 |
+
|