|
|
|
import argparse |
|
import subprocess |
|
import os |
|
|
|
parser = argparse.ArgumentParser(description="Demo script for the model.") |
|
|
|
parser.add_argument("--model", type=str) |
|
parser.add_argument("--dataset", type=str) |
|
parser.add_argument("--flavor", type=str) |
|
parser.add_argument("--token", type=str) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
config_file = f"configs/{args.model}_{args.flavor}.yaml" |
|
|
|
|
|
if not os.path.exists(config_file): |
|
raise RuntimeError(f"Training model {args.model} with flavor {args.flavor} is not supported.") |
|
|
|
|
|
subprocess.run(["trl", "sft", "--config", config_file, "--dataset_name", args.dataset]) |
|
|