|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from PIL import Image |
|
|
|
import torch |
|
from torchvision.utils import make_grid |
|
|
|
|
|
def plt_batch( |
|
photos: torch.Tensor, |
|
sketch: torch.Tensor, |
|
step: int, |
|
prompt: str, |
|
save_path: str, |
|
name: str, |
|
dpi: int = 300 |
|
): |
|
if photos.shape != sketch.shape: |
|
raise ValueError("photos and sketch must have the same dimensions") |
|
|
|
plt.figure() |
|
plt.subplot(1, 2, 1) |
|
grid = make_grid(photos, normalize=True, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title("Generated sample") |
|
|
|
plt.subplot(1, 2, 2) |
|
grid = make_grid(sketch, normalize=False, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title(f"Rendering result - {step} steps") |
|
|
|
plt.suptitle(insert_newline(prompt), fontsize=10) |
|
|
|
plt.tight_layout() |
|
plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def plt_triplet( |
|
photos: torch.Tensor, |
|
sketch: torch.Tensor, |
|
style: torch.Tensor, |
|
step: int, |
|
prompt: str, |
|
save_path: str, |
|
name: str, |
|
dpi: int = 300 |
|
): |
|
if photos.shape != sketch.shape: |
|
raise ValueError("photos and sketch must have the same dimensions") |
|
|
|
plt.figure() |
|
plt.subplot(1, 3, 1) |
|
grid = make_grid(photos, normalize=True, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title("Generated sample") |
|
|
|
plt.subplot(1, 3, 2) |
|
|
|
grid = make_grid(style, normalize=False, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title(f"Style") |
|
|
|
plt.subplot(1, 3, 3) |
|
|
|
grid = make_grid(sketch, normalize=False, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title(f"Rendering result - {step} steps") |
|
|
|
plt.suptitle(insert_newline(prompt), fontsize=10) |
|
|
|
plt.tight_layout() |
|
plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def insert_newline(string, point=9): |
|
|
|
words = string.split() |
|
if len(words) <= point: |
|
return string |
|
|
|
word_chunks = [words[i:i + point] for i in range(0, len(words), point)] |
|
new_string = "\n".join(" ".join(chunk) for chunk in word_chunks) |
|
return new_string |
|
|
|
|
|
def log_tensor_img(inputs, output_dir, output_prefix="input", norm=False, dpi=300): |
|
grid = make_grid(inputs, normalize=norm, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.tight_layout() |
|
plt.savefig(f"{output_dir}/{output_prefix}.png", dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def plt_tensor_img(tensor, title, save_path, name, dpi=500): |
|
grid = make_grid(tensor, normalize=True, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.title(f"{title}") |
|
plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def save_tensor_img(tensor, save_path, name, dpi=500): |
|
grid = make_grid(tensor, normalize=True, pad_value=2) |
|
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
plt.imshow(ndarr) |
|
plt.axis("off") |
|
plt.tight_layout() |
|
plt.savefig(f"{save_path}/{name}.png", dpi=dpi) |
|
plt.close() |
|
|
|
|
|
def plt_attn(attn, threshold_map, inputs, inds, output_path): |
|
|
|
plt.figure(figsize=(10, 5)) |
|
|
|
plt.subplot(1, 3, 1) |
|
main_im = make_grid(inputs, normalize=True, pad_value=2) |
|
main_im = np.transpose(main_im.cpu().numpy(), (1, 2, 0)) |
|
plt.imshow(main_im, interpolation='nearest') |
|
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') |
|
plt.title("input img") |
|
plt.axis("off") |
|
|
|
plt.subplot(1, 3, 2) |
|
plt.imshow(attn, interpolation='nearest', vmin=0, vmax=1) |
|
plt.title("attn map") |
|
plt.axis("off") |
|
|
|
plt.subplot(1, 3, 3) |
|
threshold_map_ = (threshold_map - threshold_map.min()) / \ |
|
(threshold_map.max() - threshold_map.min()) |
|
plt.imshow(np.nan_to_num(threshold_map_), interpolation='nearest', vmin=0, vmax=1) |
|
plt.title("prob softmax") |
|
plt.scatter(inds[:, 1], inds[:, 0], s=10, c='red', marker='o') |
|
plt.axis("off") |
|
|
|
plt.tight_layout() |
|
plt.savefig(output_path) |
|
plt.close() |
|
|
|
|
|
def fix_image_scale(im): |
|
im_np = np.array(im) / 255 |
|
height, width = im_np.shape[0], im_np.shape[1] |
|
max_len = max(height, width) + 20 |
|
new_background = np.ones((max_len, max_len, 3)) |
|
y, x = max_len // 2 - height // 2, max_len // 2 - width // 2 |
|
new_background[y: y + height, x: x + width] = im_np |
|
new_background = (new_background / new_background.max() |
|
* 255).astype(np.uint8) |
|
new_im = Image.fromarray(new_background) |
|
return new_im |
|
|