Spaces:
Build error
Build error
File size: 572 Bytes
6d314be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import math
import torch
from torch.optim import Optimizer
import numpy as np
class ClampOptimizer(Optimizer):
def __init__(self, optimizer, params, **kwargs):
self.opt = optimizer(params, **kwargs)
self.params = params
@torch.no_grad()
def step(self, closure=None):
loss = self.opt.step(closure)
for param in self.params:
tmp_latent_norm = torch.clamp(param.data, 0, 1)
param.data.add_(tmp_latent_norm - param.data)
return loss
def zero_grad(self):
self.opt.zero_grad()
|