Spaces:
Running
Running
import random | |
import torch | |
import numpy as np | |
import os | |
from ...common.log import logger | |
def set_random_seed(seed: int): | |
"""Fix all random seeds in common Python packages (`random`, `torch`, `numpy`). | |
Recommend to use before all codes to ensure reproducibility. | |
Args: | |
seed (int): Random seed. | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
def create_tbwriter(log_dir, launch_tbboard=False): | |
if launch_tbboard: | |
from torch.utils.tensorboard import SummaryWriter | |
tb_log = SummaryWriter(log_dir) | |
from tensorboard import program | |
tb = program.TensorBoard() | |
tb.configure(argv=[None, '--logdir', log_dir]) | |
url = tb.launch() | |
logger.info(f'launch tensorboard in {url}') | |
logger.info(f'tensorboard --logdir="{log_dir}"') | |
from torch.utils.tensorboard import SummaryWriter | |
tb_writer = SummaryWriter(log_dir) | |
return tb_writer | |