Spaces:
Sleeping
Sleeping
Update train_model.py
Browse files- train_model.py +14 -21
train_model.py
CHANGED
@@ -14,18 +14,20 @@ from huggingface_hub import HfApi, HfFolder
|
|
14 |
import logging
|
15 |
|
16 |
def main():
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
24 |
|
25 |
parser = argparse.ArgumentParser()
|
26 |
parser.add_argument("--task", type=str, required=True, help="Task type: generation or classification")
|
27 |
parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
|
28 |
-
parser.add_argument("--
|
29 |
parser.add_argument("--num_layers", type=int, default=12)
|
30 |
parser.add_argument("--attention_heads", type=int, default=1)
|
31 |
parser.add_argument("--hidden_size", type=int, default=64)
|
@@ -53,26 +55,17 @@ def main():
|
|
53 |
|
54 |
# Load and prepare dataset
|
55 |
if args.task == "generation":
|
56 |
-
dataset = load_dataset(
|
57 |
def tokenize_function(examples):
|
58 |
return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
|
59 |
elif args.task == "classification":
|
60 |
-
|
61 |
-
|
62 |
-
lines = f.readlines()
|
63 |
-
texts = []
|
64 |
-
labels = []
|
65 |
-
for line in lines:
|
66 |
-
parts = line.strip().split("\t")
|
67 |
-
if len(parts) == 2:
|
68 |
-
texts.append(parts[0])
|
69 |
-
labels.append(int(parts[1]))
|
70 |
-
dataset = Dataset.from_dict({"text": texts, "label": labels})
|
71 |
def tokenize_function(examples):
|
72 |
return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
|
73 |
else:
|
74 |
raise ValueError("Unsupported task type")
|
75 |
-
|
76 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
77 |
|
78 |
if args.task == "generation":
|
|
|
14 |
import logging
|
15 |
|
16 |
def main():
|
17 |
+
# ... existing code ...
|
18 |
+
if args.task == "generation":
|
19 |
+
dataset = load_dataset(args.dataset_name, split='train') # Load dataset by name
|
20 |
+
elif args.task == "classification":
|
21 |
+
dataset = load_dataset(args.dataset_name, split='train') # Adjust if necessary
|
22 |
+
else:
|
23 |
+
raise ValueError("Unsupported task type")
|
24 |
+
# ... existing code ...
|
25 |
+
|
26 |
|
27 |
parser = argparse.ArgumentParser()
|
28 |
parser.add_argument("--task", type=str, required=True, help="Task type: generation or classification")
|
29 |
parser.add_argument("--model_name", type=str, required=True, help="Name of the model")
|
30 |
+
parser.add_argument("--dataset_name", type=str, required=True, help="Name of the Hugging Face dataset")
|
31 |
parser.add_argument("--num_layers", type=int, default=12)
|
32 |
parser.add_argument("--attention_heads", type=int, default=1)
|
33 |
parser.add_argument("--hidden_size", type=int, default=64)
|
|
|
55 |
|
56 |
# Load and prepare dataset
|
57 |
if args.task == "generation":
|
58 |
+
dataset = load_dataset(args.dataset_name, split='train')
|
59 |
def tokenize_function(examples):
|
60 |
return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
|
61 |
elif args.task == "classification":
|
62 |
+
dataset = load_dataset(args.dataset_name, split='train')
|
63 |
+
# Assuming the dataset has 'text' and 'label' columns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def tokenize_function(examples):
|
65 |
return tokenizer(examples['text'], truncation=True, max_length=args.sequence_length)
|
66 |
else:
|
67 |
raise ValueError("Unsupported task type")
|
68 |
+
|
69 |
tokenized_datasets = dataset.map(tokenize_function, batched=True)
|
70 |
|
71 |
if args.task == "generation":
|