Spaces:
Paused
Paused
File size: 1,429 Bytes
343e5a8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import torch
from PIL import Image
import torchvision.transforms.functional as F
from src.pix2pix_turbo import Pix2Pix_Turbo
import numpy as np
def process_sketch(sketch_path, output_path, prompt, val_r=0.4, seed=42):
# Load the model
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
# Set the seed for reproducibility
torch.manual_seed(seed)
# Load the sketch image
image = Image.open(sketch_path).convert("RGB")
# Convert the image to tensor and threshold it
image_t = F.to_tensor(image) > 0.5
# Prepare the input tensor
with torch.no_grad():
c_t = image_t.unsqueeze(0).cuda().float()
B, C, H, W = c_t.shape
# Create a random noise map
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
# Call the Pix2Pix model
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
# Convert the output tensor to PIL image
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
# Save the output image
output_pil.save(output_path)
print(f"Output image saved to {output_path}")
if __name__ == "__main__":
sketch_path = "sketch.png"
output_path = "output.png"
prompt = ("a fantasy concept art of a magical castle in the sky, ")
process_sketch(sketch_path, output_path, prompt)
|