gg_prior / backslash.py
wujun
amend
8e64547
import torch
rdo = 2e3
clip = 1
gamma_table = torch.load("data/gamma_table.pt")
r_gamma_table = torch.load("data/r_gamma_table.pt")
def backslash(model):
with torch.no_grad():
device = torch.device("cuda:0")
# Evaluate the shape parameter
n, var, mean = 0, 0, 0
for param in model.parameters():
param = param.flatten().detach()
n += param.shape[0]
var += torch.sum((param ** 2).to(device))
mean += torch.sum(torch.abs(param).to(device))
r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu"))
pos = torch.argmin(torch.abs(r_gamma - r_gamma_table))
shape = gamma_table[pos]
std = torch.sqrt(var / n)
n = torch.tensor(n)
# Rate Constrained Optimization
for param in model.parameters():
constant = rdo * shape / n * torch.sign(param.data)
param_reg = torch.pow(
torch.abs(param.data) + clip, shape - 1)
param.data -= constant * param_reg
distribution = {"shape": shape, "standard": std}
return distribution