ailanta's picture
Update app.py
39cae37 verified
# -*- coding: utf-8 -*-
"""Stylegan-nada-ailanta.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE
# Проект "CLIP-Guided Domain Adaptation of Image Generators"
Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
- Сдвиг генератора по текстовому промпту
- Генерация примеров
- Генерация примеров из готовых пресетов
- Веб-демо
- Стилизация изображения из файла
## 1. Установка
"""
# @title
# Импорт нужных библиотек
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import gradio as gr
import subprocess
import gdown
# Настройка устройства
device = "cuda" if torch.cuda.is_available() else "cpu"
if not os.path.exists("stylegan2-pytorch"):
subprocess.run(["git", "clone", "https://github.com/rosinality/stylegan2-pytorch.git"])
os.chdir("stylegan2-pytorch")
gdown.download('https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT')
gdown.download('https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL')
sys.path.append("/home/user/app/stylegan2-pytorch")
from model import Generator
# Параметры генератора
latent_dim = 512
f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
f_generator.load_state_dict(state_dict['g_ema'])
f_generator.eval()
# Загрузка пресетов
os.makedirs("/content/presets", exist_ok=True)
gdown.download('https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ', '/content/presets/sketch.pth')
gdown.download('https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL', '/content/presets/modigliani.pth')
gdown.download('https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J', '/content/presets/werewolf.pth')
# Загрузка генератора из файла
def load_model(file_path, latent_dim=512, size=1024):
state_dicts = torch.load(file_path, map_location=device)
# Инициализация
trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)
# Загрузка весов
trained_generator.load_state_dict(state_dicts)
trained_generator.eval()
return trained_generator
model_paths = {
"Photo -> Pencil Sketch": "/content/presets/sketch.pth",
"Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
"Human -> Werewolf": "/content/presets/werewolf.pth"
}
# Функция обработки
def generate(model_name):
model_path = model_paths[model_name]
g_generator = load_model(model_path)
images = []
with torch.no_grad():
w_optimized = f_generator.style(torch.randn(2, latent_dim).to(device))
w_plus = w_optimized.unsqueeze(1).repeat(1, f_generator.n_latent, 1).clone()
frozen_images = f_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
frozen_images = (frozen_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
frozen_images = frozen_images.permute(0, 2, 3, 1).cpu().numpy()
images.extend(frozen_images)
trained_images = g_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
trained_images = (trained_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
trained_images = trained_images.permute(0, 2, 3, 1).cpu().numpy()
images.extend(trained_images)
return images
# Интерфейс
iface = gr.Interface(
fn=generate,
inputs=gr.Dropdown(choices=list(model_paths.keys()), label="Выберите пресет"),
outputs=gr.Gallery(label="Результаты генерации", columns=2),
title="Выбор модели",
description="Выберите преобразование из списка."
)
iface.launch(debug=True)