Spaces:
Running
Running
import torch | |
from utils.encode.quantizer import LinearQuantizer | |
import math | |
from scipy.special import gamma as Gamma | |
import numpy as np | |
import dask.array as da | |
class DeepShape: | |
def __init__(self): | |
self.gamma_table = torch.load('utils/gamma_table.pt') | |
self.rho_table = torch.load('utils/rho_table.pt') | |
"""estimate GGD parameters""" | |
def Calc_GG_params(self, model, adj_minnum = 0): | |
#get parameters | |
params = [] | |
for param in model.parameters(): | |
params.append(param.flatten()) | |
params = torch.cat(params).detach() | |
params_org = params.clone() | |
# Quantization | |
lq = LinearQuantizer(params, 13) | |
params = lq.quant(params) | |
#sorting | |
elements, counts = torch.unique(params, return_counts=True) | |
# dask_params = da.from_array(params.numpy(), chunks=int(1e8)) #if param's size is big | |
# elements, counts = da.unique(dask_params, return_counts=True) | |
# elements = torch.from_numpy(elements.compute()) | |
# counts = torch.from_numpy(counts.compute()) | |
indices = torch.argsort(counts, descending=True) | |
elements = elements[indices] | |
counts = counts[indices] | |
if adj_minnum > 0: | |
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long() | |
# print("param_max", (param_max/(2**13))) | |
# print('max_param, num_max_param', (elements[0]/(2**13)), counts[0]) | |
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))] | |
else: | |
elements_cut = params_org | |
#estimate | |
n = len(elements_cut) | |
var = torch.sum(torch.pow(elements_cut, 2)) | |
mean = torch.sum(torch.abs(elements_cut)) | |
self.gamma_table = self.gamma_table.to(elements_cut.device) | |
self.rho_table = self.rho_table.to(elements_cut.device) | |
rho = n * var / mean ** 2 | |
pos = torch.argmin(torch.abs(rho - self.rho_table)).item() | |
shape = self.gamma_table[pos].item() | |
std = torch.sqrt(var / n) | |
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std | |
mu = torch.mean(elements_cut) | |
print("mu:", mu) | |
print('shape:', shape) | |
print('beta',(beta)) | |
return mu, shape, beta | |
"""GGD deepshape remap""" | |
def GGD_deepshape(self, model, shape_scale=0.8, std_scale=0.6, adj_minnum = 1000): | |
#get parameters | |
params = [] | |
for param in model.parameters(): | |
params.append(param.flatten()) | |
params = torch.cat(params).detach() | |
params_org = params.clone() | |
# Quantization | |
lq = LinearQuantizer(params, 13) | |
params = lq.quant(params) | |
#sorting | |
elements, counts = torch.unique(params, return_counts=True) | |
indices = torch.argsort(counts, descending=True) | |
elements = elements[indices] | |
counts = counts[indices] | |
if adj_minnum > 0: | |
param_max = torch.min(elements[(counts<=adj_minnum) & (elements>0)]).long() | |
elements_cut = params_org[torch.abs(params_org)<=(param_max.float()/(2**13))] | |
else: | |
elements_cut = params_org | |
param_max=0 | |
#estimate org GGD | |
n = len(elements_cut) | |
var = torch.sum(torch.pow(elements_cut, 2)) | |
mean = torch.sum(torch.abs(elements_cut)) | |
self.gamma_table = self.gamma_table.to(elements_cut.device) | |
self.rho_table = self.rho_table.to(elements_cut.device) | |
rho = n * var / mean ** 2 | |
pos = torch.argmin(torch.abs(rho - self.rho_table)).item() | |
shape = self.gamma_table[pos].item() | |
std = torch.sqrt(var / n) | |
beta = math.sqrt(Gamma(1/shape) / Gamma(3/shape))* std | |
mu_est = torch.mean(elements_cut) | |
print("org mu:", mu_est) | |
print('org shape:', shape) | |
print('org beta',beta) | |
beta = (beta * (2**13)) | |
mu_est = int(mu_est*(2**13)) | |
#sorting params in [-param_pax, param_max] | |
if adj_minnum>0: | |
adj_indices = torch.nonzero((params>=mu_est-param_max)&(params<=mu_est+param_max), as_tuple=False).squeeze() | |
adj_indices = adj_indices[torch.argsort(params[(params>=mu_est-param_max)&(params<=mu_est+param_max)], descending=False)] | |
adj_num = len(adj_indices) | |
else: | |
adj_indices = torch.argsort(params, descending=False) | |
adj_num = len(adj_indices) | |
#remape new GGD | |
new_params = params.clone() | |
new_shape = shape * shape_scale | |
new_beta = beta * std_scale | |
if(beta<=0): | |
beta=1 | |
x = torch.arange(mu_est-param_max, mu_est+param_max+1, device=params.device) | |
new_ratio = -torch.pow(torch.abs(x.float()-mu_est)/new_beta, new_shape) | |
new_ratio = torch.exp(new_ratio) | |
new_ratio = new_ratio / torch.sum(new_ratio) | |
new_num = (adj_num * new_ratio).long() | |
num_temp = 0 | |
for i in range(0, 2*param_max+1): | |
new_params[adj_indices[num_temp : num_temp+new_num[i]]]=i+mu_est-param_max | |
num_temp += new_num[i] | |
new_params=new_params.float()/(2**13) | |
#modify model parameters | |
j=0 | |
for name, param in model.named_parameters(): | |
shape=param.data.shape | |
param_flatten = torch.flatten(param.data) | |
param_flatten = new_params[j: j+len(param_flatten)] | |
j+=len(param_flatten) | |
param_flatten = param_flatten.reshape(shape) | |
param.data= param_flatten | |
print("new mu:", float(mu_est)/(2**13)) | |
print('new_shape:', new_shape) | |
print('new beta', float(new_beta)/(2**13)) | |
return float(mu_est)/(2**13), new_shape, float(new_beta)/(2**13) | |