Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Stylegan-nada-ailanta.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE | |
# Проект "CLIP-Guided Domain Adaptation of Image Generators" | |
Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946). | |
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя: | |
- Сдвиг генератора по текстовому промпту | |
- Генерация примеров | |
- Генерация примеров из готовых пресетов | |
- Веб-демо | |
- Стилизация изображения из файла | |
## 1. Установка | |
""" | |
# @title | |
# Импорт нужных библиотек | |
import os | |
import sys | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from torchvision import transforms | |
from torchvision.utils import save_image | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import subprocess | |
import gdown | |
# Настройка устройства | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
if not os.path.exists("stylegan2-pytorch"): | |
subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"]) | |
os.chdir("stylegan2-pytorch") | |
gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT') | |
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL') | |
sys.path.append("/home/user/app/stylegan2-pytorch") | |
from model import Generator | |
# Параметры генератора | |
latent_dim = 512 | |
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device) | |
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device) | |
f_generator.load_state_dict(state_dict['g_ema']) | |
f_generator.eval() | |
# Загрузка пресетов | |
os.makedirs("/content/presets", exist_ok=True) | |
gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth') | |
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth') | |
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth') | |
# Загрузка генератора из файла | |
def load_model(file_path, latent_dim=512, size=1024): | |
state_dicts = torch.load(file_path, map_location=device) | |
# Инициализация | |
trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device) | |
# Загрузка весов | |
trained_generator.load_state_dict(state_dicts) | |
trained_generator.eval() | |
return trained_generator | |
model_paths = { | |
"Photo -> Pencil Sketch": "/content/presets/sketch.pth", | |
"Photo -> Modigliani Painting": "/content/presets/modigliani.pth", | |
"Human -> Werewolf": "/content/presets/werewolf.pth" | |
} | |
# Функция обработки | |
def generate(model_name): | |
model_path = model_paths[model_name] | |
g_generator = load_model(model_path) | |
images = [] | |
with torch.no_grad(): | |
w_optimized = f_generator.style(torch.randn(2, latent_dim).to(device)) | |
w_plus = w_optimized.unsqueeze(1).repeat(1, f_generator.n_latent, 1).clone() | |
frozen_images = f_generator(w_plus.unsqueeze(0), input_is_latent=True)[0] | |
frozen_images = (frozen_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
frozen_images = frozen_images.permute(0, 2, 3, 1).cpu().numpy() | |
images.extend(frozen_images) | |
trained_images = g_generator(w_plus.unsqueeze(0), input_is_latent=True)[0] | |
trained_images = (trained_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1] | |
trained_images = trained_images.permute(0, 2, 3, 1).cpu().numpy() | |
images.extend(trained_images) | |
return images | |
# Интерфейс | |
iface = gr.Interface( | |
fn=generate, | |
inputs=gr.Dropdown(choices=list(model_paths.keys()), label="Выберите пресет"), | |
outputs=gr.Gallery(label="Результаты генерации", columns=2), | |
title="Выбор модели", | |
description="Выберите преобразование из списка." | |
) | |
iface.launch(debug=True) | |