File size: 15,992 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
# latent_preview.py

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Latent preview for Blissful Tuner extension
License: Apache 2.0
Created on Mon Mar 10 16:47:29 2025

@author: blyss
"""
import os
import torch
import av
from PIL import Image
from .taehv import TAEHV
from .utils import load_torch_file
from blissful_tuner.utils import BlissfulLogger

logger = BlissfulLogger(__name__, "#8e00ed")


class LatentPreviewer():
    @torch.inference_mode()
    def __init__(self, args, original_latents, timesteps, device, dtype, model_type="hunyuan"):
        self.mode = "latent2rgb" if not hasattr(args, 'preview_vae') or args.preview_vae is None else "taehv"
        ##logger.info(f"Initializing latent previewer with mode {self.mode}...")
        # Correctly handle framepack - it should subtract noise like others unless specifically told otherwise
        self.subtract_noise = True # Default to True for all models now
        # If you specifically need framepack NOT to subtract noise, you'd add a condition here
        # Example: self.subtract_noise = False if model_type == "framepack" else True
        self.args = args
        self.model_type = model_type
        self.device = device
        self.dtype = dtype if dtype != torch.float8_e4m3fn else torch.float16
        if model_type != "framepack" and original_latents is not None and timesteps is not None:
            self.original_latents = original_latents.to(self.device)
            self.timesteps_percent = timesteps / 1000
        # Add Framepack check here too if needed for original_latents/timesteps later
        # elif model_type == "framepack" and ...

        if self.model_type not in ["hunyuan", "wan", "framepack"]:
            raise ValueError(f"Unsupported model type: {self.model_type}")

        if self.mode == "taehv":
            #logger.info(f"Loading TAEHV: {args.preview_vae}...")
            if os.path.exists(args.preview_vae):
                tae_sd = load_torch_file(args.preview_vae, safe_load=True, device=args.device)
            else:
                raise FileNotFoundError(f"{args.preview_vae} was not found!")
            self.taehv = TAEHV(tae_sd).to("cpu", self.dtype)  # Offload for VRAM and match datatype
            self.decoder = self.decode_taehv
            self.scale_factor = None
            self.fps = args.fps
        elif self.mode == "latent2rgb":
            self.decoder = self.decode_latent2rgb
            self.scale_factor = 8
            # Adjust FPS for latent2rgb preview if necessary
            # Original code had / 4, but maybe match output FPS is better?
            # Let's keep the / 4 logic for now as it was there before.
            self.fps = int(args.fps / 4) if args.fps > 4 else 1 # Ensure fps is at least 1


    @torch.inference_mode()
    def preview(self, noisy_latents, current_step=None, preview_suffix=None):
        if self.device == "cuda" or self.device == torch.device("cuda"):
            torch.cuda.empty_cache()
        if self.model_type == "wan":
            noisy_latents = noisy_latents.unsqueeze(0)  # F, C, H, W -> B, F, C, H, W
        elif self.model_type == "hunyuan" or self.model_type == "framepack": # Handle framepack like hunyuan
            pass  # already B, F, C, H, W or expected format B, C, T, H, W
        
        # Check dimensions for framepack - it might be B,C,T,H,W not B,F,C,H,W
        if self.model_type == "framepack" and noisy_latents.ndim == 5: # B,C,T,H,W
             # Ensure latent shape is B, F, C, H, W for consistent processing below if needed
             # If decoder expects B,C,T,H,W, this permute might be wrong. Check decoder.
             # Assuming decoder handles B,C,T,H,W for framepack's latent2rgb
             pass # Keep as B, C, T, H, W if latent2rgb handles it
        
        # Apply subtraction only if enabled AND necessary inputs are available
        if self.subtract_noise and hasattr(self, 'original_latents') and hasattr(self, 'timesteps_percent') and current_step is not None:
             denoisy_latents = self.subtract_original_and_normalize(noisy_latents, current_step)
        else:
             # If not subtracting, maybe still normalize? Depends on desired preview quality.
             # For now, just pass through if subtraction isn't happening.
             denoisy_latents = noisy_latents


        decoded = self.decoder(denoisy_latents)  # Expects F, C, H, W output from decoder

        # Upscale if we used latent2rgb so output is same size as expected
        if self.scale_factor is not None:
            upscaled = torch.nn.functional.interpolate(
                decoded,
                scale_factor=self.scale_factor,
                mode="bicubic",
                align_corners=False
            )
        else:
            upscaled = decoded

        _, _, h, w = upscaled.shape
        self.write_preview(upscaled, w, h, preview_suffix=preview_suffix)

    @torch.inference_mode()
    def subtract_original_and_normalize(self, noisy_latents, current_step):
        # Ensure original_latents and timesteps_percent were initialized
        if not hasattr(self, 'original_latents') or not hasattr(self, 'timesteps_percent'):
             logger.warning("Cannot subtract noise: original_latents or timesteps_percent not initialized.")
             return noisy_latents # Return original if we can't process

        # Compute what percent of original noise is remaining
        noise_remaining = self.timesteps_percent[current_step].to(device=noisy_latents.device)
        # Subtract the portion of original latents
        denoisy_latents = noisy_latents - (self.original_latents.to(device=noisy_latents.device) * noise_remaining)

        # Normalize
        normalized_denoisy_latents = (denoisy_latents - denoisy_latents.mean()) / (denoisy_latents.std() + 1e-8)
        return normalized_denoisy_latents

    @torch.inference_mode()
    def write_preview(self, frames, width, height, preview_suffix=None):
        suffix_str = f"_{preview_suffix}" if preview_suffix else ""
        base_name = f"latent_preview{suffix_str}"
        target = os.path.join(self.args.save_path, f"{base_name}.mp4")
        target_img = os.path.join(self.args.save_path, f"{base_name}.png")
        # Check if we only have a single frame.
        if frames.shape[0] == 1:
            # Clamp, scale, convert to byte and move to CPU
            frame = frames[0].clamp(0, 1).mul(255).byte().cpu()
            # Permute from (3, H, W) to (H, W, 3) for PIL.
            frame_np = frame.permute(1, 2, 0).numpy()
            Image.fromarray(frame_np).save(target_img)
            #logger.info(f"Saved single frame preview to {target_img}") # Add log
            return

        # Otherwise, write out as a video.
        # Make sure fps is at least 1
        output_fps = max(1, self.fps)
        #logger.info(f"Writing preview video to {target} at {output_fps} FPS") # Add log
        try:
            container = av.open(target, mode="w")
            stream = container.add_stream("libx264", rate=output_fps) # Use output_fps
            stream.pix_fmt = "yuv420p"
            stream.width = width
            stream.height = height
            # Add option for higher quality preview encoding if needed
            # stream.options = {'crf': '18'} # Example: Lower CRF = higher quality

            # Loop through each frame.
            for frame_idx, frame in enumerate(frames):
                # Clamp to [0,1], scale, convert to byte and move to CPU.
                frame = frame.clamp(0, 1).mul(255).byte().cpu()
                # Permute from (3, H, W) -> (H, W, 3) for AV.
                frame_np = frame.permute(1, 2, 0).numpy()
                try:
                    video_frame = av.VideoFrame.from_ndarray(frame_np, format="rgb24")
                    for packet in stream.encode(video_frame):
                        container.mux(packet)
                except Exception as e:
                     logger.error(f"Error encoding frame {frame_idx}: {e}")
                     # Optionally break or continue if one frame fails
                     break


            # Flush out any remaining packets and close.
            try:
                for packet in stream.encode():
                    container.mux(packet)
                container.close()
                #logger.info(f"Finished writing preview video: {target}") # Add log
            except Exception as e:
                 logger.error(f"Error finalizing preview video: {e}")
                 # Clean up container if possible
                 try: container.close()
                 except: pass
        except Exception as e:
            logger.error(f"Error opening or writing to preview container {target}: {e}")

    @torch.inference_mode()
    def decode_taehv(self, latents):
        """
        Decodes latents with the TAEHV model, returns shape (F, C, H, W).
        """
        self.taehv.to(self.device)  # Onload
        # --- Adjust permute based on expected input dimension order ---
        # Assuming TAEHV expects B, C, F, H, W (check TAEHV implementation)
        # If input `latents` is B, F, C, H, W (like hunyuan/wan), permute is needed
        # If input `latents` is B, C, F, H, W (like framepack), permute might not be needed or different
        if self.model_type == "framepack": # Assuming framepack latents are B,C,T,H,W
             latents_permuted = latents # No permute needed if TAEHV handles B,C,T,H,W
        else: # Assuming hunyuan/wan are B,F,C,H,W -> need B,C,F,H,W for TAEHV?
             # Original permute was (0, 2, 1, 3, 4) - Check if this matches TAEHV's expectation
             # This permutes B, F, C, H, W -> B, C, F, H, W
             latents_permuted = latents.permute(0, 2, 1, 3, 4)

        latents_permuted = latents_permuted.to(device=self.device, dtype=self.dtype)
        decoded = self.taehv.decode_video(latents_permuted, parallel=False, show_progress_bar=False)
        self.taehv.to("cpu")  # Offload
        return decoded.squeeze(0)  # squeeze off batch dimension -> F, C, H, W

    @torch.inference_mode()
    def decode_latent2rgb(self, latents):
        """
        Decodes latents to RGB using linear transform, returns shape (F, 3, H, W).
        Handles different latent dimension orders (B,F,C,H,W or B,C,T,H,W).
        """
        model_params = {
            "hunyuan": {
                "rgb_factors": [
                    [-0.0395, -0.0331,  0.0445], [ 0.0696,  0.0795,  0.0518],
                    [ 0.0135, -0.0945, -0.0282], [ 0.0108, -0.0250, -0.0765],
                    [-0.0209,  0.0032,  0.0224], [-0.0804, -0.0254, -0.0639],
                    [-0.0991,  0.0271, -0.0669], [-0.0646, -0.0422, -0.0400],
                    [-0.0696, -0.0595, -0.0894], [-0.0799, -0.0208, -0.0375],
                    [ 0.1166,  0.1627,  0.0962], [ 0.1165,  0.0432,  0.0407],
                    [-0.2315, -0.1920, -0.1355], [-0.0270,  0.0401, -0.0821],
                    [-0.0616, -0.0997, -0.0727], [ 0.0249, -0.0469, -0.1703]
                ],
                "bias": [0.0259, -0.0192, -0.0761],
            },
            "wan": {
                "rgb_factors": [
                    [-0.1299, -0.1692,  0.2932], [ 0.0671,  0.0406,  0.0442],
                    [ 0.3568,  0.2548,  0.1747], [ 0.0372,  0.2344,  0.1420],
                    [ 0.0313,  0.0189, -0.0328], [ 0.0296, -0.0956, -0.0665],
                    [-0.3477, -0.4059, -0.2925], [ 0.0166,  0.1902,  0.1975],
                    [-0.0412,  0.0267, -0.1364], [-0.1293,  0.0740,  0.1636],
                    [ 0.0680,  0.3019,  0.1128], [ 0.0032,  0.0581,  0.0639],
                    [-0.1251,  0.0927,  0.1699], [ 0.0060, -0.0633,  0.0005],
                    [ 0.3477,  0.2275,  0.2950], [ 0.1984,  0.0913,  0.1861]
                ],
                "bias": [-0.1835, -0.0868, -0.3360],
            },
            # No 'framepack' key needed, will map to 'hunyuan' below
        }

        # --- FIX: Determine the correct parameter key ---
        # Use 'hunyuan' parameters if the model type is 'framepack'
        params_key = "hunyuan" if self.model_type == "framepack" else self.model_type
        if params_key not in model_params:
             logger.error(f"Unsupported model type '{self.model_type}' (key '{params_key}') for latent2rgb.")
             # Optionally return a black image or raise error
             # Returning black image of expected shape might prevent further crashes
             b, c_or_f, t_or_c, h, w = latents.shape # Get shape
             num_frames = t_or_c if self.model_type == "framepack" else c_or_f # Estimate frame dim
             return torch.zeros((num_frames, 3, h * self.scale_factor, w * self.scale_factor), device='cpu')
             # raise KeyError(f"Unsupported model type '{self.model_type}' (key '{params_key}') for latent2rgb decoding.")

        latent_rgb_factors_data = model_params[params_key]["rgb_factors"]
        latent_rgb_factors_bias_data = model_params[params_key]["bias"]
        # --- END FIX ---

        # Prepare linear transform
        latent_rgb_factors = torch.tensor(
            latent_rgb_factors_data, # Use data fetched with correct key
            device=latents.device,
            dtype=latents.dtype
        ).transpose(0, 1)
        latent_rgb_factors_bias = torch.tensor(
            latent_rgb_factors_bias_data, # Use data fetched with correct key
            device=latents.device,
            dtype=latents.dtype
        )

        # Handle different dimension orders
        # B, F, C, H, W (Hunyuan, Wan) vs B, C, T, H, W (Framepack)
        if self.model_type == "framepack":
            # Input: B, C, T, H, W
            # We need to iterate through T (time/frames) dimension
            num_frames = latents.shape[2]
            frame_dim_idx = 2
            channel_dim_idx = 1
        else: # Wan (and potentially Hunyuan if prepared similarly)
            # Input is expected as B, C, F, H, W after preview() method
            num_frames = latents.shape[2] # F (frame dimension)
            channel_dim_idx = 1           # C
            frame_dim_idx = 2             # F

        latent_images = []
        for t in range(num_frames):
            # Extract frame t, permute C to the end for linear layer
            if self.model_type == "framepack":
                 # Extract B, C, H, W for frame t -> squeeze B -> C, H, W -> permute -> H, W, C
                 extracted = latents[:, :, t, :, :].squeeze(0).permute(1, 2, 0)
            else:
                 # Extract B, C, H, W for frame t -> squeeze B -> C, H, W -> permute -> H, W, C
                 extracted = latents[:, :, t, :, :].squeeze(0).permute(1, 2, 0)

            # extracted should now be (H, W, C)
            rgb = torch.nn.functional.linear(extracted, latent_rgb_factors, bias=latent_rgb_factors_bias) # shape = (H, W, 3)
            latent_images.append(rgb)

        # Stack frames into (F, H, W, 3)
        if not latent_images: # Handle case where loop might not run
             logger.warning("No latent images generated in decode_latent2rgb.")
             b, c_or_f, t_or_c, h, w = latents.shape
             num_frames = t_or_c if self.model_type == "framepack" else c_or_f
             return torch.zeros((num_frames, 3, h * self.scale_factor, w * self.scale_factor), device='cpu')

        latent_images_stacked = torch.stack(latent_images, dim=0)

        # Normalize to [0..1]
        latent_images_min = latent_images_stacked.min()
        latent_images_max = latent_images_stacked.max()
        if latent_images_max > latent_images_min:
            normalized_images = (latent_images_stacked - latent_images_min) / (latent_images_max - latent_images_min)
        else:
            # Handle case where max == min (e.g., all black image)
            normalized_images = torch.zeros_like(latent_images_stacked)

        # Permute to (F, 3, H, W) before returning
        final_images = normalized_images.permute(0, 3, 1, 2)
        return final_images