Spaces:
Sleeping
Sleeping
File size: 3,937 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import glob
import argparse
import pprint
import omegaconf
from omegaconf import OmegaConf
from torch.utils.data import DataLoader
from mmpt.utils import load_config, set_seed
from mmpt.evaluators import Evaluator
from mmpt.evaluators import predictor as predictor_path
from mmpt.tasks import Task
from mmpt import processors
from mmpt.datasets import MMDataset
def get_dataloader(config):
meta_processor_cls = getattr(processors, config.dataset.meta_processor)
video_processor_cls = getattr(processors, config.dataset.video_processor)
text_processor_cls = getattr(processors, config.dataset.text_processor)
aligner_cls = getattr(processors, config.dataset.aligner)
meta_processor = meta_processor_cls(config.dataset)
video_processor = video_processor_cls(config.dataset)
text_processor = text_processor_cls(config.dataset)
aligner = aligner_cls(config.dataset)
test_data = MMDataset(
meta_processor,
video_processor,
text_processor,
aligner,
)
print("test_len", len(test_data))
output = test_data[0]
test_data.print_example(output)
test_dataloader = DataLoader(
test_data,
batch_size=config.fairseq.dataset.batch_size,
shuffle=False,
num_workers=6,
collate_fn=test_data.collater,
)
return test_dataloader
def main(args):
config = load_config(args)
if isinstance(config, omegaconf.dictconfig.DictConfig):
print(OmegaConf.to_yaml(config))
else:
pp = pprint.PrettyPrinter(indent=4)
pp.print(config)
mmtask = Task.config_task(config)
mmtask.build_model()
test_dataloader = get_dataloader(config)
checkpoint_search_path = os.path.dirname(config.eval.save_path)
results = []
prefix = os.path.basename(args.taskconfig)
if prefix.startswith("test"):
# loop all checkpoint for datasets without validation set.
if "best" not in config.fairseq.common_eval.path:
print("eval each epoch.")
for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"):
model = mmtask.load_checkpoint(checkpoint)
ckpt = os.path.basename(checkpoint)
evaluator = Evaluator(config)
output = evaluator.evaluate(
model, test_dataloader, ckpt + "_merged")
results.append((checkpoint, output))
# use the one specified by the config lastly.
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
evaluator = Evaluator(config)
output = evaluator.evaluate(model, test_dataloader)
results.append((config.fairseq.common_eval.path, output))
best_result = None
best_metric = 0.
for checkpoint, result in results:
print(checkpoint)
evaluator.metric.print_computed_metrics(result)
best_score = evaluator.metric.best_metric(result)
if best_score > best_metric:
best_result = (checkpoint, result)
best_metric = best_score
print("best results:")
print(best_result[0])
evaluator.metric.print_computed_metrics(best_result[1])
elif prefix.startswith("vis"):
model = mmtask.load_checkpoint(config.fairseq.common_eval.path)
predictor_cls = getattr(predictor_path, config.predictor)
predictor = predictor_cls(config)
predictor.predict_loop(model, test_dataloader, mmtask, None)
else:
raise ValueError("unknown prefix of the config file", args.taskconfig)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("taskconfig", type=str)
args = parser.parse_args()
main(args)
|