Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,647 Bytes
90a9dd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
# Copyright (c) 2024 <Julius Erbach ETH Zurich>
#
# This file is part of the var_post_samp project and is licensed under the MIT License.
# See the LICENSE file in the project root for more information.
"""
Usage:
python run_image_inv.py --config <config.yaml>
"""
import os
import sys
import time
import csv
import yaml
import torch
import random
import click
import numpy as np
import tqdm
import datetime
import torchvision
from flair.helper_functions import parse_click_context
from flair.pipelines import model_loader
from flair.utils import data_utils
from flair import var_post_samp
dtype = torch.bfloat16
num_gpus = torch.cuda.device_count()
if num_gpus > 0:
devices = [f"cuda:{i}" for i in range(num_gpus)]
primary_device = devices[0]
print(f"Using devices: {devices}")
print(f"Primary device for operations: {primary_device}")
else:
print("No CUDA devices found. Using CPU.")
devices = ["cpu"]
primary_device = "cpu"
@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True})
@click.option("--config", "config_file_arg", type=click.Path(exists=True), help="Path to the config file")
@click.option("--target_file", type=click.Path(exists=True), help="Path to the target file or folder")
@click.option("--result_folder", type=click.Path(file_okay=False, dir_okay=True, writable=True, resolve_path=True), help="Path to the output folder. It will be created if it doesn't exist.")
@click.option("--mask_file", type=click.Path(exists=True), default=None, help="Path to the mask file npy. Optional used for image inpainting. True pixels are observed.")
@click.pass_context
def main(ctx, config_file_arg, target_file, result_folder, mask_file=None):
"""Main entry point for image inversion and sampling.
The user must provide either a caption_file (with per-image captions) OR a single prompt for all images in the config YAML file.
"""
with open(config_file_arg, "r") as f:
config = yaml.safe_load(f)
ctx = parse_click_context(ctx)
config.update(ctx)
# Read caption_file and prompt from config
caption_file = config.get("caption_file", None)
prompt = config.get("prompt", None)
# Enforce mutually exclusive caption_file or prompt
if (not caption_file and not prompt) or (caption_file and prompt):
raise ValueError("You must provide either 'caption_file' OR 'prompt' (not both) in the config file. See documentation.")
# wandb removed, so config_dict is just a copy
config_dict = dict(config)
torch.manual_seed(config["seed"])
np.random.seed(config["seed"])
random.seed(config["seed"])
# Use config values as-is (no to_absolute_path)
caption_file = caption_file if caption_file else None
guidance_img_iterator = data_utils.yield_images(
target_file, size=config["resolution"]
)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
counter = 1
name = f'results_{config["model"]}_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}_{timestamp}'
candidate = os.path.join(name)
while os.path.exists(candidate):
candidate = os.path.join(f"{name}_{counter}")
counter += 1
output_folders = data_utils.generate_output_structure(
result_folder,
[
candidate,
f'input_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
f'target_{config["degradation"]["name"]}_resolution_{config["resolution"]}_noise_{config["degradation"]["kwargs"]["noise_std"]}',
],
)
config_out = os.path.join(os.path.split(output_folders[0])[0], "config.yaml")
with open(config_out, "w") as f:
yaml.safe_dump(config_dict, f)
source_files = list(data_utils.find_files(target_file, ext="png"))
num_images = len(source_files)
print(f"Found {num_images} images.")
# Load captions
if caption_file:
captions = data_utils.load_captions_from_file(caption_file, user_prompt="")
if not captions:
sys.exit("Error: No captions were loaded from the provided caption file.")
if len(captions) != num_images:
print("Warning: Number of captions does not match number of images.")
prompts_in_order = [captions.get(os.path.basename(f), "") for f in source_files]
else:
# Use the single prompt for all images
prompts_in_order = [prompt for _ in range(num_images)]
if any(p == "" for p in prompts_in_order):
print("Warning: Some images might not have corresponding captions or prompt is empty.")
config["prompt"] = prompts_in_order
model, inp_kwargs = model_loader.load_model(config, device=devices)
if mask_file and config["degradation"]["name"] == "Inpainting":
config["degradation"]["kwargs"]["mask"] = mask_file
posterior_model = var_post_samp.VariationalPosterior(model, config)
guidance_img_iterator = data_utils.yield_images(
target_file, size=config["resolution"]
)
for idx, guidance_img in tqdm.tqdm(enumerate(guidance_img_iterator), total=num_images):
guidance_img = guidance_img.to(dtype).cuda()
y = posterior_model.forward_operator(guidance_img)
tic = time.time()
with torch.no_grad():
result_dict = posterior_model.forward(y, inp_kwargs[idx])
x_hat = result_dict["x_hat"]
toc = time.time()
print(f"Runtime: {toc - tic}")
guidance_img = guidance_img.cuda()
result_file = output_folders[0].format(idx)
input_file = output_folders[1].format(idx)
ground_truth_file = output_folders[2].format(idx)
x_hat_pil = torchvision.transforms.ToPILImage()(
x_hat.float()[0].clip(-1, 1) * 0.5 + 0.5
)
x_hat_pil.save(result_file)
try:
if config["degradation"]["name"] == "SuperRes":
input_img = posterior_model.forward_operator.nn(y)
else:
input_img = posterior_model.forward_operator.pseudo_inv(y)
input_img_pil = torchvision.transforms.ToPILImage()(
input_img.float()[0].clip(-1, 1) * 0.5 + 0.5
)
input_img_pil.save(input_file)
except Exception:
print("Error in pseudo-inverse operation. Skipping input image save.")
guidance_img_pil = torchvision.transforms.ToPILImage()(
guidance_img.float()[0] * 0.5 + 0.5
)
guidance_img_pil.save(ground_truth_file)
if __name__ == "__main__":
main()
|