File size: 1,127 Bytes
223265d
8e64547
 
 
 
 
 
 
223265d
 
 
 
 
 
 
 
 
 
 
 
8e64547
 
223265d
 
 
 
 
8e64547
223265d
8e64547
223265d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch

rdo = 2e3
clip = 1
gamma_table = torch.load("data/gamma_table.pt")
r_gamma_table = torch.load("data/r_gamma_table.pt")


def backslash(model):
    with torch.no_grad():
        device = torch.device("cuda:0")
        
        # Evaluate the shape parameter
        n, var, mean = 0, 0, 0
        for param in model.parameters():
            param = param.flatten().detach()
            n += param.shape[0]
            var += torch.sum((param ** 2).to(device))
            mean += torch.sum(torch.abs(param).to(device))
        r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu"))
        pos = torch.argmin(torch.abs(r_gamma - r_gamma_table))
        shape = gamma_table[pos]
        std = torch.sqrt(var / n)
        n = torch.tensor(n)

        # Rate Constrained Optimization
        for param in model.parameters():
            constant = rdo * shape / n * torch.sign(param.data)
            param_reg = torch.pow(
                torch.abs(param.data) + clip, shape - 1)
            param.data -= constant * param_reg
        distribution = {"shape": shape, "standard": std}
    return distribution