Framepack-H111 / blissful_tuner /prompt_weighting.py
rahul7star's picture
Upload 303 files
e0336bc verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 20 12:51:05 2025
Prompt weighting for WanVideo
Adapted and heavily modified from https://github.com/xhinker/sd_embed
License: Apache 2.0
@author: blyss
"""
from transformers import T5Model
import torch
import re
from typing import Tuple, List, Union
from blissful_tuner.utils import BlissfulLogger
logger = BlissfulLogger(__name__, "#8e00ed")
class MiniT5Wrapper():
"""A mini wrapper for the T5 to make managing prompt weighting in Musubi easier"""
def __init__(self, device: torch.device, dtype: torch.dtype, t5: T5Model):
self.device = device
self.dtype = dtype
self.t5 = t5
self.model = t5.model
self.times_called = 0
def __call__(
self,
prompt: Union[str, List[str]],
device: torch.device,
max_len: int = None
) -> List[torch.Tensor]:
if isinstance(prompt, list):
if len(prompt) != 1:
raise ValueError("MiniT5Wrapper expects a single prompt at a time (wrapped as a list). Got multiple prompts.")
prompt = prompt[0]
if self.times_called == 0: # Only print this notice once even if called multiple times
logger.info("Weighting prompts...")
# Split positive prompts and process each with weights
prompts_raw = [p.strip() for p in prompt.split('|')]
prompts = []
all_weights = []
for p in prompts_raw:
cleaned_prompt, weights = self.parse_prompt_weights(p)
prompts.append(cleaned_prompt)
all_weights.append(weights)
context = self.t5(prompts, device)
# Apply weights to embeddings if any were extracted
for i, weights in enumerate(all_weights):
for text, weight in weights.items():
logger.info(f"Applying weight ({weight}) to promptchunk: '{text}'")
if len(weights) > 0:
context[i] = context[i] * weight
self.times_called += 1
return context
def parse_prompt_weights(self, prompt: str) -> Tuple[str, dict]:
"""Extract text and weights from prompts with (text:weight) format"""
# Parse all instances of (text:weight) in the prompt
pattern = r'\((.*?):([\d\.]+)\)'
matches = re.findall(pattern, prompt)
# Replace each match with just the text part
cleaned_prompt = prompt
weights = {}
for match in matches:
text, weight = match
orig_text = f"({text}:{weight})"
cleaned_prompt = cleaned_prompt.replace(orig_text, text)
weights[text] = float(weight)
return cleaned_prompt, weights