Spaces:
Running
Running
from typing import Dict | |
import torch | |
import torch.nn as nn | |
from einops import repeat, rearrange | |
from ...util import append_dims, instantiate_from_config | |
from .denoiser_scaling import DenoiserScaling | |
class DenoiserDub(nn.Module): | |
def __init__(self, scaling_config: Dict, mask_input: bool = True): | |
super().__init__() | |
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) | |
self.mask_input = mask_input | |
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: | |
return sigma | |
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: | |
return c_noise | |
def forward( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: Dict, | |
num_overlap_frames: int = 1, | |
num_frames: int = 14, | |
n_skips: int = 1, | |
chunk_size: int = None, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
sigma = self.possibly_quantize_sigma(sigma) | |
if input.ndim == 5: | |
T = input.shape[2] | |
input = rearrange(input, "b c t h w -> (b t) c h w") | |
if sigma.shape[0] != input.shape[0]: | |
sigma = repeat(sigma, "b ... -> b t ...", t=T) | |
sigma = rearrange(sigma, "b t ... -> (b t) ...") | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) | |
gt = cond.get("gt", torch.Tensor([]).type_as(input)) | |
if gt.dim() == 5: | |
gt = rearrange(gt, "b c t h w -> (b t) c h w") | |
masks = cond.get("masks", None) | |
if masks.dim() == 5: | |
masks = rearrange(masks, "b c t h w -> (b t) c h w") | |
if self.mask_input: | |
input = input * masks + gt * (1.0 - masks) | |
if chunk_size is not None: | |
assert chunk_size % num_frames == 0, ( | |
"Chunk size should be multiple of num_frames" | |
) | |
out = chunk_network( | |
network, | |
input, | |
c_in, | |
c_noise, | |
cond, | |
additional_model_inputs, | |
chunk_size, | |
num_frames=num_frames, | |
) | |
else: | |
out = network(input * c_in, c_noise, cond, **additional_model_inputs) | |
out = out * c_out + input * c_skip | |
out = out * masks + gt * (1.0 - masks) | |
return out | |
class DenoiserTemporalMultiDiffusion(nn.Module): | |
def __init__(self, scaling_config: Dict, is_dub: bool = False): | |
super().__init__() | |
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config) | |
self.is_dub = is_dub | |
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor: | |
return sigma | |
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor: | |
return c_noise | |
def forward( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: Dict, | |
num_overlap_frames: int, | |
num_frames: int, | |
n_skips: int, | |
chunk_size: int = None, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
""" | |
Args: | |
network: Denoising network | |
input: Noisy input | |
sigma: Noise level | |
cond: Dictionary containing additional information | |
num_overlap_frames: Number of overlapping frames | |
additional_model_inputs: Additional inputs for the denoising network | |
Returns: | |
out: Denoised output | |
This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video. | |
The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap. | |
""" | |
sigma = self.possibly_quantize_sigma(sigma) | |
T = num_frames | |
if input.ndim == 5: | |
T = input.shape[2] | |
input = rearrange(input, "b c t h w -> (b t) c h w") | |
if sigma.shape[0] != input.shape[0]: | |
sigma = repeat(sigma, "b ... -> b t ...", t=T) | |
sigma = rearrange(sigma, "b t ... -> (b t) ...") | |
n_skips = n_skips * input.shape[0] // T | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) | |
if self.is_dub: | |
gt = cond.get("gt", torch.Tensor([]).type_as(input)) | |
if gt.dim() == 5: | |
gt = rearrange(gt, "b c t h w -> (b t) c h w") | |
masks = cond.get("masks", None) | |
if masks.dim() == 5: | |
masks = rearrange(masks, "b c t h w -> (b t) c h w") | |
input = input * masks + gt * (1.0 - masks) | |
# Now we want to find the overlapping frames and average them | |
input = rearrange(input, "(b t) c h w -> b c t h w", t=T) | |
# Overlapping frames are at begining and end of each segment and given by num_overlap_frames | |
for i in range(input.shape[0] - n_skips): | |
average_frame = torch.stack( | |
[ | |
input[i, :, -num_overlap_frames:], | |
input[i + 1, :, :num_overlap_frames], | |
] | |
).mean(0) | |
input[i, :, -num_overlap_frames:] = average_frame | |
input[i + n_skips, :, :num_overlap_frames] = average_frame | |
input = rearrange(input, "b c t h w -> (b t) c h w") | |
if chunk_size is not None: | |
assert chunk_size % num_frames == 0, ( | |
"Chunk size should be multiple of num_frames" | |
) | |
out = chunk_network( | |
network, | |
input, | |
c_in, | |
c_noise, | |
cond, | |
additional_model_inputs, | |
chunk_size, | |
num_frames=num_frames, | |
) | |
else: | |
out = network(input * c_in, c_noise, cond, **additional_model_inputs) | |
out = out * c_out + input * c_skip | |
if self.is_dub: | |
out = out * masks + gt * (1.0 - masks) | |
return out | |
def chunk_network( | |
network, | |
input, | |
c_in, | |
c_noise, | |
cond, | |
additional_model_inputs, | |
chunk_size, | |
num_frames=1, | |
): | |
out = [] | |
for i in range(0, input.shape[0], chunk_size): | |
start_idx = i | |
end_idx = i + chunk_size | |
input_chunk = input[start_idx:end_idx] | |
c_in_chunk = ( | |
c_in[start_idx:end_idx] | |
if c_in.shape[0] == input.shape[0] | |
else c_in[start_idx // num_frames : end_idx // num_frames] | |
) | |
c_noise_chunk = ( | |
c_noise[start_idx:end_idx] | |
if c_noise.shape[0] == input.shape[0] | |
else c_noise[start_idx // num_frames : end_idx // num_frames] | |
) | |
cond_chunk = {} | |
for k, v in cond.items(): | |
if isinstance(v, torch.Tensor) and v.shape[0] == input.shape[0]: | |
cond_chunk[k] = v[start_idx:end_idx] | |
elif isinstance(v, torch.Tensor): | |
cond_chunk[k] = v[start_idx // num_frames : end_idx // num_frames] | |
else: | |
cond_chunk[k] = v | |
additional_model_inputs_chunk = {} | |
for k, v in additional_model_inputs.items(): | |
if isinstance(v, torch.Tensor): | |
or_size = v.shape[0] | |
additional_model_inputs_chunk[k] = repeat( | |
v, | |
"b c -> (b t) c", | |
t=(input_chunk.shape[0] // num_frames // or_size) + 1, | |
)[: cond_chunk["concat"].shape[0]] | |
else: | |
additional_model_inputs_chunk[k] = v | |
out.append( | |
network( | |
input_chunk * c_in_chunk, | |
c_noise_chunk, | |
cond_chunk, | |
**additional_model_inputs_chunk, | |
) | |
) | |
return torch.cat(out, dim=0) | |
class KarrasTemporalMultiDiffusion(DenoiserTemporalMultiDiffusion): | |
def __init__(self, scaling_config: Dict): | |
super().__init__(scaling_config) | |
self.bad_network = None | |
def set_bad_network(self, bad_network: nn.Module): | |
self.bad_network = bad_network | |
def split_inputs( | |
self, input: torch.Tensor, cond: Dict, additional_model_inputs | |
) -> torch.Tensor: | |
half_input = input.shape[0] // 2 | |
first_cond_half = {} | |
second_cond_half = {} | |
for k, v in cond.items(): | |
if isinstance(v, torch.Tensor): | |
half_cond = v.shape[0] // 2 | |
first_cond_half[k] = v[:half_cond] | |
second_cond_half[k] = v[half_cond:] | |
elif isinstance(v, list): | |
half_add = v[0].shape[0] // 2 | |
first_cond_half[k] = [v[i][:half_add] for i in range(len(v))] | |
second_cond_half[k] = [v[i][half_add:] for i in range(len(v))] | |
else: | |
first_cond_half[k] = v | |
second_cond_half[k] = v | |
add_good = {} | |
add_bad = {} | |
for k, v in additional_model_inputs.items(): | |
if isinstance(v, torch.Tensor): | |
half_add = v.shape[0] // 2 | |
add_good[k] = v[:half_add] | |
add_bad[k] = v[half_add:] | |
elif isinstance(v, list): | |
half_add = v[0].shape[0] // 2 | |
add_good[k] = [v[i][:half_add] for i in range(len(v))] | |
add_bad[k] = [v[i][half_add:] for i in range(len(v))] | |
else: | |
add_good[k] = v | |
add_bad[k] = v | |
return ( | |
input[:half_input], | |
input[half_input:], | |
first_cond_half, | |
second_cond_half, | |
add_good, | |
add_bad, | |
) | |
def forward( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: Dict, | |
num_overlap_frames: int, | |
num_frames: int, | |
n_skips: int, | |
chunk_size: int = None, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
""" | |
Args: | |
network: Denoising network | |
input: Noisy input | |
sigma: Noise level | |
cond: Dictionary containing additional information | |
num_overlap_frames: Number of overlapping frames | |
additional_model_inputs: Additional inputs for the denoising network | |
Returns: | |
out: Denoised output | |
This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video. | |
The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap. | |
""" | |
sigma = self.possibly_quantize_sigma(sigma) | |
T = num_frames | |
if input.ndim == 5: | |
T = input.shape[2] | |
input = rearrange(input, "b c t h w -> (b t) c h w") | |
if sigma.shape[0] != input.shape[0]: | |
sigma = repeat(sigma, "b ... -> b t ...", t=T) | |
sigma = rearrange(sigma, "b t ... -> (b t) ...") | |
n_skips = n_skips * input.shape[0] // T | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) | |
if self.is_dub: | |
gt = cond.get("gt", torch.Tensor([]).type_as(input)) | |
if gt.dim() == 5: | |
gt = rearrange(gt, "b c t h w -> (b t) c h w") | |
masks = cond.get("masks", None) | |
if masks.dim() == 5: | |
masks = rearrange(masks, "b c t h w -> (b t) c h w") | |
input = input * masks + gt * (1.0 - masks) | |
# Now we want to find the overlapping frames and average them | |
input = rearrange(input, "(b t) c h w -> b c t h w", t=T) | |
# Overlapping frames are at begining and end of each segment and given by num_overlap_frames | |
for i in range(input.shape[0] - n_skips): | |
average_frame = torch.stack( | |
[ | |
input[i, :, -num_overlap_frames:], | |
input[i + 1, :, :num_overlap_frames], | |
] | |
).mean(0) | |
input[i, :, -num_overlap_frames:] = average_frame | |
input[i + n_skips, :, :num_overlap_frames] = average_frame | |
input = rearrange(input, "b c t h w -> (b t) c h w") | |
half = c_in.shape[0] // 2 | |
in_bad, in_good, cond_bad, cond_good, add_inputs_good, add_inputs_bad = ( | |
self.split_inputs(input, cond, additional_model_inputs) | |
) | |
if chunk_size is not None: | |
assert chunk_size % num_frames == 0, ( | |
"Chunk size should be multiple of num_frames" | |
) | |
out = chunk_network( | |
network, | |
in_good, | |
c_in[half:], | |
c_noise[half:], | |
cond_good, | |
add_inputs_good, | |
chunk_size, | |
num_frames=num_frames, | |
) | |
bad_out = chunk_network( | |
self.bad_network, | |
in_bad, | |
c_in[:half], | |
c_noise[:half], | |
cond_bad, | |
add_inputs_bad, | |
chunk_size, | |
num_frames=num_frames, | |
) | |
else: | |
out = network( | |
in_good * c_in[half:], c_noise[half:], cond_good, **add_inputs_good | |
) | |
bad_out = self.bad_network( | |
in_bad * c_in[:half], c_noise[:half], cond_bad, **add_inputs_bad | |
) | |
out = torch.cat([bad_out, out], dim=0) | |
out = out * c_out + input * c_skip | |
if self.is_dub: | |
out = out * masks + gt * (1.0 - masks) | |
return out | |