|
import argparse |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from s3prl import Dataset, Output, Task |
|
from s3prl.base.object import Object |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"load_from", help="The directory containing all the checkpoints" |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
load_from = Path(args.load_from) |
|
|
|
task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device) |
|
task.eval() |
|
|
|
test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt") |
|
test_dataloader = test_dataset.to_dataloader(batch_size=1, num_workers=6) |
|
|
|
with torch.no_grad(): |
|
for batch in test_dataloader: |
|
batch: Output = batch.to(device) |
|
result = task(**batch.subset("x", "x_len", as_type="dict")) |
|
for name, prediction in zip(batch.name, result.prediction): |
|
print(name, prediction) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|