Spaces:
Running
Running
File size: 2,734 Bytes
fcc02a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
cached_multipier = None
def get_multiplier(timesteps, num_timesteps=1000):
global cached_multipier
if cached_multipier is None:
# creates a bell curve
x = torch.arange(num_timesteps, dtype=torch.float32)
y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2)
# Shift minimum to 0
y_shifted = y - y.min()
# Scale to make mean 1
cached_multipier = y_shifted * (num_timesteps / y_shifted.sum())
scale_list = []
# get the idx multiplier for each timestep
for i in range(timesteps.shape[0]):
idx = min(int(timesteps[i].item()) - 1, 0)
scale_list.append(cached_multipier[idx:idx + 1])
scales = torch.cat(scale_list, dim=0)
batch_multiplier = scales.view(-1, 1, 1, 1)
return batch_multiplier
def get_blended_blur_noise(latents, noise, timestep):
latent_chunks = torch.chunk(latents, latents.shape[0], dim=0)
# timestep is 1000 to 0
# timestep = timestep.to(latents.device, dtype=latents.dtype)
# scale it so timestep 1000 is 0 and 0 is 2
# blur_strength = value_map(timestep, 1000, 0, 0, 1.0)
# blur_strength = timestep / 500.0
# blur_strength = blur_strength.view(-1, 1, 1, 1)
# scale to 2.0 max
# blur_strength = get_multiplier(timestep).to(
# latents.device, dtype=latents.dtype
# ) * 2.0
# blur_strength = 2.0
blurred_latent_chunks = []
for i in range(len(latent_chunks)):
latent_chunk = latent_chunks[i]
# get two random scalers 0.1 to 0.9
# scaler1 = random.uniform(0.2, 0.8)
scaler1 = 0.25
scaler2 = scaler1
# shrink latents by 1/4 and bring them back for blurring using interpolation
blur_latents = torch.nn.functional.interpolate(
latent_chunk,
size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)),
mode='bilinear',
align_corners=False
)
blur_latents = torch.nn.functional.interpolate(
blur_latents,
size=(latents.shape[2], latents.shape[3]),
mode='bilinear',
align_corners=False
)
# only the difference of the blur from ground truth
blur_latents = blur_latents - latent_chunk
blurred_latent_chunks.append(blur_latents)
blur_latents = torch.cat(blurred_latent_chunks, dim=0)
# make random strength along batch 0 to 1
blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2
blur_latents = blur_latents * blur_strength
noise = noise + blur_latents
return noise
|