ailanta commited on
Commit
db177eb
·
verified ·
1 Parent(s): 002c87e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Stylegan-nada-ailanta.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1ysq4Y2sv7WTE0sW-n5W_HSgE28vaUDNE
8
+
9
+ # Проект "CLIP-Guided Domain Adaptation of Image Generators"
10
+
11
+ Данный проект представляет собой имплементацию подхода StyleGAN-NADA, предложенного в статье [StyleGAN-NADA: CLIP-Guided Domain Adaptation of Image Generators](https://arxiv.org/pdf/2108.00946).
12
+ Представленный ниже функционал предназначен для визуализации реализованного проекта и включает в себя:
13
+ - Сдвиг генератора по текстовому промпту
14
+ - Генерация примеров
15
+ - Генерация примеров из готовых пресетов
16
+ - Веб-демо
17
+ - Стилизация изображения из файла
18
+
19
+ ## 1. Установка
20
+ """
21
+
22
+ # @title
23
+ # Импорт нужных библиотек
24
+ import os
25
+ import sys
26
+ from tqdm import tqdm
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.optim as optim
30
+ from torchvision import transforms
31
+ from torchvision.utils import save_image
32
+ from PIL import Image
33
+ import numpy as np
34
+ import matplotlib.pyplot as plt
35
+
36
+ # Настройка устройства
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # Установка библиотек
40
+ !pip install ftfy regex tqdm
41
+ !pip install git+https://github.com/openai/CLIP.git
42
+ !pip install Ninja
43
+
44
+ # Клонирование репозитория StyleGAN2
45
+ !git clone https://github.com/rosinality/stylegan2-pytorch.git
46
+ os.chdir("stylegan2-pytorch")
47
+
48
+ # Скачивание весов StyleGAN
49
+ !pip install gdown
50
+ !gdown https://drive.google.com/uc?id=1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT
51
+ !gdown https://drive.google.com/uc?id=1N0MZSqPRJpLfP4mFQCS14ikrVSe8vQlL
52
+ from model import Generator
53
+
54
+ # Параметры генератора
55
+ latent_dim = 512
56
+ f_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
57
+ state_dict = torch.load('stylegan2-ffhq-config-f.pt', map_location=device)
58
+ f_generator.load_state_dict(state_dict['g_ema'])
59
+ f_generator.eval()
60
+
61
+ g_generator = Generator(size=1024, style_dim=latent_dim, n_mlp=8).to(device)
62
+ g_generator.load_state_dict(state_dict['g_ema'])
63
+
64
+ # Загрузка модели CLIP
65
+ import clip
66
+ clip_model, preprocess = clip.load("ViT-B/32", device=device)
67
+
68
+ latent_dim=512
69
+ batch_size=4
70
+
71
+ """## 6. Готовые пресеты"""
72
+
73
+ # @title Загрузка пресетов
74
+ os.makedirs("/content/presets", exist_ok=True)
75
+
76
+ !gdown --output /content/presets/sketch.pth https://drive.google.com/uc?id=1trcBvlz7jeBRLNeCyNVCXE4esW25GPaZ
77
+ !gdown --output /content/presets/modigliani.pth https://drive.google.com/uc?id=1N4C-aTwxeOamZX2GeEElppsMv-ALKojL
78
+ !gdown --output /content/presets/werewolf.pth https://drive.google.com/uc?id=1VZHEalFyEFGWIaHei98f9XPyHHvMBp6J
79
+
80
+
81
+ # @title Генерация примеров из пресета
82
+ # Загрузка генератора из файла
83
+ def load_model(file_path, latent_dim=512, size=1024):
84
+
85
+ state_dicts = torch.load(file_path, map_location=device)
86
+
87
+ # Инициализация
88
+ trained_generator = Generator(size=size, style_dim=latent_dim, n_mlp=8).to(device)
89
+
90
+ # Загрузка весов
91
+ trained_generator.load_state_dict(state_dicts)
92
+
93
+ trained_generator.eval()
94
+
95
+ return trained_generator
96
+
97
+ model_paths = {
98
+ "Photo -> Pencil Sketch": "/content/presets/sketch.pth",
99
+ "Photo -> Modigliani Painting": "/content/presets/modigliani.pth",
100
+ "Human -> Werewolf": "/content/presets/werewolf.pth"
101
+ }
102
+
103
+
104
+ """## 8. Веб-демо"""
105
+ !pip install gradio
106
+
107
+ import gradio as gr
108
+
109
+ # Функция обработки
110
+ def generate(model_name):
111
+ model_path = model_paths[model_name]
112
+ g_generator = load_model(model_path)
113
+ images = []
114
+ with torch.no_grad():
115
+ w_optimized = f_generator.style(torch.randn(2, latent_dim).to(device))
116
+ w_plus = w_optimized.unsqueeze(1).repeat(1, f_generator.n_latent, 1).clone()
117
+
118
+ frozen_images = f_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
119
+ frozen_images = (frozen_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
120
+ frozen_images = frozen_images.permute(0, 2, 3, 1).cpu().numpy()
121
+ images.extend(frozen_images)
122
+ trained_images = g_generator(w_plus.unsqueeze(0), input_is_latent=True)[0]
123
+ trained_images = (trained_images.clamp(-1, 1) + 1) / 2.0 # Нормализация к [0, 1]
124
+ trained_images = trained_images.permute(0, 2, 3, 1).cpu().numpy()
125
+ images.extend(trained_images)
126
+ return images
127
+
128
+ # Интерфейс
129
+ iface = gr.Interface(
130
+ fn=generate,
131
+ inputs=gr.Dropdown(choices=list(model_paths.keys()), label="Выберите пресет"),
132
+ outputs=gr.Gallery(label="Результаты генерации", columns=2),
133
+ title="Выбор модели",
134
+ description="Выберите преобразование из списка."
135
+ )
136
+
137
+ iface.launch(debug=True)