Spaces:
Runtime error
Runtime error
File size: 3,901 Bytes
2de3774 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from .util import remove_empty_str
from comfy.model_patcher import ModelPatcher
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
import comfy.model_management as model_management
from transformers.generation.logits_process import LogitsProcessorList
import os
import random
import sys
import torch
import math
fooocus_expansion_path = "prompt_expansion"
SEED_LIMIT_NUMPY = 2**32
neg_inf = -8192.0
def safe_str(x):
x = str(x)
for _ in range(16):
x = x.replace(" ", " ")
return x.strip(",. \r\n")
class FooocusExpansion:
tokenizer = None
model = None
def __init__(self):
self.load_model_and_tokenizer(fooocus_expansion_path)
self.offload_device = model_management.text_encoder_offload_device()
self.patcher = ModelPatcher(
self.model,
load_device=self.model.device,
offload_device=self.offload_device,
)
@classmethod
def load_model_and_tokenizer(cls, model_path):
if cls.tokenizer is None or cls.model is None:
cls.tokenizer = AutoTokenizer.from_pretrained(model_path)
cls.model = AutoModelForCausalLM.from_pretrained(model_path)
cls.model.to("cpu")
def __call__(self, prompt, seed):
seed = int(seed) % SEED_LIMIT_NUMPY
set_seed(seed)
positive_words = (
open(os.path.join(fooocus_expansion_path, "positive.txt"), encoding="utf-8")
.read()
.splitlines()
)
positive_words = ["Ġ" + x.lower() for x in positive_words if x != ""]
self.logits_bias = (
torch.zeros((1, len(self.tokenizer.vocab)), dtype=torch.float32) + neg_inf
)
debug_list = []
for k, v in self.tokenizer.vocab.items():
if k in positive_words:
self.logits_bias[0, v] = 0
debug_list.append(k[1:])
# print(f'Expansion: Vocab with {len(debug_list)} words.')
text = safe_str(prompt) + ","
tokenized_kwargs = self.tokenizer(text, return_tensors="pt")
tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to(
self.patcher.load_device
)
tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data[
"attention_mask"
].to(self.patcher.load_device)
current_token_length = int(tokenized_kwargs.data["input_ids"].shape[1])
max_token_length = 75 * int(math.ceil(float(current_token_length) / 75.0))
max_new_tokens = max_token_length - current_token_length
features = self.model.generate(
**tokenized_kwargs,
top_k=100,
max_new_tokens=max_new_tokens,
do_sample=True,
logits_processor=LogitsProcessorList([self.logits_processor])
)
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
result = safe_str(response[0])
return result
def logits_processor(self, input_ids, scores):
assert scores.ndim == 2 and scores.shape[0] == 1
self.logits_bias = self.logits_bias.to(scores)
bias = self.logits_bias.clone()
bias[0, input_ids[0].to(bias.device).long()] = neg_inf
bias[0, 11] = 0
return scores + bias
class PromptExpansion:
# Define the expected input types for the node
@staticmethod
@torch.no_grad()
def expand_prompt(text):
expansion = FooocusExpansion()
prompt = remove_empty_str([safe_str(text)], default="")[0]
max_seed = int(1024 * 1024 * 1024)
seed = random.randint(1, max_seed)
if seed < 0:
seed = -seed
seed = seed % max_seed
expansion_text = expansion(prompt, seed)
final_prompt = expansion_text
return final_prompt
# Define a mapping of node class names to their respective classes
|