Antoni Bigata
first commit
b5ce381
from typing import Dict
import torch
import torch.nn as nn
from einops import repeat, rearrange
from ...util import append_dims, instantiate_from_config
from .denoiser_scaling import DenoiserScaling
class DenoiserDub(nn.Module):
def __init__(self, scaling_config: Dict, mask_input: bool = True):
super().__init__()
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
self.mask_input = mask_input
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
num_overlap_frames: int = 1,
num_frames: int = 14,
n_skips: int = 1,
chunk_size: int = None,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.possibly_quantize_sigma(sigma)
if input.ndim == 5:
T = input.shape[2]
input = rearrange(input, "b c t h w -> (b t) c h w")
if sigma.shape[0] != input.shape[0]:
sigma = repeat(sigma, "b ... -> b t ...", t=T)
sigma = rearrange(sigma, "b t ... -> (b t) ...")
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
gt = cond.get("gt", torch.Tensor([]).type_as(input))
if gt.dim() == 5:
gt = rearrange(gt, "b c t h w -> (b t) c h w")
masks = cond.get("masks", None)
if masks.dim() == 5:
masks = rearrange(masks, "b c t h w -> (b t) c h w")
if self.mask_input:
input = input * masks + gt * (1.0 - masks)
if chunk_size is not None:
assert chunk_size % num_frames == 0, (
"Chunk size should be multiple of num_frames"
)
out = chunk_network(
network,
input,
c_in,
c_noise,
cond,
additional_model_inputs,
chunk_size,
num_frames=num_frames,
)
else:
out = network(input * c_in, c_noise, cond, **additional_model_inputs)
out = out * c_out + input * c_skip
out = out * masks + gt * (1.0 - masks)
return out
class DenoiserTemporalMultiDiffusion(nn.Module):
def __init__(self, scaling_config: Dict, is_dub: bool = False):
super().__init__()
self.scaling: DenoiserScaling = instantiate_from_config(scaling_config)
self.is_dub = is_dub
def possibly_quantize_sigma(self, sigma: torch.Tensor) -> torch.Tensor:
return sigma
def possibly_quantize_c_noise(self, c_noise: torch.Tensor) -> torch.Tensor:
return c_noise
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
num_overlap_frames: int,
num_frames: int,
n_skips: int,
chunk_size: int = None,
**additional_model_inputs,
) -> torch.Tensor:
"""
Args:
network: Denoising network
input: Noisy input
sigma: Noise level
cond: Dictionary containing additional information
num_overlap_frames: Number of overlapping frames
additional_model_inputs: Additional inputs for the denoising network
Returns:
out: Denoised output
This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video.
The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap.
"""
sigma = self.possibly_quantize_sigma(sigma)
T = num_frames
if input.ndim == 5:
T = input.shape[2]
input = rearrange(input, "b c t h w -> (b t) c h w")
if sigma.shape[0] != input.shape[0]:
sigma = repeat(sigma, "b ... -> b t ...", t=T)
sigma = rearrange(sigma, "b t ... -> (b t) ...")
n_skips = n_skips * input.shape[0] // T
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
if self.is_dub:
gt = cond.get("gt", torch.Tensor([]).type_as(input))
if gt.dim() == 5:
gt = rearrange(gt, "b c t h w -> (b t) c h w")
masks = cond.get("masks", None)
if masks.dim() == 5:
masks = rearrange(masks, "b c t h w -> (b t) c h w")
input = input * masks + gt * (1.0 - masks)
# Now we want to find the overlapping frames and average them
input = rearrange(input, "(b t) c h w -> b c t h w", t=T)
# Overlapping frames are at begining and end of each segment and given by num_overlap_frames
for i in range(input.shape[0] - n_skips):
average_frame = torch.stack(
[
input[i, :, -num_overlap_frames:],
input[i + 1, :, :num_overlap_frames],
]
).mean(0)
input[i, :, -num_overlap_frames:] = average_frame
input[i + n_skips, :, :num_overlap_frames] = average_frame
input = rearrange(input, "b c t h w -> (b t) c h w")
if chunk_size is not None:
assert chunk_size % num_frames == 0, (
"Chunk size should be multiple of num_frames"
)
out = chunk_network(
network,
input,
c_in,
c_noise,
cond,
additional_model_inputs,
chunk_size,
num_frames=num_frames,
)
else:
out = network(input * c_in, c_noise, cond, **additional_model_inputs)
out = out * c_out + input * c_skip
if self.is_dub:
out = out * masks + gt * (1.0 - masks)
return out
def chunk_network(
network,
input,
c_in,
c_noise,
cond,
additional_model_inputs,
chunk_size,
num_frames=1,
):
out = []
for i in range(0, input.shape[0], chunk_size):
start_idx = i
end_idx = i + chunk_size
input_chunk = input[start_idx:end_idx]
c_in_chunk = (
c_in[start_idx:end_idx]
if c_in.shape[0] == input.shape[0]
else c_in[start_idx // num_frames : end_idx // num_frames]
)
c_noise_chunk = (
c_noise[start_idx:end_idx]
if c_noise.shape[0] == input.shape[0]
else c_noise[start_idx // num_frames : end_idx // num_frames]
)
cond_chunk = {}
for k, v in cond.items():
if isinstance(v, torch.Tensor) and v.shape[0] == input.shape[0]:
cond_chunk[k] = v[start_idx:end_idx]
elif isinstance(v, torch.Tensor):
cond_chunk[k] = v[start_idx // num_frames : end_idx // num_frames]
else:
cond_chunk[k] = v
additional_model_inputs_chunk = {}
for k, v in additional_model_inputs.items():
if isinstance(v, torch.Tensor):
or_size = v.shape[0]
additional_model_inputs_chunk[k] = repeat(
v,
"b c -> (b t) c",
t=(input_chunk.shape[0] // num_frames // or_size) + 1,
)[: cond_chunk["concat"].shape[0]]
else:
additional_model_inputs_chunk[k] = v
out.append(
network(
input_chunk * c_in_chunk,
c_noise_chunk,
cond_chunk,
**additional_model_inputs_chunk,
)
)
return torch.cat(out, dim=0)
class KarrasTemporalMultiDiffusion(DenoiserTemporalMultiDiffusion):
def __init__(self, scaling_config: Dict):
super().__init__(scaling_config)
self.bad_network = None
def set_bad_network(self, bad_network: nn.Module):
self.bad_network = bad_network
def split_inputs(
self, input: torch.Tensor, cond: Dict, additional_model_inputs
) -> torch.Tensor:
half_input = input.shape[0] // 2
first_cond_half = {}
second_cond_half = {}
for k, v in cond.items():
if isinstance(v, torch.Tensor):
half_cond = v.shape[0] // 2
first_cond_half[k] = v[:half_cond]
second_cond_half[k] = v[half_cond:]
elif isinstance(v, list):
half_add = v[0].shape[0] // 2
first_cond_half[k] = [v[i][:half_add] for i in range(len(v))]
second_cond_half[k] = [v[i][half_add:] for i in range(len(v))]
else:
first_cond_half[k] = v
second_cond_half[k] = v
add_good = {}
add_bad = {}
for k, v in additional_model_inputs.items():
if isinstance(v, torch.Tensor):
half_add = v.shape[0] // 2
add_good[k] = v[:half_add]
add_bad[k] = v[half_add:]
elif isinstance(v, list):
half_add = v[0].shape[0] // 2
add_good[k] = [v[i][:half_add] for i in range(len(v))]
add_bad[k] = [v[i][half_add:] for i in range(len(v))]
else:
add_good[k] = v
add_bad[k] = v
return (
input[:half_input],
input[half_input:],
first_cond_half,
second_cond_half,
add_good,
add_bad,
)
def forward(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: Dict,
num_overlap_frames: int,
num_frames: int,
n_skips: int,
chunk_size: int = None,
**additional_model_inputs,
) -> torch.Tensor:
"""
Args:
network: Denoising network
input: Noisy input
sigma: Noise level
cond: Dictionary containing additional information
num_overlap_frames: Number of overlapping frames
additional_model_inputs: Additional inputs for the denoising network
Returns:
out: Denoised output
This function assumes the input is of shape (B, C, T, H, W) with the B dimension being the number of segments in video.
The num_overlap_frames is the number of overlapping frames between the segments to be able to handle the temporal overlap.
"""
sigma = self.possibly_quantize_sigma(sigma)
T = num_frames
if input.ndim == 5:
T = input.shape[2]
input = rearrange(input, "b c t h w -> (b t) c h w")
if sigma.shape[0] != input.shape[0]:
sigma = repeat(sigma, "b ... -> b t ...", t=T)
sigma = rearrange(sigma, "b t ... -> (b t) ...")
n_skips = n_skips * input.shape[0] // T
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
if self.is_dub:
gt = cond.get("gt", torch.Tensor([]).type_as(input))
if gt.dim() == 5:
gt = rearrange(gt, "b c t h w -> (b t) c h w")
masks = cond.get("masks", None)
if masks.dim() == 5:
masks = rearrange(masks, "b c t h w -> (b t) c h w")
input = input * masks + gt * (1.0 - masks)
# Now we want to find the overlapping frames and average them
input = rearrange(input, "(b t) c h w -> b c t h w", t=T)
# Overlapping frames are at begining and end of each segment and given by num_overlap_frames
for i in range(input.shape[0] - n_skips):
average_frame = torch.stack(
[
input[i, :, -num_overlap_frames:],
input[i + 1, :, :num_overlap_frames],
]
).mean(0)
input[i, :, -num_overlap_frames:] = average_frame
input[i + n_skips, :, :num_overlap_frames] = average_frame
input = rearrange(input, "b c t h w -> (b t) c h w")
half = c_in.shape[0] // 2
in_bad, in_good, cond_bad, cond_good, add_inputs_good, add_inputs_bad = (
self.split_inputs(input, cond, additional_model_inputs)
)
if chunk_size is not None:
assert chunk_size % num_frames == 0, (
"Chunk size should be multiple of num_frames"
)
out = chunk_network(
network,
in_good,
c_in[half:],
c_noise[half:],
cond_good,
add_inputs_good,
chunk_size,
num_frames=num_frames,
)
bad_out = chunk_network(
self.bad_network,
in_bad,
c_in[:half],
c_noise[:half],
cond_bad,
add_inputs_bad,
chunk_size,
num_frames=num_frames,
)
else:
out = network(
in_good * c_in[half:], c_noise[half:], cond_good, **add_inputs_good
)
bad_out = self.bad_network(
in_bad * c_in[:half], c_noise[:half], cond_bad, **add_inputs_bad
)
out = torch.cat([bad_out, out], dim=0)
out = out * c_out + input * c_skip
if self.is_dub:
out = out * masks + gt * (1.0 - masks)
return out