|
from enum import Enum |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
class Summary(Enum): |
|
NONE = 0 |
|
AVERAGE = 1 |
|
SUM = 2 |
|
COUNT = 3 |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): |
|
self.name = name |
|
self.fmt = fmt |
|
self.summary_type = summary_type |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def all_reduce(self): |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
device = torch.device("mps") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) |
|
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) |
|
self.sum, self.count = total.tolist() |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
|
return fmtstr.format(**self.__dict__) |
|
|
|
def summary(self): |
|
fmtstr = '' |
|
if self.summary_type is Summary.NONE: |
|
fmtstr = '' |
|
elif self.summary_type is Summary.AVERAGE: |
|
fmtstr = '{name} {avg:.3f}' |
|
elif self.summary_type is Summary.SUM: |
|
fmtstr = '{name} {sum:.3f}' |
|
elif self.summary_type is Summary.COUNT: |
|
fmtstr = '{name} {count:.3f}' |
|
else: |
|
raise ValueError('invalid summary type %r' % self.summary_type) |
|
|
|
return fmtstr.format(**self.__dict__) |
|
|