File size: 1,065 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import argparse
from pathlib import Path

import torch

from s3prl import Dataset, Output, Task
from s3prl.base.object import Object

device = "cuda" if torch.cuda.is_available() else "cpu"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "load_from", help="The directory containing all the checkpoints"
    )
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
    load_from = Path(args.load_from)

    task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
    task.eval()

    test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
    test_dataloader = test_dataset.to_dataloader(batch_size=1, num_workers=6)

    with torch.no_grad():
        for batch in test_dataloader:
            batch: Output = batch.to(device)
            result = task(**batch.subset("x", "x_len", as_type="dict"))
            for name, prediction in zip(batch.name, result.prediction):
                print(name, prediction)


if __name__ == "__main__":
    main()