Spaces:
Running
Running
File size: 10,436 Bytes
e0336bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
#!/usr/bin/env python3
"""
Tiny AutoEncoder for Hunyuan Video
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
From: https://github.com/madebyollin/taehv
MIT License
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = nn.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
"""
Apply a sequential model with memblocks to the given input.
Args:
- model: nn.Sequential of blocks to apply
- x: input data, of dimensions NTCHW
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
- show_progress_bar: if True, enables tqdm progressbar display
Returns NTCHW tensor of output data.
"""
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
N, T, C, H, W = x.shape
if parallel:
x = x.reshape(N*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
NT, C, H, W = x.shape
T = NT // N
_x = x.reshape(N, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
NT, C, H, W = x.shape
T = NT // N
x = x.view(N, T, C, H, W)
else:
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
# need to fix :(
out = []
# iterate over input timesteps and also iterate over blocks.
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
# it's actually a ***graph traversal*** problem! so let's make a queue
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
# in addition to manually managing our queue, we also need to manually manage our progressbar.
# we'll update it for every source node that we consume.
progress_bar = tqdm(range(T), disable=not show_progress_bar)
# we'll also need a separate addressable memory per node as well
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.pop(0)
if i == 0:
# new source node consumed
progress_bar.update(1)
if i == len(model):
# reached end of the graph, append result to output list
out.append(xt)
else:
# fetch the block to process
b = model[i]
if isinstance(b, MemBlock):
# mem blocks are simple since we're visiting the graph in causal order
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt
else:
xt_new = b(xt, mem[i])
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
# pool blocks are miserable
if mem[i] is None:
mem[i] = [] # pool memory is itself a queue of inputs to pool
mem[i].append(xt)
if len(mem[i]) > b.stride:
# pool mem is in invalid state, we should have pooled before this
raise ValueError("???")
elif len(mem[i]) < b.stride:
# pool mem is not yet full, go back to processing the work queue
pass
else:
# pool mem is ready, run the pool block
N, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(N*b.stride, C, H, W))
# reset the pool mem
mem[i] = []
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
# each tgrow has multiple successor nodes
for xt_next in reversed(xt.view(N, b.stride*C, H, W).chunk(b.stride, 1)):
# add successor to work queue
work_queue.insert(0, TWorkItem(xt_next, i+1))
else:
# normal block with no funny business
xt = b(xt)
# add successor to work queue
work_queue.insert(0, TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
latent_channels = 16
image_channels = 3
def __init__(self, state_dict, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
"""Initialize pretrained TAEHV from the given checkpoint.
Arg:
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
"""
super().__init__()
self.encoder = nn.Sequential(
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
conv(64, TAEHV.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
)
if state_dict is not None:
self.load_state_dict(self.patch_tgrow_layers(state_dict))
def patch_tgrow_layers(self, sd):
"""Patch TGrow layers to use a smaller kernel if needed.
Args:
sd: state dict to patch
"""
new_sd = self.state_dict()
for i, layer in enumerate(self.decoder):
if isinstance(layer, TGrow):
key = f"decoder.{i}.conv.weight"
if sd[key].shape[0] > new_sd[key].shape[0]:
# take the last-timestep output channels
sd[key] = sd[key][-new_sd[key].shape[0]:]
return sd
def encode_video(self, x, parallel=True, show_progress_bar=True):
"""Encode a sequence of frames.
Args:
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW latent tensor with ~Gaussian values.
"""
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
def decode_video(self, x, parallel=True, show_progress_bar=True):
"""Decode a sequence of frames.
Args:
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
parallel: if True, all frames will be processed at once.
(this is faster but may require more memory).
if False, frames will be processed sequentially.
Returns NTCHW RGB tensor with ~[0, 1] values.
"""
x = apply_model_with_memblocks(self.decoder, x, False, show_progress_bar)
return x[:, self.frames_to_trim:]
def forward(self, x):
return self.c(x)
|