Spaces:
Running
Running
import torch | |
rdo = 2e3 | |
clip = 1 | |
gamma_table = torch.load("data/gamma_table.pt") | |
r_gamma_table = torch.load("data/r_gamma_table.pt") | |
def backslash(model): | |
with torch.no_grad(): | |
device = torch.device("cuda:0") | |
# Evaluate the shape parameter | |
n, var, mean = 0, 0, 0 | |
for param in model.parameters(): | |
param = param.flatten().detach() | |
n += param.shape[0] | |
var += torch.sum((param ** 2).to(device)) | |
mean += torch.sum(torch.abs(param).to(device)) | |
r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu")) | |
pos = torch.argmin(torch.abs(r_gamma - r_gamma_table)) | |
shape = gamma_table[pos] | |
std = torch.sqrt(var / n) | |
n = torch.tensor(n) | |
# Rate Constrained Optimization | |
for param in model.parameters(): | |
constant = rdo * shape / n * torch.sign(param.data) | |
param_reg = torch.pow( | |
torch.abs(param.data) + clip, shape - 1) | |
param.data -= constant * param_reg | |
distribution = {"shape": shape, "standard": std} | |
return distribution |