from __future__ import print_function import torch import torchvision import torch.nn as nn import gradio as gr import os import time import torch.nn.functional as F import torch.optim as optim import matplotlib.pyplot as plt import torchvision.transforms as transforms import copy import torchvision.models as models import torchvision.transforms.functional as TF from PIL import Image import numpy as np from model import ContentLoss, gram_matrix, StyleLoss, image_transform, get_input_optimizer,get_style_model_and_losses #Defining the Model cnn = models.vgg19(pretrained=True).features.eval() #Normalization cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]) cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]) ## style_transfer import numpy as np def run_style_transfer(cnn, normalization_mean, normalization_std, content_img, style_img, input_img, num_steps=300, style_weight=1000000, content_weight=1): """Run the style transfer.""" print('Building the style transfer model..') model, style_losses, content_losses = get_style_model_and_losses(cnn, normalization_mean, normalization_std, style_img, content_img) # We want to optimize the input and not the model parameters so we # update all the requires_grad fields accordingly input_img.requires_grad_(True) model.requires_grad_(False) optimizer = get_input_optimizer(input_img) print('Optimizing..') run = [0] while run[0] <= num_steps: def closure(): # correct the values of updated input image with torch.no_grad(): input_img.clamp_(0, 1) optimizer.zero_grad() model(input_img) style_score = 0 content_score = 0 for sl in style_losses: style_score += sl.loss for cl in content_losses: content_score += cl.loss style_score *= style_weight content_score *= content_weight loss = style_score + content_score loss.backward() run[0] += 1 if run[0] % 50 == 0: print("run {}:".format(run)) print('Style Loss : {:4f} Content Loss: {:4f}'.format( style_score.item(), content_score.item())) print() return style_score + content_score optimizer.step(closure) # a last correction... with torch.no_grad(): input_img.clamp_(0, 1) # Convert output tensor to a NumPy array output_np = input_img.detach().cpu().numpy()[0].transpose(1, 2, 0) # Convert NumPy array to PIL Image object output_img = Image.fromarray((output_np * 255).astype(np.uint8)) return output_img # Define the dimensions of the random output image image_width = 128 image_height = 128 #example_images example_list = [["examples/" + example] for example in os.listdir("examples")] #Default output image default_output_image = Image.fromarray(np.random.randint(0, 255, (image_height, image_width, 3), dtype=np.uint8)) #Defining the predict function def style_transfer(cont_img:Image.Image=None,styl_img:Image.Image=None) -> Image.Image: if cont_img is None or styl_img is None: # Handle the default behavior when inputs are not provided # For example, you can load default images or return an error message # Return a default output image or error message return default_output_image #Start the timer start_time = time.time() #transform the input image style_img = image_transform(styl_img) content_img =image_transform(cont_img) print(styl_img) print(cont_img) #getting input image input_img = content_img.clone() #running the style transfer output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, content_img, style_img, input_img) # output_img = output.detach().cpu().squeeze(0) # output_img = TF.to_pil_image(output_img) end_time=time.time() pred_time =round(end_time- start_time, 5) return output ##Gradio App import gradio as gr title= 'Style Transfer' description='A model to transfer the style of one image to another' article = 'Created at Pytorch Model Deployment' #example_images example_list = [["examples/" + example] for example in os.listdir("examples")] #Create the gradio demo demo = gr.Interface( fn=style_transfer, inputs=[ gr.Image(label="content Image"), gr.Image(label="style_image") ], examples=example_list, outputs="image", allow_flagging=False, title=title, description=description, article=article ) # Launch the Gradio interface demo.launch(debug=True)