lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.07 kB
import argparse
from pathlib import Path
import torch
from s3prl import Dataset, Output, Task
from s3prl.base.object import Object
device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"load_from", help="The directory containing all the checkpoints"
)
args = parser.parse_args()
return args
def main():
args = parse_args()
load_from = Path(args.load_from)
task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
task.eval()
test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
test_dataloader = test_dataset.to_dataloader(batch_size=1, num_workers=6)
with torch.no_grad():
for batch in test_dataloader:
batch: Output = batch.to(device)
result = task(**batch.subset("x", "x_len", as_type="dict"))
for name, prediction in zip(batch.name, result.prediction):
print(name, prediction)
if __name__ == "__main__":
main()