File size: 222 Bytes
a080fe0
 
 
 
 
 
5dcd105
a080fe0
 
 
1
2
3
4
5
6
7
8
9
10
from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="openai-community/gpt2",
    train_dataset=dataset,
)
trainer.train()