Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from os import PathLike | |
from model import DecoderBase, make_model | |
from rich.progress import ( | |
BarColumn, | |
MofNCompleteColumn, | |
Progress, | |
TextColumn, | |
TimeElapsedColumn, | |
) | |
def construct_contract_prompt(prompt: str, contract_type: str, contract: str) -> str: | |
if contract_type == "none": | |
return prompt | |
elif contract_type == "docstring": | |
# embed within the docstring | |
sep = "" | |
if '"""' in prompt: | |
sep = '"""' | |
elif "'''" in prompt: | |
sep = "'''" | |
assert sep != "" | |
l = prompt.split(sep) | |
contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) | |
l[1] = ( | |
l[1] + contract + "\n" + " " * (len(contract) - len(contract.lstrip()) - 1) | |
) | |
return sep.join(l) | |
elif contract_type == "code": | |
# at the beginning of the function | |
contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) | |
return prompt + contract | |
def code_generate(args, workdir: PathLike, model: DecoderBase, id_range=None): | |
with Progress( | |
TextColumn( | |
f"{args.dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%" | |
), | |
BarColumn(), | |
MofNCompleteColumn(), | |
TextColumn("•"), | |
TimeElapsedColumn(), | |
) as p: | |
if args.dataset == "humaneval": | |
from evalplus.data import get_human_eval_plus | |
dataset = get_human_eval_plus() | |
elif args.dataset == "mbpp": | |
from evalplus.data import get_mbpp_plus | |
dataset = get_mbpp_plus() | |
for task_id, task in p.track(dataset.items()): | |
if id_range is not None: | |
id_num = int(task_id.split("/")[1]) | |
low, high = id_range | |
if id_num < low or id_num >= high: | |
p.console.print(f"Skipping {task_id} as it is not in {id_range}") | |
continue | |
p_name = task_id.replace("/", "_") | |
if args.contract_type != "none" and task["contract"] == "": | |
continue | |
os.makedirs(os.path.join(workdir, p_name), exist_ok=True) | |
log = f"Codegen: {p_name} @ {model}" | |
n_existing = 0 | |
if args.resume: | |
# count existing .py files | |
n_existing = len( | |
[ | |
f | |
for f in os.listdir(os.path.join(workdir, p_name)) | |
if f.endswith(".py") | |
] | |
) | |
if n_existing > 0: | |
log += f" (resuming from {n_existing})" | |
nsamples = args.n_samples - n_existing | |
p.console.print(log) | |
sidx = args.n_samples - nsamples | |
while sidx < args.n_samples: | |
outputs = model.codegen( | |
construct_contract_prompt( | |
task["prompt"], args.contract_type, task["contract"] | |
), | |
do_sample=not args.greedy, | |
num_samples=args.n_samples - sidx, | |
) | |
assert outputs, "No outputs from model!" | |
for impl in outputs: | |
try: | |
with open( | |
os.path.join(workdir, p_name, f"{sidx}.py"), | |
"w", | |
encoding="utf-8", | |
) as f: | |
if model.conversational: | |
f.write(impl) | |
else: | |
f.write(task["prompt"] + impl) | |
except UnicodeEncodeError: | |
continue | |
sidx += 1 | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model", required=True, type=str) | |
parser.add_argument("--bs", default=1, type=int) | |
parser.add_argument("--temperature", default=0.0, type=float) | |
parser.add_argument( | |
"--dataset", required=True, type=str, choices=["humaneval", "mbpp"] | |
) | |
parser.add_argument("--root", type=str, required=True) | |
parser.add_argument("--n_samples", default=1, type=int) | |
parser.add_argument("--resume", action="store_true") | |
parser.add_argument( | |
"--contract-type", | |
default="none", | |
type=str, | |
choices=["none", "code", "docstring"], | |
) | |
parser.add_argument("--greedy", action="store_true") | |
# id_range is list | |
parser.add_argument("--id-range", default=None, nargs="+", type=int) | |
args = parser.parse_args() | |
if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1): | |
args.temperature = 0 | |
args.bs = 1 | |
args.n_samples = 1 | |
print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") | |
if args.id_range is not None: | |
assert len(args.id_range) == 2, "id_range must be a list of length 2" | |
assert args.id_range[0] < args.id_range[1], "id_range must be increasing" | |
args.id_range = tuple(args.id_range) | |
# Make project dir | |
os.makedirs(args.root, exist_ok=True) | |
# Make dataset dir | |
os.makedirs(os.path.join(args.root, args.dataset), exist_ok=True) | |
# Make dir for codes generated by each model | |
args.model = args.model.lower() | |
model = make_model( | |
name=args.model, batch_size=args.bs, temperature=args.temperature | |
) | |
workdir = os.path.join( | |
args.root, | |
args.dataset, | |
args.model | |
+ f"_temp_{args.temperature}" | |
+ ("" if args.contract_type == "none" else f"-contract-{args.contract_type}"), | |
) | |
os.makedirs(workdir, exist_ok=True) | |
with open(os.path.join(workdir, "args.txt"), "w") as f: | |
f.write(str(args)) | |
code_generate(args, workdir=workdir, model=model, id_range=args.id_range) | |
if __name__ == "__main__": | |
main() | |