File size: 572 Bytes
6d314be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import math
import torch
from torch.optim import Optimizer
import numpy as np

class ClampOptimizer(Optimizer):
    def __init__(self, optimizer, params, **kwargs):
        self.opt = optimizer(params, **kwargs)
        self.params = params




    @torch.no_grad()
    def step(self, closure=None):
        loss = self.opt.step(closure)


        for param in self.params:
            tmp_latent_norm = torch.clamp(param.data, 0, 1)
            param.data.add_(tmp_latent_norm - param.data)


        return loss


    def zero_grad(self):
        self.opt.zero_grad()