import pydiffvg import argparse import ttools.modules import torch import skimage.io gamma = 1.0 def main(args): perception_loss = ttools.modules.LPIPS().to(pydiffvg.get_device()) target = torch.from_numpy(skimage.io.imread(args.target)).to(torch.float32) / 255.0 target = target.pow(gamma) target = target.to(pydiffvg.get_device()) target = target.unsqueeze(0) target = target.permute(0, 3, 1, 2) # NHWC -> NCHW canvas_width, canvas_height, shapes, shape_groups = \ pydiffvg.svg_to_scene(args.svg) scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) render = pydiffvg.RenderFunction.apply img = render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, # bg *scene_args) # The output image is in linear RGB space. Do Gamma correction before saving the image. pydiffvg.imwrite(img.cpu(), 'results/refine_svg/init.png', gamma=gamma) points_vars = [] for path in shapes: path.points.requires_grad = True points_vars.append(path.points) color_vars = {} for group in shape_groups: group.fill_color.requires_grad = True color_vars[group.fill_color.data_ptr()] = group.fill_color color_vars = list(color_vars.values()) # Optimize points_optim = torch.optim.Adam(points_vars, lr=1.0) color_optim = torch.optim.Adam(color_vars, lr=0.01) # Adam iterations. for t in range(args.num_iter): print('iteration:', t) points_optim.zero_grad() color_optim.zero_grad() # Forward pass: render the image. scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) img = render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, # bg *scene_args) # Compose img with white background img = img[:, :, 3:4] * img[:, :, :3] + torch.ones(img.shape[0], img.shape[1], 3, device = pydiffvg.get_device()) * (1 - img[:, :, 3:4]) # Save the intermediate render. pydiffvg.imwrite(img.cpu(), 'results/refine_svg/iter_{}.png'.format(t), gamma=gamma) img = img[:, :, :3] # Convert img from HWC to NCHW img = img.unsqueeze(0) img = img.permute(0, 3, 1, 2) # NHWC -> NCHW if args.use_lpips_loss: loss = perception_loss(img, target) else: loss = (img - target).pow(2).mean() print('render loss:', loss.item()) # Backpropagate the gradients. loss.backward() # Take a gradient descent step. points_optim.step() color_optim.step() for group in shape_groups: group.fill_color.data.clamp_(0.0, 1.0) if t % 10 == 0 or t == args.num_iter - 1: pydiffvg.save_svg('results/refine_svg/iter_{}.svg'.format(t), canvas_width, canvas_height, shapes, shape_groups) # Render the final result. scene_args = pydiffvg.RenderFunction.serialize_scene(\ canvas_width, canvas_height, shapes, shape_groups) img = render(canvas_width, # width canvas_height, # height 2, # num_samples_x 2, # num_samples_y 0, # seed None, # bg *scene_args) # Save the intermediate render. pydiffvg.imwrite(img.cpu(), 'results/refine_svg/final.png'.format(t), gamma=gamma) # Convert the intermediate renderings to a video. from subprocess import call call(["ffmpeg", "-framerate", "24", "-i", "results/refine_svg/iter_%d.png", "-vb", "20M", "results/refine_svg/out.mp4"]) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("svg", help="source SVG path") parser.add_argument("target", help="target image path") parser.add_argument("--use_lpips_loss", dest='use_lpips_loss', action='store_true') parser.add_argument("--num_iter", type=int, default=250) args = parser.parse_args() main(args)