art-manuh commited on
Commit
25ebe24
·
verified ·
1 Parent(s): a6835cd

Changed model

Browse files
Files changed (1) hide show
  1. app.py +12 -11
app.py CHANGED
@@ -14,7 +14,7 @@ dataset = load_dataset("mwitiderrick/swahili")
14
  print(f"Dataset columns: {dataset['train'].column_names}")
15
 
16
  # Initialize the tokenizer and model
17
- model_name = "gpt2" # Use GPT-2 for text generation
18
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
19
  model = GPT2LMHeadModel.from_pretrained(model_name)
20
 
@@ -24,21 +24,21 @@ tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
24
 
25
  # Preprocess the dataset
26
  def preprocess_function(examples):
27
- # Tokenize and format the dataset
28
  encodings = tokenizer(
29
- examples['text'], # Use 'text' column from your dataset
30
  truncation=True,
31
- padding='max_length', # Ensure consistent length
32
  max_length=512
33
  )
34
- encodings['labels'] = encodings['input_ids'] # Use input_ids directly as labels
35
  return encodings
36
 
37
  # Tokenize the dataset
38
  try:
39
  tokenized_datasets = dataset.map(
40
  preprocess_function,
41
- batched=True
 
42
  )
43
  except Exception as e:
44
  print(f"Error during tokenization: {e}")
@@ -46,13 +46,14 @@ except Exception as e:
46
  # Define training arguments
47
  training_args = TrainingArguments(
48
  output_dir='./results',
49
- per_device_train_batch_size=4,
50
  num_train_epochs=1,
51
  logging_dir='./logs',
52
- logging_steps=500, # Log every 500 steps
53
- evaluation_strategy="steps", # Use evaluation strategy
54
- save_steps=10_000, # Save checkpoint every 10,000 steps
55
- save_total_limit=2, # Keep only the last 2 checkpoints
 
56
  )
57
 
58
  # Define Trainer
 
14
  print(f"Dataset columns: {dataset['train'].column_names}")
15
 
16
  # Initialize the tokenizer and model
17
+ model_name = "gpt2-small" # Use a smaller variant of GPT-2 for efficiency
18
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
19
  model = GPT2LMHeadModel.from_pretrained(model_name)
20
 
 
24
 
25
  # Preprocess the dataset
26
  def preprocess_function(examples):
 
27
  encodings = tokenizer(
28
+ examples['text'],
29
  truncation=True,
30
+ padding='max_length',
31
  max_length=512
32
  )
33
+ encodings['labels'] = encodings['input_ids']
34
  return encodings
35
 
36
  # Tokenize the dataset
37
  try:
38
  tokenized_datasets = dataset.map(
39
  preprocess_function,
40
+ batched=True,
41
+ batch_size=1000 # Adjust batch size for efficiency
42
  )
43
  except Exception as e:
44
  print(f"Error during tokenization: {e}")
 
46
  # Define training arguments
47
  training_args = TrainingArguments(
48
  output_dir='./results',
49
+ per_device_train_batch_size=2, # Lowered batch size to prevent OOM errors
50
  num_train_epochs=1,
51
  logging_dir='./logs',
52
+ logging_steps=500,
53
+ evaluation_strategy="steps",
54
+ save_steps=5000, # Save checkpoints more frequently
55
+ save_total_limit=2,
56
+ gradient_accumulation_steps=8, # Accumulate gradients to simulate larger batch size
57
  )
58
 
59
  # Define Trainer