File size: 5,909 Bytes
223265d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
from utils.encode.quantizer import LinearQuantizer
import math
from scipy.special import gamma as Gamma
import numpy  as np
import dask.array as da


class DeepShape:
    def __init__(self):
        self.gamma_table = torch.load('utils/gamma_table.pt')
        self.rho_table = torch.load('utils/rho_table.pt')
    
    """estimate GGD parameters"""
    def Calc_GG_params(self, model, adj_minnum = 0):
        #get parameters
        params = []
        for param in model.parameters():
            params.append(param.flatten())
        params = torch.cat(params).detach()
        params_org = params.clone()
        
        # Quantization
        lq = LinearQuantizer(params, 13)
        params = lq.quant(params)   
        
        #sorting
        elements, counts = torch.unique(params, return_counts=True)   
        # dask_params = da.from_array(params.numpy(), chunks=int(1e8))  #if param's size is big
        # elements, counts = da.unique(dask_params, return_counts=True)
        # elements = torch.from_numpy(elements.compute())
        # counts = torch.from_numpy(counts.compute())
        indices = torch.argsort(counts, descending=True)
        elements = elements[indices]
        counts = counts[indices]

        if adj_minnum > 0:
            param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
            # print("param_max", (param_max/(2**13)))
            # print('max_param, num_max_param', (elements[0]/(2**13)), counts[0])
            elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
        else:
            elements_cut = params_org

        #estimate
        n = len(elements_cut)
        var = torch.sum(torch.pow(elements_cut, 2))
        mean = torch.sum(torch.abs(elements_cut))
        self.gamma_table = self.gamma_table.to(elements_cut.device)
        self.rho_table = self.rho_table.to(elements_cut.device)
        rho = n * var / mean ** 2
        pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
        shape = self.gamma_table[pos].item()
        std = torch.sqrt(var / n)
        beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
        mu = torch.mean(elements_cut)
        print("mu:", mu)
        print('shape:', shape)
        print('beta',(beta))
        
        return mu, shape, beta
        
    
    """GGD deepshape remap"""
    def GGD_deepshape(self, model, shape_scale=0.8, std_scale=0.6, adj_minnum = 1000): 
        #get parameters
        params = []
        for param in model.parameters():
            params.append(param.flatten())
        params = torch.cat(params).detach()
        params_org = params.clone()
        
        # Quantization
        lq = LinearQuantizer(params, 13)
        params = lq.quant(params)   
        
        #sorting
        elements, counts = torch.unique(params, return_counts=True)
        indices = torch.argsort(counts, descending=True)
        elements = elements[indices]
        counts = counts[indices]
             
        if adj_minnum > 0:
            param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long()
            elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))]
        else:
            elements_cut = params_org
            param_max=0
        
        #estimate org GGD    
        n = len(elements_cut)
        var = torch.sum(torch.pow(elements_cut, 2))
        mean = torch.sum(torch.abs(elements_cut))
        self.gamma_table = self.gamma_table.to(elements_cut.device)
        self.rho_table = self.rho_table.to(elements_cut.device)
        rho = n * var / mean ** 2
        pos = torch.argmin(torch.abs(rho - self.rho_table)).item()
        shape = self.gamma_table[pos].item()
        std = torch.sqrt(var / n)
        beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std
        mu_est = torch.mean(elements_cut)
    
        print("org mu:", mu_est)
        print('org shape:', shape)
        print('org beta',beta)
        
        beta = (beta * (2**13))
        mu_est = int(mu_est*(2**13))
        
        #sorting params in [-param_pax, param_max]  
        if adj_minnum>0:
            adj_indices = torch.nonzero((params>=mu_est-param_max)&(params<=mu_est+param_max), as_tuple=False).squeeze()
            adj_indices = adj_indices[torch.argsort(params[(params>=mu_est-param_max)&(params<=mu_est+param_max)], descending=False)]
            adj_num = len(adj_indices)
        else:
            adj_indices = torch.argsort(params, descending=False)
            adj_num = len(adj_indices)
            
        #remape new GGD
        new_params = params.clone()
        new_shape = shape * shape_scale
        new_beta = beta * std_scale
        if(beta<=0):
            beta=1
        
        x = torch.arange(mu_est-param_max, mu_est+param_max+1, device=params.device)
        new_ratio = -torch.pow(torch.abs(x.float()-mu_est)/new_beta, new_shape)
        new_ratio = torch.exp(new_ratio)
        new_ratio = new_ratio / torch.sum(new_ratio)
        new_num = (adj_num * new_ratio).long()
        num_temp = 0
        for i in range(0, 2*param_max+1):
            new_params[adj_indices[num_temp : num_temp+new_num[i]]]=i+mu_est-param_max
            num_temp += new_num[i]
        new_params=new_params.float()/(2**13)
        
        #modify model parameters
        j=0
        for name, param in model.named_parameters():
            shape=param.data.shape
            param_flatten = torch.flatten(param.data)
            param_flatten = new_params[j: j+len(param_flatten)]
            j+=len(param_flatten)
            param_flatten = param_flatten.reshape(shape)
            param.data= param_flatten
        
        print("new mu:", float(mu_est)/(2**13))
        print('new_shape:', new_shape)
        print('new beta', float(new_beta)/(2**13))
        return float(mu_est)/(2**13), new_shape, float(new_beta)/(2**13)