Spaces:
Runtime error
Runtime error
from .util import remove_empty_str | |
from comfy.model_patcher import ModelPatcher | |
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed | |
import comfy.model_management as model_management | |
from transformers.generation.logits_process import LogitsProcessorList | |
import os | |
import random | |
import sys | |
import torch | |
import math | |
fooocus_expansion_path = "prompt_expansion" | |
SEED_LIMIT_NUMPY = 2**32 | |
neg_inf = -8192.0 | |
def safe_str(x): | |
x = str(x) | |
for _ in range(16): | |
x = x.replace(" ", " ") | |
return x.strip(",. \r\n") | |
class FooocusExpansion: | |
tokenizer = None | |
model = None | |
def __init__(self): | |
self.load_model_and_tokenizer(fooocus_expansion_path) | |
self.offload_device = model_management.text_encoder_offload_device() | |
self.patcher = ModelPatcher( | |
self.model, | |
load_device=self.model.device, | |
offload_device=self.offload_device, | |
) | |
def load_model_and_tokenizer(cls, model_path): | |
if cls.tokenizer is None or cls.model is None: | |
cls.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
cls.model = AutoModelForCausalLM.from_pretrained(model_path) | |
cls.model.to("cpu") | |
def __call__(self, prompt, seed): | |
seed = int(seed) % SEED_LIMIT_NUMPY | |
set_seed(seed) | |
positive_words = ( | |
open(os.path.join(fooocus_expansion_path, "positive.txt"), encoding="utf-8") | |
.read() | |
.splitlines() | |
) | |
positive_words = ["Ġ" + x.lower() for x in positive_words if x != ""] | |
self.logits_bias = ( | |
torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf | |
) | |
debug_list = [] | |
for k, v in self.tokenizer.vocab.items(): | |
if k in positive_words: | |
self.logits_bias[0, v] = 0 | |
debug_list.append(k[1:]) | |
# print(f'Expansion: Vocab with {len(debug_list)} words.') | |
text = safe_str(prompt) + "," | |
tokenized_kwargs = self.tokenizer(text, return_tensors="pt") | |
tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to( | |
self.patcher.load_device | |
) | |
tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data[ | |
"attention_mask" | |
].to(self.patcher.load_device) | |
current_token_length = int(tokenized_kwargs.data["input_ids"].shape[1]) | |
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0)) | |
max_new_tokens = max_token_length - current_token_length | |
features = self.model.generate( | |
**tokenized_kwargs, | |
top_k=100, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
logits_processor=LogitsProcessorList([self.logits_processor]) | |
) | |
response = self.tokenizer.batch_decode(features, skip_special_tokens=True) | |
result = safe_str(response[0]) | |
return result | |
def logits_processor(self, input_ids, scores): | |
assert scores.ndim == 2 and scores.shape[0] == 1 | |
self.logits_bias = self.logits_bias.to(scores) | |
bias = self.logits_bias.clone() | |
bias[0, input_ids[0].to(bias.device).long()] = neg_inf | |
bias[0, 11] = 0 | |
return scores + bias | |
class PromptExpansion: | |
# Define the expected input types for the node | |
def expand_prompt(text): | |
expansion = FooocusExpansion() | |
prompt = remove_empty_str([safe_str(text)], default="")[0] | |
max_seed = int(1024 * 1024 * 1024) | |
seed = random.randint(1, max_seed) | |
if seed < 0: | |
seed = -seed | |
seed = seed % max_seed | |
expansion_text = expansion(prompt, seed) | |
final_prompt = expansion_text | |
return final_prompt | |
# Define a mapping of node class names to their respective classes | |