lmzjms's picture
Upload 1162 files
0b32ad6 verified
raw
history blame
1.19 kB
import argparse
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from s3prl import Output
from s3prl.base.object import Object
from s3prl.dataset import Dataset
from s3prl.task import Task
device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--load_from",
type=str,
default="result/sv",
help="The directory containing all the checkpoints",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
load_from = Path(args.load_from)
task: Task = Object.load_checkpoint(load_from / "task.ckpt").to(device)
task.eval()
test_dataset: Dataset = Object.load_checkpoint(load_from / "test_dataset.ckpt")
test_dataloader = DataLoader(
test_dataset, batch_size=1, num_workers=6, collate_fn=test_dataset.collate_fn
)
with torch.no_grad():
for batch in test_dataloader:
batch: Output = batch.to(device)
result = task(**batch.subset("x", "x_len", as_type="dict"))
print(result.hidden_states.shape)
if __name__ == "__main__":
main()