import torch rdo = 2e3 clip = 1 gamma_table = torch.load("data/gamma_table.pt") r_gamma_table = torch.load("data/r_gamma_table.pt") def backslash(model): with torch.no_grad(): device = torch.device("cuda:0") # Evaluate the shape parameter n, var, mean = 0, 0, 0 for param in model.parameters(): param = param.flatten().detach() n += param.shape[0] var += torch.sum((param ** 2).to(device)) mean += torch.sum(torch.abs(param).to(device)) r_gamma = (n * var / mean ** 2).to(device=torch.device("cpu")) pos = torch.argmin(torch.abs(r_gamma - r_gamma_table)) shape = gamma_table[pos] std = torch.sqrt(var / n) n = torch.tensor(n) # Rate Constrained Optimization for param in model.parameters(): constant = rdo * shape / n * torch.sign(param.data) param_reg = torch.pow( torch.abs(param.data) + clip, shape - 1) param.data -= constant * param_reg distribution = {"shape": shape, "standard": std} return distribution