dineshsai07's picture
Add files using upload-large-folder tool
46a8d8a verified
raw
history blame
1.34 kB
import torch.distributed as dist
import torch.multiprocessing as mp
import os
import os.path as osp
import sys
import numpy as np
import copy
from lib.cfg_holder import cfg_unique_holder as cfguh
from lib.cfg_helper import \
get_command_line_args, \
cfg_initiates
from lib.model_zoo.sd import version
from lib.utils import get_obj_from_str
if __name__ == "__main__":
cfg = get_command_line_args()
cfg = cfg_initiates(cfg)
if 'train' in cfg:
trainer = get_obj_from_str(cfg.train.main)(cfg)
tstage = get_obj_from_str(cfg.train.stage)()
if 'eval' in cfg:
tstage.nested_eval_stage = get_obj_from_str(cfg.eval.stage)()
trainer.register_stage(tstage)
if cfg.env.gpu_count == 1:
trainer(0)
else:
mp.spawn(trainer,
args=(),
nprocs=cfg.env.gpu_count,
join=True)
trainer.destroy()
else:
evaler = get_obj_from_str(cfg.eval.main)(cfg)
estage = get_obj_from_str(cfg.eval.stage)()
evaler.register_stage(estage)
if cfg.env.gpu_count == 1:
evaler(0)
else:
mp.spawn(evaler,
args=(),
nprocs=cfg.env.gpu_count,
join=True)
evaler.destroy()