Spaces:
Runtime error
Runtime error
Commit
Β·
c55fc6f
1
Parent(s):
a91989d
Delete train.py
Browse files
train.py
DELETED
|
@@ -1,103 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Copyright (c) 2022, salesforce.com, inc.
|
| 3 |
-
All rights reserved.
|
| 4 |
-
SPDX-License-Identifier: BSD-3-Clause
|
| 5 |
-
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import argparse
|
| 9 |
-
import os
|
| 10 |
-
import random
|
| 11 |
-
|
| 12 |
-
import numpy as np
|
| 13 |
-
import torch
|
| 14 |
-
import torch.backends.cudnn as cudnn
|
| 15 |
-
|
| 16 |
-
import minigpt4.tasks as tasks
|
| 17 |
-
from minigpt4.common.config import Config
|
| 18 |
-
from minigpt4.common.dist_utils import get_rank, init_distributed_mode
|
| 19 |
-
from minigpt4.common.logger import setup_logger
|
| 20 |
-
from minigpt4.common.optims import (
|
| 21 |
-
LinearWarmupCosineLRScheduler,
|
| 22 |
-
LinearWarmupStepLRScheduler,
|
| 23 |
-
)
|
| 24 |
-
from minigpt4.common.registry import registry
|
| 25 |
-
from minigpt4.common.utils import now
|
| 26 |
-
|
| 27 |
-
# imports modules for registration
|
| 28 |
-
from minigpt4.datasets.builders import *
|
| 29 |
-
from minigpt4.models import *
|
| 30 |
-
from minigpt4.processors import *
|
| 31 |
-
from minigpt4.runners import *
|
| 32 |
-
from minigpt4.tasks import *
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def parse_args():
|
| 36 |
-
parser = argparse.ArgumentParser(description="Training")
|
| 37 |
-
|
| 38 |
-
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
| 39 |
-
parser.add_argument(
|
| 40 |
-
"--options",
|
| 41 |
-
nargs="+",
|
| 42 |
-
help="override some settings in the used config, the key-value pair "
|
| 43 |
-
"in xxx=yyy format will be merged into config file (deprecate), "
|
| 44 |
-
"change to --cfg-options instead.",
|
| 45 |
-
)
|
| 46 |
-
|
| 47 |
-
args = parser.parse_args()
|
| 48 |
-
# if 'LOCAL_RANK' not in os.environ:
|
| 49 |
-
# os.environ['LOCAL_RANK'] = str(args.local_rank)
|
| 50 |
-
|
| 51 |
-
return args
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def setup_seeds(config):
|
| 55 |
-
seed = config.run_cfg.seed + get_rank()
|
| 56 |
-
|
| 57 |
-
random.seed(seed)
|
| 58 |
-
np.random.seed(seed)
|
| 59 |
-
torch.manual_seed(seed)
|
| 60 |
-
|
| 61 |
-
cudnn.benchmark = False
|
| 62 |
-
cudnn.deterministic = True
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def get_runner_class(cfg):
|
| 66 |
-
"""
|
| 67 |
-
Get runner class from config. Default to epoch-based runner.
|
| 68 |
-
"""
|
| 69 |
-
runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base"))
|
| 70 |
-
|
| 71 |
-
return runner_cls
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def main():
|
| 75 |
-
# allow auto-dl completes on main process without timeout when using NCCL backend.
|
| 76 |
-
# os.environ["NCCL_BLOCKING_WAIT"] = "1"
|
| 77 |
-
|
| 78 |
-
# set before init_distributed_mode() to ensure the same job_id shared across all ranks.
|
| 79 |
-
job_id = now()
|
| 80 |
-
|
| 81 |
-
cfg = Config(parse_args())
|
| 82 |
-
|
| 83 |
-
init_distributed_mode(cfg.run_cfg)
|
| 84 |
-
|
| 85 |
-
setup_seeds(cfg)
|
| 86 |
-
|
| 87 |
-
# set after init_distributed_mode() to only log on master.
|
| 88 |
-
setup_logger()
|
| 89 |
-
|
| 90 |
-
cfg.pretty_print()
|
| 91 |
-
|
| 92 |
-
task = tasks.setup_task(cfg)
|
| 93 |
-
datasets = task.build_datasets(cfg)
|
| 94 |
-
model = task.build_model(cfg)
|
| 95 |
-
|
| 96 |
-
runner = get_runner_class(cfg)(
|
| 97 |
-
cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets
|
| 98 |
-
)
|
| 99 |
-
runner.train()
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if __name__ == "__main__":
|
| 103 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|