|
import torch |
|
import numpy as np |
|
from .processors import Processor_id |
|
|
|
|
|
class ControlNetConfigUnit: |
|
def __init__(self, processor_id: Processor_id, model_path, scale=1.0): |
|
self.processor_id = processor_id |
|
self.model_path = model_path |
|
self.scale = scale |
|
|
|
|
|
class ControlNetUnit: |
|
def __init__(self, processor, model, scale=1.0): |
|
self.processor = processor |
|
self.model = model |
|
self.scale = scale |
|
|
|
|
|
class MultiControlNetManager: |
|
def __init__(self, controlnet_units=[]): |
|
self.processors = [unit.processor for unit in controlnet_units] |
|
self.models = [unit.model for unit in controlnet_units] |
|
self.scales = [unit.scale for unit in controlnet_units] |
|
|
|
def process_image(self, image, processor_id=None): |
|
if processor_id is None: |
|
processed_image = [processor(image) for processor in self.processors] |
|
else: |
|
processed_image = [self.processors[processor_id](image)] |
|
processed_image = torch.concat([ |
|
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0) |
|
for image_ in processed_image |
|
], dim=0) |
|
return processed_image |
|
|
|
def __call__( |
|
self, |
|
sample, timestep, encoder_hidden_states, conditionings, |
|
tiled=False, tile_size=64, tile_stride=32 |
|
): |
|
res_stack = None |
|
for conditioning, model, scale in zip(conditionings, self.models, self.scales): |
|
res_stack_ = model( |
|
sample, timestep, encoder_hidden_states, conditioning, |
|
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
|
) |
|
res_stack_ = [res * scale for res in res_stack_] |
|
if res_stack is None: |
|
res_stack = res_stack_ |
|
else: |
|
res_stack = [i + j for i, j in zip(res_stack, res_stack_)] |
|
return res_stack |
|
|