huaweilin's picture
update
14ce5a9
raw
history blame
6.13 kB
"""This file contains the definition of the sampling function."""
from typing import Optional, Tuple, List, Text
import tqdm
import torch
from .masking import get_masking_ratio
from .factorization import combine_factorized_tokens
@torch.no_grad()
def sample(
model,
vqgan_model,
num_samples: int = 10,
labels: Optional[torch.Tensor] = None,
softmax_temperature: float = 1.0,
randomize_temperature: float = 4.5,
mask_schedule_strategy: Text = "linear",
num_steps: int = 12,
guidance_scale: float = 3.0,
mask_token: int = 1024,
patch_size: int = 16,
guidance_annealing: Text = "none",
use_sampling_annealing: bool = False,
scale_pow: float = 4.0,
codebook_size: int = 1024,
codebook_splits: int = 1,
use_tqdm: bool = False,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Sample from the model.
Args:
model -> torch.nn.Module: The model to sample from.
vqgan_model -> torch.nn.Module: The VQGAN model.
num_samples -> int: The number of samples to generate.
labels -> Optional[torch.Tensor]: The labels to use for the generation.
softmax_temperature -> float: The temperature for the softmax.
randomize_temperature -> float: The temperature for the randomization.
mask_schedule_strategy -> Text: The strategy for the mask schedule.
num_steps -> int: The number of steps to use for the sampling.
guidance_scale -> float: The scale for the guidance.
mask_token -> int: The token to use for the masking.
patch_size -> int: The size of the patches.
guidance_annealing -> Text: The annealing strategy for the guidance.
use_sampling_annealing -> bool: Whether to use the sampling annealing.
scale_pow -> float: The power for the scaling.
codebook_size -> int: The size of the codebook.
codebook_splits -> int: The number of splits for the codebook.
Returns:
Tuple[torch.Tensor, List[torch.Tensor]]: The generated samples and the tokens at each step.
"""
device = model.device
model.eval()
vqgan_model.eval()
if labels is None:
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear, random
labels = [
1,
7,
282,
604,
724,
179,
751,
404,
850,
torch.randint(0, 999, size=(1,)),
] * (num_samples // 10)
labels = torch.LongTensor(labels).to(device)
drop_labels = torch.ones(num_samples, dtype=bool, device=device)
spatial_size = int(patch_size**2)
num_splits = int(codebook_splits)
masked_tokens = torch.full(
(num_samples, spatial_size, num_splits), mask_token, device=device
)
num_maskable = spatial_size * num_splits
mask = masked_tokens == mask_token
l_full_tokens = []
gumbel = torch.distributions.Gumbel(loc=0.0, scale=1.0)
if use_tqdm:
step_iterable = tqdm.tqdm(range(num_steps), desc="Sampling steps", position=1)
else:
step_iterable = range(num_steps)
for i in step_iterable:
progress = (i + 1) / num_steps
if guidance_scale != 0.0:
logits = model(
torch.cat([masked_tokens.clone(), masked_tokens.clone()], dim=0),
torch.cat([labels, labels], dim=0),
torch.cat([~drop_labels, drop_labels], dim=0),
)
# Classifier-free guidance
logits_with_class, logits_without_class = torch.chunk(logits, 2, dim=0)
if guidance_annealing == "none":
scale_step = 1.0
elif guidance_annealing == "linear":
scale_step = i / num_steps
elif guidance_annealing == "cosine":
scale_pow = torch.ones((1), device=device) * scale_pow
scale_step = (
(1 - torch.cos(((i / num_steps) ** scale_pow) * torch.pi)) * 1 / 2
) # power-cos scaling
scale = guidance_scale * scale_step
logits = logits_with_class + scale * (
logits_with_class - logits_without_class
)
else:
logits = model(masked_tokens.clone(), labels, ~drop_labels)
if use_sampling_annealing:
softmax_temperature = 0.5 + 0.8 * (1 - progress)
probabilities = torch.softmax(logits / softmax_temperature, dim=-1)
distribution = torch.distributions.Categorical(probabilities)
predicted_tokens = distribution.sample()
num_masked = torch.sum(mask, dim=(1, 2))[0]
predicted_tokens = torch.where(mask, predicted_tokens, masked_tokens)
confidence = torch.gather(
probabilities, -1, predicted_tokens.unsqueeze(-1)
).squeeze(-1)
# Ignore existing tokens by overwriting the confidence.
confidence = torch.where(mask, confidence, torch.inf)
noise = (
gumbel.sample(predicted_tokens.size())
* randomize_temperature
* (1 - progress)
)
confidence = torch.log(confidence) + noise.to(device)
mask_ratio = get_masking_ratio(progress, mode=mask_schedule_strategy).to(device)
# min = 1, max = num_masked - 1
mask_len = torch.floor(mask_ratio * num_maskable)
num_tokens_to_mask = torch.clamp(
mask_len, torch.ones_like(num_masked), num_masked - 1
).long()
sorted_confidence = torch.sort(confidence.view(num_samples, -1), dim=-1).values
threshold = sorted_confidence[:, num_tokens_to_mask - 1]
should_mask = confidence <= threshold.unsqueeze(-1).unsqueeze(-1)
masked_tokens = torch.where(should_mask, mask_token, predicted_tokens)
mask = masked_tokens == mask_token
l_full_tokens.append(predicted_tokens)
predicted_tokens = combine_factorized_tokens(
predicted_tokens, codebook_size, codebook_splits
)
generated_image = vqgan_model.decode_tokens(predicted_tokens)
return generated_image, l_full_tokens