|
from __future__ import print_function |
|
import torch |
|
import torchvision |
|
import torch.nn as nn |
|
import gradio as gr |
|
import os |
|
import time |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import matplotlib.pyplot as plt |
|
import torchvision.transforms as transforms |
|
import copy |
|
import torchvision.models as models |
|
import torchvision.transforms.functional as TF |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def image_transform(image): |
|
if isinstance(image, str): |
|
|
|
image = Image.open(image).convert('RGB') |
|
else: |
|
|
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
|
|
image = transform(image).unsqueeze(0) |
|
return image.to(device) |
|
|
|
|
|
|
|
def style_transfer(cont_img,styl_img): |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
style_img = image_transform(styl_img) |
|
content_img =image_transform(cont_img) |
|
|
|
|
|
input_img = content_img.clone() |
|
|
|
|
|
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std, |
|
content_img, style_img, input_img) |
|
|
|
|
|
end_time=time.time() |
|
|
|
pred_time =round(end_time- start_time, 5) |
|
|
|
return output |
|
|
|
|
|
import gradio as gr |
|
title= 'Style Transfer' |
|
description='A model to transfer the style of one image to another' |
|
article = 'Created at Pytorch Model Deployment' |
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=style_transfer, |
|
inputs=[ |
|
gr.inputs.Image(label="content image",type=pil), |
|
gr.inputs.Image(label="style_image",type=pil) |
|
], |
|
examples=example_list, |
|
outputs="image", |
|
allow_flagging=False, |
|
title=title, |
|
description=description, |
|
article=article |
|
) |
|
|
|
|
|
demo.launch(debug=True) |
|
|
|
|