Spaces:
Build error
Build error
import math | |
import torch | |
from torch.optim import Optimizer | |
import numpy as np | |
class ClampOptimizer(Optimizer): | |
def __init__(self, optimizer, params, **kwargs): | |
self.opt = optimizer(params, **kwargs) | |
self.params = params | |
def step(self, closure=None): | |
loss = self.opt.step(closure) | |
for param in self.params: | |
tmp_latent_norm = torch.clamp(param.data, 0, 1) | |
param.data.add_(tmp_latent_norm - param.data) | |
return loss | |
def zero_grad(self): | |
self.opt.zero_grad() | |