|
import math |
|
|
|
import torch |
|
|
|
|
|
def identity(t, *args, **kwargs): |
|
"""return t""" |
|
return t |
|
|
|
|
|
def exists(x): |
|
"""whether x is None or not""" |
|
return x is not None |
|
|
|
|
|
def default(val, d): |
|
"""ternary judgment: val != None ? val : d""" |
|
if exists(val): |
|
return val |
|
return d() if callable(d) else d |
|
|
|
|
|
def has_int_squareroot(num): |
|
return (math.sqrt(num) ** 2) == num |
|
|
|
|
|
def num_to_groups(num, divisor): |
|
groups = num // divisor |
|
remainder = num % divisor |
|
arr = [divisor] * groups |
|
if remainder > 0: |
|
arr.append(remainder) |
|
return arr |
|
|
|
|
|
|
|
|
|
|
|
|
|
def sum_params(model: torch.nn.Module, eps: float = 1e6): |
|
return sum(p.numel() for p in model.parameters()) / eps |
|
|
|
|
|
|
|
|
|
|
|
|
|
def cycle(dl): |
|
while True: |
|
for data in dl: |
|
yield data |
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract(a, t, x_shape): |
|
b, *_ = t.shape |
|
assert x_shape[0] == b |
|
out = a.gather(-1, t) |
|
return out.reshape(b, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
|
def unnormalize(x): |
|
"""unnormalize_to_zero_to_one""" |
|
x = (x + 1) * 0.5 |
|
return torch.clamp(x, 0.0, 1.0) |
|
|
|
|
|
def normalize(x): |
|
"""normalize_to_neg_one_to_one""" |
|
x = x * 2 - 1 |
|
return torch.clamp(x, -1.0, 1.0) |
|
|