trl-4-dnd / quickstart.py
vishaljoshi24's picture
Chagned LLM to gpt2
5dcd105
raw
history blame
222 Bytes
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="openai-community/gpt2",
train_dataset=dataset,
)
trainer.train()