Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| import glob | |
| import argparse | |
| import pprint | |
| import omegaconf | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader | |
| from mmpt.utils import load_config, set_seed | |
| from mmpt.evaluators import Evaluator | |
| from mmpt.evaluators import predictor as predictor_path | |
| from mmpt.tasks import Task | |
| from mmpt import processors | |
| from mmpt.datasets import MMDataset | |
| def get_dataloader(config): | |
| meta_processor_cls = getattr(processors, config.dataset.meta_processor) | |
| video_processor_cls = getattr(processors, config.dataset.video_processor) | |
| text_processor_cls = getattr(processors, config.dataset.text_processor) | |
| aligner_cls = getattr(processors, config.dataset.aligner) | |
| meta_processor = meta_processor_cls(config.dataset) | |
| video_processor = video_processor_cls(config.dataset) | |
| text_processor = text_processor_cls(config.dataset) | |
| aligner = aligner_cls(config.dataset) | |
| test_data = MMDataset( | |
| meta_processor, | |
| video_processor, | |
| text_processor, | |
| aligner, | |
| ) | |
| print("test_len", len(test_data)) | |
| output = test_data[0] | |
| test_data.print_example(output) | |
| test_dataloader = DataLoader( | |
| test_data, | |
| batch_size=config.fairseq.dataset.batch_size, | |
| shuffle=False, | |
| num_workers=6, | |
| collate_fn=test_data.collater, | |
| ) | |
| return test_dataloader | |
| def main(args): | |
| config = load_config(args) | |
| if isinstance(config, omegaconf.dictconfig.DictConfig): | |
| print(OmegaConf.to_yaml(config)) | |
| else: | |
| pp = pprint.PrettyPrinter(indent=4) | |
| pp.print(config) | |
| mmtask = Task.config_task(config) | |
| mmtask.build_model() | |
| test_dataloader = get_dataloader(config) | |
| checkpoint_search_path = os.path.dirname(config.eval.save_path) | |
| results = [] | |
| prefix = os.path.basename(args.taskconfig) | |
| if prefix.startswith("test"): | |
| # loop all checkpoint for datasets without validation set. | |
| if "best" not in config.fairseq.common_eval.path: | |
| print("eval each epoch.") | |
| for checkpoint in glob.glob(checkpoint_search_path + "/checkpoint*"): | |
| model = mmtask.load_checkpoint(checkpoint) | |
| ckpt = os.path.basename(checkpoint) | |
| evaluator = Evaluator(config) | |
| output = evaluator.evaluate( | |
| model, test_dataloader, ckpt + "_merged") | |
| results.append((checkpoint, output)) | |
| # use the one specified by the config lastly. | |
| model = mmtask.load_checkpoint(config.fairseq.common_eval.path) | |
| evaluator = Evaluator(config) | |
| output = evaluator.evaluate(model, test_dataloader) | |
| results.append((config.fairseq.common_eval.path, output)) | |
| best_result = None | |
| best_metric = 0. | |
| for checkpoint, result in results: | |
| print(checkpoint) | |
| evaluator.metric.print_computed_metrics(result) | |
| best_score = evaluator.metric.best_metric(result) | |
| if best_score > best_metric: | |
| best_result = (checkpoint, result) | |
| best_metric = best_score | |
| print("best results:") | |
| print(best_result[0]) | |
| evaluator.metric.print_computed_metrics(best_result[1]) | |
| elif prefix.startswith("vis"): | |
| model = mmtask.load_checkpoint(config.fairseq.common_eval.path) | |
| predictor_cls = getattr(predictor_path, config.predictor) | |
| predictor = predictor_cls(config) | |
| predictor.predict_loop(model, test_dataloader, mmtask, None) | |
| else: | |
| raise ValueError("unknown prefix of the config file", args.taskconfig) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("taskconfig", type=str) | |
| args = parser.parse_args() | |
| main(args) | |