import sys sys.path.append('./') import gradio as gr import spaces import os import sys import subprocess import numpy as np from PIL import Image import cv2 import torch import random os.system("pip install -e ./controlnet_aux") from controlnet_aux import OpenposeDetector #, CannyDetector from depth_anything_v2.dpt import DepthAnythingV2 from huggingface_hub import hf_hub_download from huggingface_hub import login hf_token = os.environ.get("HF_TOKEN") login(token=hf_token) MAX_SEED = np.iinfo(np.int32).max def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: if randomize_seed: seed = random.randint(0, MAX_SEED) return seed DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } ratios_map = { 0.5:{"width":704,"height":1408}, 0.57:{"width":768,"height":1344}, 0.68:{"width":832,"height":1216}, 0.72:{"width":832,"height":1152}, 0.78:{"width":896,"height":1152}, 0.82:{"width":896,"height":1088}, 0.88:{"width":960,"height":1088}, 0.94:{"width":960,"height":1024}, 1.00:{"width":1024,"height":1024}, 1.13:{"width":1088,"height":960}, 1.21:{"width":1088,"height":896}, 1.29:{"width":1152,"height":896}, 1.38:{"width":1152,"height":832}, 1.46:{"width":1216,"height":832}, 1.67:{"width":1280,"height":768}, 1.75:{"width":1344,"height":768}, 2.00:{"width":1408,"height":704} } ratios = np.array(list(ratios_map.keys())) encoder = 'vitl' model = DepthAnythingV2(**model_configs[encoder]) filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-Large", filename=f"depth_anything_v2_vitl.pth", repo_type="model") state_dict = torch.load(filepath, map_location="cpu") model.load_state_dict(state_dict) model = model.to(DEVICE).eval() from huggingface_hub import hf_hub_download import os import torch from diffusers.utils import load_image from controlnet_bria import BriaControlNetModel, BriaMultiControlNetModel from pipeline_bria_controlnet import BriaControlNetPipeline import PIL.Image as Image base_model = 'briaai/BRIA-4B-Adapt' controlnet_model = 'briaai/BRIA-4B-Adapt-ControlNet-Union' controlnet = BriaControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16) controlnet = BriaMultiControlNetModel([controlnet]) pipe = BriaControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16, trust_remote_code=True) pipe.to("cuda") mode_mapping = { "depth": 0, "canny": 1, "colorgrid": 2, "recolor": 3, "tile": 4, "pose": 5, } strength_mapping = { "depth": 1.0, "canny": 1.0, "colorgrid": 1.0, "recolor": 1.0, "tile": 1.0, "pose": 1.0, } open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators") torch.backends.cuda.matmul.allow_tf32 = True pipe.enable_model_cpu_offload() # for saving memory def convert_from_image_to_cv2(img: Image) -> np.ndarray: return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) def convert_from_cv2_to_image(img: np.ndarray) -> Image: return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) def extract_depth(image): image = np.asarray(image) depth = model.infer_image(image[:, :, ::-1]) depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.astype(np.uint8) gray_depth = Image.fromarray(depth).convert('RGB') return gray_depth def extract_openpose(img): processed_image_open_pose = open_pose(img, hand_and_face=True) return processed_image_open_pose def extract_canny(input_image): image = np.array(input_image) image = cv2.Canny(image, 100, 200) image = image[:, :, None] image = np.concatenate([image, image, image], axis=2) canny_image = Image.fromarray(image) return canny_image def convert_to_grayscale(image): image = convert_from_image_to_cv2(image) gray_image = convert_from_cv2_to_image(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)) return gray_image def tile_old(input_image, resolution=768): input_image = convert_from_image_to_cv2(input_image) H, W, C = input_image.shape H = float(H) W = float(W) k = float(resolution) / min(H, W) H *= k W *= k H = int(np.round(H / 16.0)) * 16 W = int(np.round(W / 16.0)) * 16 img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) img = convert_from_cv2_to_image(img) return img def tile(downscale_factor, input_image): control_image = input_image.resize((input_image.size[0] // downscale_factor, input_image.size[1] // downscale_factor)).resize(input_image.size, Image.NEAREST) return control_image def get_size(init_image): w,h=init_image.size curr_ratio = w/h ind = np.argmin(np.abs(curr_ratio-ratios)) ratio = ratios[ind] chosen_ratio = ratios_map[ratio] w,h = chosen_ratio['width'], chosen_ratio['height'] return w,h def resize_img(image): image = image.convert('RGB') w,h = get_size(image) resized_image = image.resize((w, h)) return resized_image @spaces.GPU(duration=180) def infer(cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed, progress=gr.Progress(track_tqdm=True)): control_mode_num = mode_mapping[control_mode] if cond_in is None: if image_in is not None: image_in = resize_img(load_image(image_in)) if control_mode == "canny": control_image = extract_canny(image_in) elif control_mode == "depth": control_image = extract_depth(image_in) elif control_mode == "pose": control_image = extract_openpose(image_in) elif control_mode == "colorgrid": control_image = tile(64, image_in) elif control_mode == "recolor": control_image = convert_to_grayscale(image_in) elif control_mode == "tile": control_image = tile(16, image_in) else: control_image = resize_img(load_image(cond_in)) width, height = control_image.size image = pipe( prompt, control_image=[control_image], control_mode=[control_mode_num], width=width, height=height, controlnet_conditioning_scale=[control_strength], num_inference_steps=inference_steps, guidance_scale=guidance_scale, generator=torch.manual_seed(seed), ).images[0] torch.cuda.empty_cache() return image, control_image, gr.update(visible=True) css=""" #col-container{ margin: 0 auto; max-width: 1080px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(""" # BRIA-4B-Adapt-ControlNet-Union A unified ControlNet for BRIA-4B-Adapt model from Bria.ai. BRIA-4B-Adapt improve the generation of humans and illustrations compared to BRIA 2.3 while still trained on licensed data, and so provides full legal liability coverage for copyright and privacy infringement. Model card: [BRIA-4B-Adapt-ControlNet-Union](https://huggingface.co/briaai/BRIA-4B-Adapt-ControlNet-Union).
""") with gr.Column(): with gr.Row(): with gr.Column(): with gr.Row(equal_height=True): cond_in = gr.Image(label="Upload a processed control image", sources=["upload"], type="filepath") image_in = gr.Image(label="Extract condition from a reference image (Optional)", sources=["upload"], type="filepath") prompt = gr.Textbox(label="Prompt", value="best quality") with gr.Accordion("Controlnet"): control_mode = gr.Radio( ["depth", "canny", "colorgrid", "recolor", "tile", "pose"], label="Mode", value="canny", info="select the control mode, one for all" ) control_strength = gr.Slider( label="control strength", minimum=0, maximum=1.0, step=0.05, value=0.9, ) seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Accordion("Advanced settings", open=False): with gr.Column(): with gr.Row(): inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=24) guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=3.5) submit_btn = gr.Button("Submit") with gr.Column(): result = gr.Image(label="Result") processed_cond = gr.Image(label="Preprocessed Cond") submit_btn.click( fn=randomize_seed_fn, inputs=[seed, randomize_seed], outputs=seed, queue=False, api_name=False ).then( fn = infer, inputs = [cond_in, image_in, prompt, inference_steps, guidance_scale, control_mode, control_strength, seed], outputs = [result, processed_cond], show_api=False ) demo.queue(api_open=False) demo.launch()