|
import copy |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
__all__ = ['EMA'] |
|
|
|
|
|
class EMA(nn.Module): |
|
""" |
|
Implements exponential moving average shadowing for your model. |
|
Utilizes an inverse decay schedule to manage longer term training runs. |
|
By adjusting the power, you can control how fast EMA will ramp up to your specified beta. |
|
@crowsonkb's notes on EMA Warmup: |
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are |
|
good values for models you plan to train for a million or more steps (reaches decay |
|
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models |
|
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at |
|
215.4k steps). |
|
Args: |
|
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. |
|
power (float): Exponential factor of EMA warmup. Default: 1. |
|
min_value (float): The minimum EMA decay rate. Default: 0. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model, |
|
|
|
|
|
ema_model=None, |
|
beta=0.9999, |
|
update_after_step=100, |
|
update_every=10, |
|
inv_gamma=1.0, |
|
power=2 / 3, |
|
min_value=0.0, |
|
param_or_buffer_names_no_ema=set(), |
|
ignore_names=set(), |
|
ignore_startswith_names=set(), |
|
|
|
|
|
include_online_model=True |
|
): |
|
super().__init__() |
|
self.beta = beta |
|
|
|
|
|
self.include_online_model = include_online_model |
|
|
|
if include_online_model: |
|
self.online_model = model |
|
else: |
|
self.online_model = [model] |
|
|
|
|
|
self.ema_model = ema_model |
|
|
|
if not exists(self.ema_model): |
|
try: |
|
self.ema_model = copy.deepcopy(model) |
|
except: |
|
print('Your model was not copyable. Please make sure you are not using any LazyLinear') |
|
exit() |
|
|
|
self.ema_model.requires_grad_(False) |
|
|
|
self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float} |
|
self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float} |
|
|
|
self.update_every = update_every |
|
self.update_after_step = update_after_step |
|
|
|
self.inv_gamma = inv_gamma |
|
self.power = power |
|
self.min_value = min_value |
|
|
|
assert isinstance(param_or_buffer_names_no_ema, (set, list)) |
|
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema |
|
|
|
self.ignore_names = ignore_names |
|
self.ignore_startswith_names = ignore_startswith_names |
|
|
|
self.register_buffer('initted', torch.Tensor([False])) |
|
self.register_buffer('step', torch.tensor([0])) |
|
|
|
@property |
|
def model(self): |
|
return self.online_model if self.include_online_model else self.online_model[0] |
|
|
|
def restore_ema_model_device(self): |
|
device = self.initted.device |
|
self.ema_model.to(device) |
|
|
|
def get_params_iter(self, model): |
|
for name, param in model.named_parameters(): |
|
if name not in self.parameter_names: |
|
continue |
|
yield name, param |
|
|
|
def get_buffers_iter(self, model): |
|
for name, buffer in model.named_buffers(): |
|
if name not in self.buffer_names: |
|
continue |
|
yield name, buffer |
|
|
|
def copy_params_from_model_to_ema(self): |
|
for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), |
|
self.get_params_iter(self.model)): |
|
ma_params.data.copy_(current_params.data) |
|
|
|
for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), |
|
self.get_buffers_iter(self.model)): |
|
ma_buffers.data.copy_(current_buffers.data) |
|
|
|
def get_current_decay(self): |
|
epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.) |
|
value = 1 - (1 + epoch / self.inv_gamma) ** - self.power |
|
|
|
if epoch <= 0: |
|
return 0. |
|
|
|
return clamp(value, min_value=self.min_value, max_value=self.beta) |
|
|
|
def update(self): |
|
step = self.step.item() |
|
self.step += 1 |
|
|
|
if (step % self.update_every) != 0: |
|
return |
|
|
|
if step <= self.update_after_step: |
|
self.copy_params_from_model_to_ema() |
|
return |
|
|
|
if not self.initted.item(): |
|
self.copy_params_from_model_to_ema() |
|
self.initted.data.copy_(torch.Tensor([True])) |
|
|
|
self.update_moving_average(self.ema_model, self.model) |
|
|
|
@torch.no_grad() |
|
def update_moving_average(self, ma_model, current_model): |
|
current_decay = self.get_current_decay() |
|
|
|
for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), |
|
self.get_params_iter(ma_model)): |
|
if name in self.ignore_names: |
|
continue |
|
|
|
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): |
|
continue |
|
|
|
if name in self.param_or_buffer_names_no_ema: |
|
ma_params.data.copy_(current_params.data) |
|
continue |
|
|
|
ma_params.data.lerp_(current_params.data, 1. - current_decay) |
|
|
|
for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), |
|
self.get_buffers_iter(ma_model)): |
|
if name in self.ignore_names: |
|
continue |
|
|
|
if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): |
|
continue |
|
|
|
if name in self.param_or_buffer_names_no_ema: |
|
ma_buffer.data.copy_(current_buffer.data) |
|
continue |
|
|
|
ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay) |
|
|
|
def __call__(self, *args, **kwargs): |
|
return self.ema_model(*args, **kwargs) |
|
|
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
|
|
def is_float_dtype(dtype): |
|
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) |
|
|
|
|
|
def clamp(value, min_value=None, max_value=None): |
|
assert exists(min_value) or exists(max_value) |
|
if exists(min_value): |
|
value = max(value, min_value) |
|
|
|
if exists(max_value): |
|
value = min(value, max_value) |
|
|
|
return value |
|
|