Spaces:
Paused
Paused
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +11 -0
- .gitignore +171 -0
- LICENSE +21 -0
- README.md +234 -12
- assets/cat_2x.gif +3 -0
- assets/clear2rainy_results.jpg +3 -0
- assets/day2night_results.jpg +3 -0
- assets/edge_to_image_results.jpg +3 -0
- assets/examples/bird.png +3 -0
- assets/examples/bird_canny.png +0 -0
- assets/examples/bird_canny_blue.png +0 -0
- assets/examples/circles_inference_input.png +0 -0
- assets/examples/circles_inference_output.png +0 -0
- assets/examples/clear2rainy_input.png +0 -0
- assets/examples/clear2rainy_output.png +0 -0
- assets/examples/day2night_input.png +0 -0
- assets/examples/day2night_output.png +0 -0
- assets/examples/my_horse2zebra_input.jpg +0 -0
- assets/examples/my_horse2zebra_output.jpg +0 -0
- assets/examples/night2day_input.png +0 -0
- assets/examples/night2day_output.png +0 -0
- assets/examples/rainy2clear_input.png +0 -0
- assets/examples/rainy2clear_output.png +0 -0
- assets/examples/sketch_input.png +0 -0
- assets/examples/sketch_output.png +0 -0
- assets/examples/training_evaluation.png +0 -0
- assets/examples/training_evaluation_unpaired.png +0 -0
- assets/examples/training_step_0.png +0 -0
- assets/examples/training_step_500.png +0 -0
- assets/examples/training_step_6000.png +0 -0
- assets/fish_2x.gif +3 -0
- assets/gen_variations.jpg +3 -0
- assets/method.jpg +0 -0
- assets/night2day_results.jpg +3 -0
- assets/rainy2clear.jpg +3 -0
- assets/teaser_results.jpg +3 -0
- docs/training_cyclegan_turbo.md +98 -0
- docs/training_pix2pix_turbo.md +118 -0
- environment.yaml +34 -0
- gradio_canny2image.py +78 -0
- gradio_sketch2image.py +382 -0
- python==3.9.8/Lib/site-packages/wheel/cli/tags.py +139 -0
- python==3.9.8/conda-meta/history +19 -0
- requirements.txt +28 -0
- scripts/download_fill50k.sh +5 -0
- scripts/download_horse2zebra.sh +5 -0
- src/cyclegan_turbo.py +254 -0
- src/image_prep.py +12 -0
- src/inference_paired.py +65 -0
- src/inference_unpaired.py +53 -0
.gitattributes
CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/cat_2x.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/clear2rainy_results.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/day2night_results.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/edge_to_image_results.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/examples/bird.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/fish_2x.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/gen_variations.jpg filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/night2day_results.jpg filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/rainy2clear.jpg filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/teaser_results.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
triton-2.1.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
single_step_translation/
|
2 |
+
|
3 |
+
# Byte-compiled / optimized / DLL files
|
4 |
+
__pycache__/
|
5 |
+
*.py[cod]
|
6 |
+
*$py.class
|
7 |
+
|
8 |
+
# C extensions
|
9 |
+
*.so
|
10 |
+
|
11 |
+
# Distribution / packaging
|
12 |
+
.Python
|
13 |
+
build/
|
14 |
+
develop-eggs/
|
15 |
+
dist/
|
16 |
+
downloads/
|
17 |
+
eggs/
|
18 |
+
.eggs/
|
19 |
+
lib/
|
20 |
+
lib64/
|
21 |
+
parts/
|
22 |
+
sdist/
|
23 |
+
var/
|
24 |
+
wheels/
|
25 |
+
share/python-wheels/
|
26 |
+
*.egg-info/
|
27 |
+
.installed.cfg
|
28 |
+
*.egg
|
29 |
+
MANIFEST
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
cover/
|
55 |
+
|
56 |
+
# Translations
|
57 |
+
*.mo
|
58 |
+
*.pot
|
59 |
+
|
60 |
+
# Django stuff:
|
61 |
+
*.log
|
62 |
+
local_settings.py
|
63 |
+
db.sqlite3
|
64 |
+
db.sqlite3-journal
|
65 |
+
|
66 |
+
# Flask stuff:
|
67 |
+
instance/
|
68 |
+
.webassets-cache
|
69 |
+
|
70 |
+
# Scrapy stuff:
|
71 |
+
.scrapy
|
72 |
+
|
73 |
+
# Sphinx documentation
|
74 |
+
docs/_build/
|
75 |
+
|
76 |
+
# PyBuilder
|
77 |
+
.pybuilder/
|
78 |
+
target/
|
79 |
+
|
80 |
+
# Jupyter Notebook
|
81 |
+
.ipynb_checkpoints
|
82 |
+
|
83 |
+
# IPython
|
84 |
+
profile_default/
|
85 |
+
ipython_config.py
|
86 |
+
|
87 |
+
# pyenv
|
88 |
+
# For a library or package, you might want to ignore these files since the code is
|
89 |
+
# intended to run in multiple environments; otherwise, check them in:
|
90 |
+
# .python-version
|
91 |
+
|
92 |
+
# pipenv
|
93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
96 |
+
# install all needed dependencies.
|
97 |
+
#Pipfile.lock
|
98 |
+
|
99 |
+
# poetry
|
100 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
101 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
102 |
+
# commonly ignored for libraries.
|
103 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
104 |
+
#poetry.lock
|
105 |
+
|
106 |
+
# pdm
|
107 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
108 |
+
#pdm.lock
|
109 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
110 |
+
# in version control.
|
111 |
+
# https://pdm.fming.dev/#use-with-ide
|
112 |
+
.pdm.toml
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
single_step_translation
|
164 |
+
gradio
|
165 |
+
checkpoints/
|
166 |
+
img2img-turbo-sketch
|
167 |
+
outputs/
|
168 |
+
outputs/bird.png
|
169 |
+
data
|
170 |
+
wandb
|
171 |
+
output/
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 img-to-img-turbo
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,234 @@
|
|
1 |
-
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: img2img-turbo
|
3 |
+
app_file: gradio_sketch2image.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 3.43.1
|
6 |
+
---
|
7 |
+
# img2img-turbo
|
8 |
+
|
9 |
+
[**Paper**](https://arxiv.org/abs/2403.12036) | [**Sketch2Image Demo**](https://huggingface.co/spaces/gparmar/img2img-turbo-sketch)
|
10 |
+
#### **Quick start:** [**Running Locally**](#getting-started) | [**Gradio (locally hosted)**](#gradio-demo) | [**Training**](#training-with-your-own-data)
|
11 |
+
|
12 |
+
### Cat Sketching
|
13 |
+
<p align="left" >
|
14 |
+
<img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/cat_2x.gif" width="800" />
|
15 |
+
</p>
|
16 |
+
|
17 |
+
### Fish Sketching
|
18 |
+
<p align="left">
|
19 |
+
<img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/fish_2x.gif" width="800" />
|
20 |
+
</p>
|
21 |
+
|
22 |
+
|
23 |
+
We propose a general method for adapting a single-step diffusion model, such as SD-Turbo, to new tasks and domains through adversarial learning. This enables us to leverage the internal knowledge of pre-trained diffusion models while achieving efficient inference (e.g., for 512x512 images, 0.29 seconds on A6000 and 0.11 seconds on A100).
|
24 |
+
|
25 |
+
Our one-step conditional models **CycleGAN-Turbo** and **pix2pix-turbo** can perform various image-to-image translation tasks for both unpaired and paired settings. CycleGAN-Turbo outperforms existing GAN-based and diffusion-based methods, while pix2pix-turbo is on par with recent works such as ControlNet for Sketch2Photo and Edge2Image, but with one-step inference.
|
26 |
+
|
27 |
+
[One-Step Image Translation with Text-to-Image Models](https://arxiv.org/abs/2403.12036)<br>
|
28 |
+
[Gaurav Parmar](https://gauravparmar.com/), [Taesung Park](https://taesung.me/), [Srinivasa Narasimhan](https://www.cs.cmu.edu/~srinivas/), [Jun-Yan Zhu](https://github.com/junyanz/)<br>
|
29 |
+
CMU and Adobe, arXiv 2403.12036
|
30 |
+
|
31 |
+
<br>
|
32 |
+
<div>
|
33 |
+
<p align="center">
|
34 |
+
<img src='assets/teaser_results.jpg' align="center" width=1000px>
|
35 |
+
</p>
|
36 |
+
</div>
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
## Results
|
42 |
+
|
43 |
+
### Paired Translation with pix2pix-turbo
|
44 |
+
**Edge to Image**
|
45 |
+
<div>
|
46 |
+
<p align="center">
|
47 |
+
<img src='assets/edge_to_image_results.jpg' align="center" width=800px>
|
48 |
+
</p>
|
49 |
+
</div>
|
50 |
+
|
51 |
+
<!-- **Sketch to Image**
|
52 |
+
TODO -->
|
53 |
+
### Generating Diverse Outputs
|
54 |
+
By varying the input noise map, our method can generate diverse outputs from the same input conditioning.
|
55 |
+
The output style can be controlled by changing the text prompt.
|
56 |
+
<div> <p align="center">
|
57 |
+
<img src='assets/gen_variations.jpg' align="center" width=800px>
|
58 |
+
</p> </div>
|
59 |
+
|
60 |
+
### Unpaired Translation with CycleGAN-Turbo
|
61 |
+
|
62 |
+
**Day to Night**
|
63 |
+
<div> <p align="center">
|
64 |
+
<img src='assets/day2night_results.jpg' align="center" width=800px>
|
65 |
+
</p> </div>
|
66 |
+
|
67 |
+
**Night to Day**
|
68 |
+
<div><p align="center">
|
69 |
+
<img src='assets/night2day_results.jpg' align="center" width=800px>
|
70 |
+
</p> </div>
|
71 |
+
|
72 |
+
**Clear to Rainy**
|
73 |
+
<div>
|
74 |
+
<p align="center">
|
75 |
+
<img src='assets/clear2rainy_results.jpg' align="center" width=800px>
|
76 |
+
</p>
|
77 |
+
</div>
|
78 |
+
|
79 |
+
**Rainy to Clear**
|
80 |
+
<div>
|
81 |
+
<p align="center">
|
82 |
+
<img src='assets/rainy2clear.jpg' align="center" width=800px>
|
83 |
+
</p>
|
84 |
+
</div>
|
85 |
+
<hr>
|
86 |
+
|
87 |
+
|
88 |
+
## Method
|
89 |
+
**Our Generator Architecture:**
|
90 |
+
We tightly integrate three separate modules in the original latent diffusion models into a single end-to-end network with small trainable weights. This architecture allows us to translate the input image x to the output y, while retaining the input scene structure. We use LoRA adapters in each module, introduce skip connections and Zero-Convs between input and output, and retrain the first layer of the U-Net. Blue boxes indicate trainable layers. Semi-transparent layers are frozen. The same generator can be used for various GAN objectives.
|
91 |
+
<div>
|
92 |
+
<p align="center">
|
93 |
+
<img src='assets/method.jpg' align="center" width=900px>
|
94 |
+
</p>
|
95 |
+
</div>
|
96 |
+
|
97 |
+
|
98 |
+
## Getting Started
|
99 |
+
**Environment Setup**
|
100 |
+
- We provide a [conda env file](environment.yml) that contains all the required dependencies.
|
101 |
+
```
|
102 |
+
conda env create -f environment.yaml
|
103 |
+
```
|
104 |
+
- Following this, you can activate the conda environment with the command below.
|
105 |
+
```
|
106 |
+
conda activate img2img-turbo
|
107 |
+
```
|
108 |
+
- Or use virtual environment:
|
109 |
+
```
|
110 |
+
python3 -m venv venv
|
111 |
+
source venv/bin/activate
|
112 |
+
pip install -r requirements.txt
|
113 |
+
```
|
114 |
+
**Paired Image Translation (pix2pix-turbo)**
|
115 |
+
- The following command takes an image file and a prompt as inputs, extracts the canny edges, and saves the results in the directory specified.
|
116 |
+
```bash
|
117 |
+
python src/inference_paired.py --model_name "edge_to_image" \
|
118 |
+
--input_image "assets/examples/bird.png" \
|
119 |
+
--prompt "a blue bird" \
|
120 |
+
--output_dir "outputs"
|
121 |
+
```
|
122 |
+
<table>
|
123 |
+
<th>Input Image</th>
|
124 |
+
<th>Canny Edges</th>
|
125 |
+
<th>Model Output</th>
|
126 |
+
</tr>
|
127 |
+
<tr>
|
128 |
+
<td><img src='assets/examples/bird.png' width="200px"></td>
|
129 |
+
<td><img src='assets/examples/bird_canny.png' width="200px"></td>
|
130 |
+
<td><img src='assets/examples/bird_canny_blue.png' width="200px"></td>
|
131 |
+
</tr>
|
132 |
+
</table>
|
133 |
+
<br>
|
134 |
+
|
135 |
+
- The following command takes a sketch and a prompt as inputs, and saves the results in the directory specified.
|
136 |
+
```bash
|
137 |
+
python src/inference_paired.py --model_name "sketch_to_image_stochastic" \
|
138 |
+
--input_image "assets/examples/sketch_input.png" --gamma 0.4 \
|
139 |
+
--prompt "ethereal fantasy concept art of an asteroid. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy" \
|
140 |
+
--output_dir "outputs"
|
141 |
+
```
|
142 |
+
<table>
|
143 |
+
<th>Input</th>
|
144 |
+
<th>Model Output</th>
|
145 |
+
</tr>
|
146 |
+
<tr>
|
147 |
+
<td><img src='assets/examples/sketch_input.png' width="400px"></td>
|
148 |
+
<td><img src='assets/examples/sketch_output.png' width="400px"></td>
|
149 |
+
</tr>
|
150 |
+
</table>
|
151 |
+
<br>
|
152 |
+
|
153 |
+
**Unpaired Image Translation (CycleGAN-Turbo)**
|
154 |
+
- The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
|
155 |
+
```
|
156 |
+
python src/inference_unpaired.py --model_name "day_to_night" \
|
157 |
+
--input_image "assets/examples/day2night_input.png" --output_dir "outputs"
|
158 |
+
```
|
159 |
+
<table>
|
160 |
+
<th>Input (day)</th>
|
161 |
+
<th>Model Output (night)</th>
|
162 |
+
</tr>
|
163 |
+
<tr>
|
164 |
+
<td><img src='assets/examples/day2night_input.png' width="400px"></td>
|
165 |
+
<td><img src='assets/examples/day2night_output.png' width="400px"></td>
|
166 |
+
</tr>
|
167 |
+
</table>
|
168 |
+
|
169 |
+
- The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
|
170 |
+
```
|
171 |
+
python src/inference_unpaired.py --model_name "night_to_day" \
|
172 |
+
--input_image "assets/examples/night2day_input.png" --output_dir "outputs"
|
173 |
+
```
|
174 |
+
<table>
|
175 |
+
<th>Input (night)</th>
|
176 |
+
<th>Model Output (day)</th>
|
177 |
+
</tr>
|
178 |
+
<tr>
|
179 |
+
<td><img src='assets/examples/night2day_input.png' width="400px"></td>
|
180 |
+
<td><img src='assets/examples/night2day_output.png' width="400px"></td>
|
181 |
+
</tr>
|
182 |
+
</table>
|
183 |
+
|
184 |
+
- The following command takes a **clear** image file as input, and saves the output **rainy** in the directory specified.
|
185 |
+
```
|
186 |
+
python src/inference_unpaired.py --model_name "clear_to_rainy" \
|
187 |
+
--input_image "assets/examples/clear2rainy_input.png" --output_dir "outputs"
|
188 |
+
```
|
189 |
+
<table>
|
190 |
+
<th>Input (clear)</th>
|
191 |
+
<th>Model Output (rainy)</th>
|
192 |
+
</tr>
|
193 |
+
<tr>
|
194 |
+
<td><img src='assets/examples/clear2rainy_input.png' width="400px"></td>
|
195 |
+
<td><img src='assets/examples/clear2rainy_output.png' width="400px"></td>
|
196 |
+
</tr>
|
197 |
+
</table>
|
198 |
+
|
199 |
+
- The following command takes a **rainy** image file as input, and saves the output **clear** in the directory specified.
|
200 |
+
```
|
201 |
+
python src/inference_unpaired.py --model_name "rainy_to_clear" \
|
202 |
+
--input_image "assets/examples/rainy2clear_input.png" --output_dir "outputs"
|
203 |
+
```
|
204 |
+
<table>
|
205 |
+
<th>Input (rainy)</th>
|
206 |
+
<th>Model Output (clear)</th>
|
207 |
+
</tr>
|
208 |
+
<tr>
|
209 |
+
<td><img src='assets/examples/rainy2clear_input.png' width="400px"></td>
|
210 |
+
<td><img src='assets/examples/rainy2clear_output.png' width="400px"></td>
|
211 |
+
</tr>
|
212 |
+
</table>
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
## Gradio Demo
|
217 |
+
- We provide a Gradio demo for the paired image translation tasks.
|
218 |
+
- The following command will launch the sketch to image locally using gradio.
|
219 |
+
```
|
220 |
+
gradio gradio_sketch2image.py
|
221 |
+
```
|
222 |
+
- The following command will launch the canny edge to image gradio demo locally.
|
223 |
+
```
|
224 |
+
gradio gradio_canny2image.py
|
225 |
+
```
|
226 |
+
|
227 |
+
|
228 |
+
## Training with your own data
|
229 |
+
- See the steps [here](docs/training_pix2pix_turbo.md) for training a pix2pix-turbo model on your paired data.
|
230 |
+
- See the steps [here](docs/training_cyclegan_turbo.md) for training a CycleGAN-Turbo model on your unpaired data.
|
231 |
+
|
232 |
+
|
233 |
+
## Acknowledgment
|
234 |
+
Our work uses the Stable Diffusion-Turbo as the base model with the following [LICENSE](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE).
|
assets/cat_2x.gif
ADDED
![]() |
Git LFS Details
|
assets/clear2rainy_results.jpg
ADDED
![]() |
Git LFS Details
|
assets/day2night_results.jpg
ADDED
![]() |
Git LFS Details
|
assets/edge_to_image_results.jpg
ADDED
![]() |
Git LFS Details
|
assets/examples/bird.png
ADDED
![]() |
Git LFS Details
|
assets/examples/bird_canny.png
ADDED
![]() |
assets/examples/bird_canny_blue.png
ADDED
![]() |
assets/examples/circles_inference_input.png
ADDED
![]() |
assets/examples/circles_inference_output.png
ADDED
![]() |
assets/examples/clear2rainy_input.png
ADDED
![]() |
assets/examples/clear2rainy_output.png
ADDED
![]() |
assets/examples/day2night_input.png
ADDED
![]() |
assets/examples/day2night_output.png
ADDED
![]() |
assets/examples/my_horse2zebra_input.jpg
ADDED
![]() |
assets/examples/my_horse2zebra_output.jpg
ADDED
![]() |
assets/examples/night2day_input.png
ADDED
![]() |
assets/examples/night2day_output.png
ADDED
![]() |
assets/examples/rainy2clear_input.png
ADDED
![]() |
assets/examples/rainy2clear_output.png
ADDED
![]() |
assets/examples/sketch_input.png
ADDED
![]() |
assets/examples/sketch_output.png
ADDED
![]() |
assets/examples/training_evaluation.png
ADDED
![]() |
assets/examples/training_evaluation_unpaired.png
ADDED
![]() |
assets/examples/training_step_0.png
ADDED
![]() |
assets/examples/training_step_500.png
ADDED
![]() |
assets/examples/training_step_6000.png
ADDED
![]() |
assets/fish_2x.gif
ADDED
![]() |
Git LFS Details
|
assets/gen_variations.jpg
ADDED
![]() |
Git LFS Details
|
assets/method.jpg
ADDED
![]() |
assets/night2day_results.jpg
ADDED
![]() |
Git LFS Details
|
assets/rainy2clear.jpg
ADDED
![]() |
Git LFS Details
|
assets/teaser_results.jpg
ADDED
![]() |
Git LFS Details
|
docs/training_cyclegan_turbo.md
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training with Unpaired Data (CycleGAN-turbo)
|
2 |
+
Here, we show how to train a CycleGAN-turbo model using unpaired data.
|
3 |
+
We will use the [horse2zebra dataset](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md) introduced by [CycleGAN](https://junyanz.github.io/CycleGAN/) as an example dataset.
|
4 |
+
|
5 |
+
|
6 |
+
### Step 1. Get the Dataset
|
7 |
+
- First download the horse2zebra dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip) using the command below.
|
8 |
+
```
|
9 |
+
bash scripts/download_horse2zebra.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
- Our training scripts expect the dataset to be in the following format:
|
13 |
+
```
|
14 |
+
data
|
15 |
+
├── dataset_name
|
16 |
+
│ ├── train_A
|
17 |
+
│ │ ├── 000000.png
|
18 |
+
│ │ ├── 000001.png
|
19 |
+
│ │ └── ...
|
20 |
+
│ ├── train_B
|
21 |
+
│ │ ├── 000000.png
|
22 |
+
│ │ ├── 000001.png
|
23 |
+
│ │ └── ...
|
24 |
+
│ └── fixed_prompt_a.txt
|
25 |
+
| └── fixed_prompt_b.txt
|
26 |
+
|
|
27 |
+
| ├── test_A
|
28 |
+
│ │ ├── 000000.png
|
29 |
+
│ │ ├── 000001.png
|
30 |
+
│ │ └── ...
|
31 |
+
│ ├── test_B
|
32 |
+
│ │ ├── 000000.png
|
33 |
+
│ │ ├── 000001.png
|
34 |
+
│ │ └── ...
|
35 |
+
```
|
36 |
+
- The `fixed_prompt_a.txt` and `fixed_prompt_b.txt` files contain the **fixed caption** used for the source and target domains respectively.
|
37 |
+
|
38 |
+
|
39 |
+
### Step 2. Train the Model
|
40 |
+
- Initialize the `accelerate` environment with the following command:
|
41 |
+
```
|
42 |
+
accelerate config
|
43 |
+
```
|
44 |
+
|
45 |
+
- Run the following command to train the model.
|
46 |
+
```
|
47 |
+
export NCCL_P2P_DISABLE=1
|
48 |
+
accelerate launch --main_process_port 29501 src/train_cyclegan_turbo.py \
|
49 |
+
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
|
50 |
+
--output_dir="output/cyclegan_turbo/my_horse2zebra" \
|
51 |
+
--dataset_folder "data/my_horse2zebra" \
|
52 |
+
--train_img_prep "resize_286_randomcrop_256x256_hflip" --val_img_prep "no_resize" \
|
53 |
+
--learning_rate="1e-5" --max_train_steps=25000 \
|
54 |
+
--train_batch_size=1 --gradient_accumulation_steps=1 \
|
55 |
+
--report_to "wandb" --tracker_project_name "gparmar_unpaired_h2z_cycle_debug_v2" \
|
56 |
+
--enable_xformers_memory_efficient_attention --validation_steps 250 \
|
57 |
+
--lambda_gan 0.5 --lambda_idt 1 --lambda_cycle 1
|
58 |
+
```
|
59 |
+
|
60 |
+
- Additional optional flags:
|
61 |
+
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
|
62 |
+
|
63 |
+
### Step 3. Monitor the training progress
|
64 |
+
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
|
65 |
+
|
66 |
+
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
|
67 |
+
<div>
|
68 |
+
<p align="center">
|
69 |
+
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
|
70 |
+
</p>
|
71 |
+
</div>
|
72 |
+
|
73 |
+
|
74 |
+
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
|
75 |
+
|
76 |
+
|
77 |
+
### Step 4. Running Inference with the trained models
|
78 |
+
|
79 |
+
- You can run inference using the trained model using the following command:
|
80 |
+
```
|
81 |
+
python src/inference_unpaired.py --model_path "output/cyclegan_turbo/my_horse2zebra/checkpoints/model_1001.pkl" \
|
82 |
+
--input_image "data/my_horse2zebra/test_A/n02381460_20.jpg" \
|
83 |
+
--prompt "picture of a zebra" --direction "a2b" \
|
84 |
+
--output_dir "outputs" --image_prep "no_resize"
|
85 |
+
```
|
86 |
+
|
87 |
+
- The above command should generate the following output:
|
88 |
+
<table>
|
89 |
+
<tr>
|
90 |
+
<th>Model Input</th>
|
91 |
+
<th>Model Output</th>
|
92 |
+
</tr>
|
93 |
+
<tr>
|
94 |
+
<td><img src='../assets/examples/my_horse2zebra_input.jpg' width="200px"></td>
|
95 |
+
<td><img src='../assets/examples/my_horse2zebra_output.jpg' width="200px"></td>
|
96 |
+
</tr>
|
97 |
+
</table>
|
98 |
+
|
docs/training_pix2pix_turbo.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Training with Paired Data (pix2pix-turbo)
|
2 |
+
Here, we show how to train a pix2pix-turbo model using paired data.
|
3 |
+
We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.
|
4 |
+
|
5 |
+
|
6 |
+
### Step 1. Get the Dataset
|
7 |
+
- First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
|
8 |
+
```
|
9 |
+
bash scripts/download_fill50k.sh
|
10 |
+
```
|
11 |
+
|
12 |
+
- Our training scripts expect the dataset to be in the following format:
|
13 |
+
```
|
14 |
+
data
|
15 |
+
├── dataset_name
|
16 |
+
│ ├── train_A
|
17 |
+
│ │ ├── 000000.png
|
18 |
+
│ │ ├── 000001.png
|
19 |
+
│ │ └── ...
|
20 |
+
│ ├── train_B
|
21 |
+
│ │ ├── 000000.png
|
22 |
+
│ │ ├── 000001.png
|
23 |
+
│ │ └── ...
|
24 |
+
│ └── train_prompts.json
|
25 |
+
|
|
26 |
+
| ├── test_A
|
27 |
+
│ │ ├── 000000.png
|
28 |
+
│ │ ├── 000001.png
|
29 |
+
│ │ └── ...
|
30 |
+
│ ├── test_B
|
31 |
+
│ │ ├── 000000.png
|
32 |
+
│ │ ├── 000001.png
|
33 |
+
│ │ └── ...
|
34 |
+
│ └── test_prompts.json
|
35 |
+
```
|
36 |
+
|
37 |
+
|
38 |
+
### Step 2. Train the Model
|
39 |
+
- Initialize the `accelerate` environment with the following command:
|
40 |
+
```
|
41 |
+
accelerate config
|
42 |
+
```
|
43 |
+
|
44 |
+
- Run the following command to train the model.
|
45 |
+
```
|
46 |
+
accelerate launch src/train_pix2pix_turbo.py \
|
47 |
+
--pretrained_model_name_or_path="stabilityai/sd-turbo" \
|
48 |
+
--output_dir="output/pix2pix_turbo/fill50k" \
|
49 |
+
--dataset_folder="data/my_fill50k" \
|
50 |
+
--resolution=512 \
|
51 |
+
--train_batch_size=2 \
|
52 |
+
--enable_xformers_memory_efficient_attention --viz_freq 25 \
|
53 |
+
--track_val_fid \
|
54 |
+
--report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
|
55 |
+
```
|
56 |
+
|
57 |
+
- Additional optional flags:
|
58 |
+
- `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
|
59 |
+
- `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
|
60 |
+
- `--viz_freq`: Frequency of visualizing the results during training.
|
61 |
+
|
62 |
+
### Step 3. Monitor the training progress
|
63 |
+
- You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
|
64 |
+
|
65 |
+
- The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
|
66 |
+
<div>
|
67 |
+
<p align="center">
|
68 |
+
<img src='../assets/examples/training_evaluation.png' align="center" width=800px>
|
69 |
+
</p>
|
70 |
+
</div>
|
71 |
+
|
72 |
+
|
73 |
+
- The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
|
74 |
+
|
75 |
+
- Screenshots of the training progress are shown below:
|
76 |
+
- Step 0:
|
77 |
+
<div>
|
78 |
+
<p align="center">
|
79 |
+
<img src='../assets/examples/training_step_0.png' align="center" width=800px>
|
80 |
+
</p>
|
81 |
+
</div>
|
82 |
+
|
83 |
+
- Step 500:
|
84 |
+
<div>
|
85 |
+
<p align="center">
|
86 |
+
<img src='../assets/examples/training_step_500.png' align="center" width=800px>
|
87 |
+
</p>
|
88 |
+
</div>
|
89 |
+
|
90 |
+
- Step 6000:
|
91 |
+
<div>
|
92 |
+
<p align="center">
|
93 |
+
<img src='../assets/examples/training_step_6000.png' align="center" width=800px>
|
94 |
+
</p>
|
95 |
+
</div>
|
96 |
+
|
97 |
+
|
98 |
+
### Step 4. Running Inference with the trained models
|
99 |
+
|
100 |
+
- You can run inference using the trained model using the following command:
|
101 |
+
```
|
102 |
+
python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
|
103 |
+
--input_image "data/my_fill50k/test_A/40000.png" \
|
104 |
+
--prompt "violet circle with orange background" \
|
105 |
+
--output_dir "outputs"
|
106 |
+
```
|
107 |
+
|
108 |
+
- The above command should generate the following output:
|
109 |
+
<table>
|
110 |
+
<tr>
|
111 |
+
<th>Model Input</th>
|
112 |
+
<th>Model Output</th>
|
113 |
+
</tr>
|
114 |
+
<tr>
|
115 |
+
<td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
|
116 |
+
<td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
|
117 |
+
</tr>
|
118 |
+
</table>
|
environment.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: img2img-turbo
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.10
|
7 |
+
- pip:
|
8 |
+
- clip @ git+https://github.com/openai/CLIP.git
|
9 |
+
- einops>=0.6.1
|
10 |
+
- numpy>=1.24.4
|
11 |
+
- open-clip-torch>=2.20.0
|
12 |
+
- opencv-python==4.6.0.66
|
13 |
+
- pillow>=9.5.0
|
14 |
+
- scipy==1.11.1
|
15 |
+
- timm>=0.9.2
|
16 |
+
- tokenizers
|
17 |
+
- torch>=2.0.1
|
18 |
+
|
19 |
+
- torchaudio>=2.0.2
|
20 |
+
- torchdata==0.6.1
|
21 |
+
- torchmetrics>=1.0.1
|
22 |
+
- torchvision>=0.15.2
|
23 |
+
|
24 |
+
- tqdm>=4.65.0
|
25 |
+
- transformers==4.35.2
|
26 |
+
- urllib3<1.27,>=1.25.4
|
27 |
+
- xformers>=0.0.20
|
28 |
+
- streamlit-keyup==0.2.0
|
29 |
+
- lpips
|
30 |
+
- clean-fid
|
31 |
+
- peft
|
32 |
+
- dominate
|
33 |
+
- diffusers==0.25.1
|
34 |
+
- gradio==3.43.1
|
gradio_canny2image.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
import gradio as gr
|
6 |
+
from src.image_prep import canny_from_pil
|
7 |
+
from src.pix2pix_turbo import Pix2Pix_Turbo
|
8 |
+
|
9 |
+
model = Pix2Pix_Turbo("edge_to_image")
|
10 |
+
|
11 |
+
|
12 |
+
def process(input_image, prompt, low_threshold, high_threshold):
|
13 |
+
# resize to be a multiple of 8
|
14 |
+
new_width = input_image.width - input_image.width % 8
|
15 |
+
new_height = input_image.height - input_image.height % 8
|
16 |
+
input_image = input_image.resize((new_width, new_height))
|
17 |
+
canny = canny_from_pil(input_image, low_threshold, high_threshold)
|
18 |
+
with torch.no_grad():
|
19 |
+
c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
|
20 |
+
output_image = model(c_t, prompt)
|
21 |
+
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
|
22 |
+
# flippy canny values, map all 0s to 1s and 1s to 0s
|
23 |
+
canny_viz = 1 - (np.array(canny) / 255)
|
24 |
+
canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
|
25 |
+
return canny_viz, output_pil
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
# load the model
|
30 |
+
with gr.Blocks() as demo:
|
31 |
+
gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
|
32 |
+
with gr.Row():
|
33 |
+
with gr.Column():
|
34 |
+
input_image = gr.Image(sources="upload", type="pil")
|
35 |
+
prompt = gr.Textbox(label="Prompt")
|
36 |
+
low_threshold = gr.Slider(
|
37 |
+
label="Canny low threshold",
|
38 |
+
minimum=1,
|
39 |
+
maximum=255,
|
40 |
+
value=100,
|
41 |
+
step=10,
|
42 |
+
)
|
43 |
+
high_threshold = gr.Slider(
|
44 |
+
label="Canny high threshold",
|
45 |
+
minimum=1,
|
46 |
+
maximum=255,
|
47 |
+
value=200,
|
48 |
+
step=10,
|
49 |
+
)
|
50 |
+
run_button = gr.Button(value="Run")
|
51 |
+
with gr.Column():
|
52 |
+
result_canny = gr.Image(type="pil")
|
53 |
+
with gr.Column():
|
54 |
+
result_output = gr.Image(type="pil")
|
55 |
+
|
56 |
+
prompt.submit(
|
57 |
+
fn=process,
|
58 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
59 |
+
outputs=[result_canny, result_output],
|
60 |
+
)
|
61 |
+
low_threshold.change(
|
62 |
+
fn=process,
|
63 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
64 |
+
outputs=[result_canny, result_output],
|
65 |
+
)
|
66 |
+
high_threshold.change(
|
67 |
+
fn=process,
|
68 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
69 |
+
outputs=[result_canny, result_output],
|
70 |
+
)
|
71 |
+
run_button.click(
|
72 |
+
fn=process,
|
73 |
+
inputs=[input_image, prompt, low_threshold, high_threshold],
|
74 |
+
outputs=[result_canny, result_output],
|
75 |
+
)
|
76 |
+
|
77 |
+
demo.queue()
|
78 |
+
demo.launch(debug=True, share=False)
|
gradio_sketch2image.py
ADDED
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import base64
|
5 |
+
from io import BytesIO
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
from src.pix2pix_turbo import Pix2Pix_Turbo
|
12 |
+
|
13 |
+
model = Pix2Pix_Turbo("sketch_to_image_stochastic")
|
14 |
+
|
15 |
+
style_list = [
|
16 |
+
{
|
17 |
+
"name": "Cinematic",
|
18 |
+
"prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"name": "3D Model",
|
22 |
+
"prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "Anime",
|
26 |
+
"prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"name": "Digital Art",
|
30 |
+
"prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "Photographic",
|
34 |
+
"prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"name": "Pixel art",
|
38 |
+
"prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"name": "Fantasy art",
|
42 |
+
"prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"name": "Neonpunk",
|
46 |
+
"prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "Manga",
|
50 |
+
"prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
|
51 |
+
},
|
52 |
+
]
|
53 |
+
|
54 |
+
styles = {k["name"]: k["prompt"] for k in style_list}
|
55 |
+
STYLE_NAMES = list(styles.keys())
|
56 |
+
DEFAULT_STYLE_NAME = "Fantasy art"
|
57 |
+
MAX_SEED = np.iinfo(np.int32).max
|
58 |
+
|
59 |
+
|
60 |
+
def pil_image_to_data_uri(img, format="PNG"):
|
61 |
+
buffered = BytesIO()
|
62 |
+
img.save(buffered, format=format)
|
63 |
+
img_str = base64.b64encode(buffered.getvalue()).decode()
|
64 |
+
return f"data:image/{format.lower()};base64,{img_str}"
|
65 |
+
|
66 |
+
|
67 |
+
def run(image, prompt, prompt_template, style_name, seed, val_r):
|
68 |
+
print(f"prompt: {prompt}")
|
69 |
+
print("sketch updated")
|
70 |
+
if image is None:
|
71 |
+
ones = Image.new("L", (512, 512), 255)
|
72 |
+
temp_uri = pil_image_to_data_uri(ones)
|
73 |
+
return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
|
74 |
+
prompt = prompt_template.replace("{prompt}", prompt)
|
75 |
+
image = image.convert("RGB")
|
76 |
+
image_t = F.to_tensor(image) > 0.5
|
77 |
+
print(f"r_val={val_r}, seed={seed}")
|
78 |
+
with torch.no_grad():
|
79 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
80 |
+
torch.manual_seed(seed)
|
81 |
+
B, C, H, W = c_t.shape
|
82 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
83 |
+
output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
|
84 |
+
output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
|
85 |
+
input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
|
86 |
+
output_image_uri = pil_image_to_data_uri(output_pil)
|
87 |
+
return (
|
88 |
+
output_pil,
|
89 |
+
gr.update(link=input_sketch_uri),
|
90 |
+
gr.update(link=output_image_uri),
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
def update_canvas(use_line, use_eraser):
|
95 |
+
if use_eraser:
|
96 |
+
_color = "#ffffff"
|
97 |
+
brush_size = 20
|
98 |
+
if use_line:
|
99 |
+
_color = "#000000"
|
100 |
+
brush_size = 4
|
101 |
+
return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
|
102 |
+
|
103 |
+
|
104 |
+
def upload_sketch(file):
|
105 |
+
_img = Image.open(file.name)
|
106 |
+
_img = _img.convert("L")
|
107 |
+
return gr.update(value=_img, source="upload", interactive=True)
|
108 |
+
|
109 |
+
|
110 |
+
scripts = """
|
111 |
+
async () => {
|
112 |
+
globalThis.theSketchDownloadFunction = () => {
|
113 |
+
console.log("test")
|
114 |
+
var link = document.createElement("a");
|
115 |
+
dataUri = document.getElementById('download_sketch').href
|
116 |
+
link.setAttribute("href", dataUri)
|
117 |
+
link.setAttribute("download", "sketch.png")
|
118 |
+
document.body.appendChild(link); // Required for Firefox
|
119 |
+
link.click();
|
120 |
+
document.body.removeChild(link); // Clean up
|
121 |
+
|
122 |
+
// also call the output download function
|
123 |
+
theOutputDownloadFunction();
|
124 |
+
return false
|
125 |
+
}
|
126 |
+
|
127 |
+
globalThis.theOutputDownloadFunction = () => {
|
128 |
+
console.log("test output download function")
|
129 |
+
var link = document.createElement("a");
|
130 |
+
dataUri = document.getElementById('download_output').href
|
131 |
+
link.setAttribute("href", dataUri);
|
132 |
+
link.setAttribute("download", "output.png");
|
133 |
+
document.body.appendChild(link); // Required for Firefox
|
134 |
+
link.click();
|
135 |
+
document.body.removeChild(link); // Clean up
|
136 |
+
return false
|
137 |
+
}
|
138 |
+
|
139 |
+
globalThis.UNDO_SKETCH_FUNCTION = () => {
|
140 |
+
console.log("undo sketch function")
|
141 |
+
var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
|
142 |
+
// Create a new 'click' event
|
143 |
+
var event = new MouseEvent('click', {
|
144 |
+
'view': window,
|
145 |
+
'bubbles': true,
|
146 |
+
'cancelable': true
|
147 |
+
});
|
148 |
+
button_undo.dispatchEvent(event);
|
149 |
+
}
|
150 |
+
|
151 |
+
globalThis.DELETE_SKETCH_FUNCTION = () => {
|
152 |
+
console.log("delete sketch function")
|
153 |
+
var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
|
154 |
+
// Create a new 'click' event
|
155 |
+
var event = new MouseEvent('click', {
|
156 |
+
'view': window,
|
157 |
+
'bubbles': true,
|
158 |
+
'cancelable': true
|
159 |
+
});
|
160 |
+
button_del.dispatchEvent(event);
|
161 |
+
}
|
162 |
+
|
163 |
+
globalThis.togglePencil = () => {
|
164 |
+
el_pencil = document.getElementById('my-toggle-pencil');
|
165 |
+
el_pencil.classList.toggle('clicked');
|
166 |
+
// simulate a click on the gradio button
|
167 |
+
btn_gradio = document.querySelector("#cb-line > label > input");
|
168 |
+
var event = new MouseEvent('click', {
|
169 |
+
'view': window,
|
170 |
+
'bubbles': true,
|
171 |
+
'cancelable': true
|
172 |
+
});
|
173 |
+
btn_gradio.dispatchEvent(event);
|
174 |
+
if (el_pencil.classList.contains('clicked')) {
|
175 |
+
document.getElementById('my-toggle-eraser').classList.remove('clicked');
|
176 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
177 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
178 |
+
}
|
179 |
+
else {
|
180 |
+
document.getElementById('my-toggle-eraser').classList.add('clicked');
|
181 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
182 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
183 |
+
}
|
184 |
+
}
|
185 |
+
|
186 |
+
globalThis.toggleEraser = () => {
|
187 |
+
element = document.getElementById('my-toggle-eraser');
|
188 |
+
element.classList.toggle('clicked');
|
189 |
+
// simulate a click on the gradio button
|
190 |
+
btn_gradio = document.querySelector("#cb-eraser > label > input");
|
191 |
+
var event = new MouseEvent('click', {
|
192 |
+
'view': window,
|
193 |
+
'bubbles': true,
|
194 |
+
'cancelable': true
|
195 |
+
});
|
196 |
+
btn_gradio.dispatchEvent(event);
|
197 |
+
if (element.classList.contains('clicked')) {
|
198 |
+
document.getElementById('my-toggle-pencil').classList.remove('clicked');
|
199 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "white";
|
200 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "gray";
|
201 |
+
}
|
202 |
+
else {
|
203 |
+
document.getElementById('my-toggle-pencil').classList.add('clicked');
|
204 |
+
document.getElementById('my-div-pencil').style.backgroundColor = "gray";
|
205 |
+
document.getElementById('my-div-eraser').style.backgroundColor = "white";
|
206 |
+
}
|
207 |
+
}
|
208 |
+
}
|
209 |
+
"""
|
210 |
+
|
211 |
+
with gr.Blocks(css="style.css") as demo:
|
212 |
+
|
213 |
+
gr.HTML(
|
214 |
+
"""
|
215 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
216 |
+
<div>
|
217 |
+
<h2><a href="https://github.com/GaParmar/img2img-turbo">One-Step Image Translation with Text-to-Image Models</a></h2>
|
218 |
+
<div>
|
219 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
220 |
+
<a href='https://gauravparmar.com/'>Gaurav Parmar, </a>
|
221 |
+
|
222 |
+
<a href='https://taesung.me/'> Taesung Park,</a>
|
223 |
+
|
224 |
+
<a href='https://www.cs.cmu.edu/~srinivas/'>Srinivasa Narasimhan, </a>
|
225 |
+
|
226 |
+
<a href='https://www.cs.cmu.edu/~junyanz/'> Jun-Yan Zhu </a>
|
227 |
+
</div>
|
228 |
+
</div>
|
229 |
+
</br>
|
230 |
+
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
|
231 |
+
<a href='https://arxiv.org/abs/2403.12036'>
|
232 |
+
<img src="https://img.shields.io/badge/arXiv-2403.12036-red">
|
233 |
+
</a>
|
234 |
+
|
235 |
+
<a href='https://github.com/GaParmar/img2img-turbo'>
|
236 |
+
<img src='https://img.shields.io/badge/github-%23121011.svg'>
|
237 |
+
</a>
|
238 |
+
|
239 |
+
<a href='https://github.com/GaParmar/img2img-turbo/blob/main/LICENSE'>
|
240 |
+
<img src='https://img.shields.io/badge/license-MIT-lightgrey'>
|
241 |
+
</a>
|
242 |
+
</div>
|
243 |
+
</div>
|
244 |
+
</div>
|
245 |
+
<div>
|
246 |
+
</br>
|
247 |
+
</div>
|
248 |
+
"""
|
249 |
+
)
|
250 |
+
|
251 |
+
# these are hidden buttons that are used to trigger the canvas changes
|
252 |
+
line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
|
253 |
+
eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
|
254 |
+
with gr.Row(elem_id="main_row"):
|
255 |
+
with gr.Column(elem_id="column_input"):
|
256 |
+
gr.Markdown("## INPUT", elem_id="input_header")
|
257 |
+
image = gr.Image(
|
258 |
+
source="canvas",
|
259 |
+
tool="color-sketch",
|
260 |
+
type="pil",
|
261 |
+
image_mode="L",
|
262 |
+
invert_colors=True,
|
263 |
+
shape=(512, 512),
|
264 |
+
brush_radius=4,
|
265 |
+
height=440,
|
266 |
+
width=440,
|
267 |
+
brush_color="#000000",
|
268 |
+
interactive=True,
|
269 |
+
show_download_button=True,
|
270 |
+
elem_id="input_image",
|
271 |
+
show_label=False,
|
272 |
+
)
|
273 |
+
download_sketch = gr.Button(
|
274 |
+
"Download sketch", scale=1, elem_id="download_sketch"
|
275 |
+
)
|
276 |
+
|
277 |
+
gr.HTML(
|
278 |
+
"""
|
279 |
+
<div class="button-row">
|
280 |
+
<div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
|
281 |
+
<div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
|
282 |
+
<div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
|
283 |
+
<div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
|
284 |
+
<div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
|
285 |
+
</div>
|
286 |
+
"""
|
287 |
+
)
|
288 |
+
# gr.Markdown("## Prompt", elem_id="tools_header")
|
289 |
+
prompt = gr.Textbox(label="Prompt", value="", show_label=True)
|
290 |
+
with gr.Row():
|
291 |
+
style = gr.Dropdown(
|
292 |
+
label="Style",
|
293 |
+
choices=STYLE_NAMES,
|
294 |
+
value=DEFAULT_STYLE_NAME,
|
295 |
+
scale=1,
|
296 |
+
)
|
297 |
+
prompt_temp = gr.Textbox(
|
298 |
+
label="Prompt Style Template",
|
299 |
+
value=styles[DEFAULT_STYLE_NAME],
|
300 |
+
scale=2,
|
301 |
+
max_lines=1,
|
302 |
+
)
|
303 |
+
|
304 |
+
with gr.Row():
|
305 |
+
val_r = gr.Slider(
|
306 |
+
label="Sketch guidance: ",
|
307 |
+
show_label=True,
|
308 |
+
minimum=0,
|
309 |
+
maximum=1,
|
310 |
+
value=0.4,
|
311 |
+
step=0.01,
|
312 |
+
scale=3,
|
313 |
+
)
|
314 |
+
seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
|
315 |
+
randomize_seed = gr.Button("Random", scale=1, min_width=50)
|
316 |
+
|
317 |
+
with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
|
318 |
+
gr.Markdown("## pix2pix-turbo", elem_id="description")
|
319 |
+
run_button = gr.Button("Run", min_width=50)
|
320 |
+
|
321 |
+
with gr.Column(elem_id="column_output"):
|
322 |
+
gr.Markdown("## OUTPUT", elem_id="output_header")
|
323 |
+
result = gr.Image(
|
324 |
+
label="Result",
|
325 |
+
height=440,
|
326 |
+
width=440,
|
327 |
+
elem_id="output_image",
|
328 |
+
show_label=False,
|
329 |
+
show_download_button=True,
|
330 |
+
)
|
331 |
+
download_output = gr.Button("Download output", elem_id="download_output")
|
332 |
+
gr.Markdown("### Instructions")
|
333 |
+
gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
|
334 |
+
gr.Markdown("**2**. Start sketching")
|
335 |
+
gr.Markdown("**3**. Change the image style using a style template")
|
336 |
+
gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
|
337 |
+
gr.Markdown("**5**. Try different seeds to generate different results")
|
338 |
+
|
339 |
+
eraser.change(
|
340 |
+
fn=lambda x: gr.update(value=not x),
|
341 |
+
inputs=[eraser],
|
342 |
+
outputs=[line],
|
343 |
+
queue=False,
|
344 |
+
api_name=False,
|
345 |
+
).then(update_canvas, [line, eraser], [image])
|
346 |
+
line.change(
|
347 |
+
fn=lambda x: gr.update(value=not x),
|
348 |
+
inputs=[line],
|
349 |
+
outputs=[eraser],
|
350 |
+
queue=False,
|
351 |
+
api_name=False,
|
352 |
+
).then(update_canvas, [line, eraser], [image])
|
353 |
+
|
354 |
+
demo.load(None, None, None, _js=scripts)
|
355 |
+
randomize_seed.click(
|
356 |
+
lambda x: random.randint(0, MAX_SEED),
|
357 |
+
inputs=[],
|
358 |
+
outputs=seed,
|
359 |
+
queue=False,
|
360 |
+
api_name=False,
|
361 |
+
)
|
362 |
+
inputs = [image, prompt, prompt_temp, style, seed, val_r]
|
363 |
+
outputs = [result, download_sketch, download_output]
|
364 |
+
prompt.submit(fn=run, inputs=inputs, outputs=outputs, api_name=False)
|
365 |
+
style.change(
|
366 |
+
lambda x: styles[x],
|
367 |
+
inputs=[style],
|
368 |
+
outputs=[prompt_temp],
|
369 |
+
queue=False,
|
370 |
+
api_name=False,
|
371 |
+
).then(
|
372 |
+
fn=run,
|
373 |
+
inputs=inputs,
|
374 |
+
outputs=outputs,
|
375 |
+
api_name=False,
|
376 |
+
)
|
377 |
+
val_r.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
378 |
+
run_button.click(fn=run, inputs=inputs, outputs=outputs, api_name=False)
|
379 |
+
image.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
|
380 |
+
|
381 |
+
if __name__ == "__main__":
|
382 |
+
demo.queue().launch(debug=True, share=True)
|
python==3.9.8/Lib/site-packages/wheel/cli/tags.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import email.policy
|
4 |
+
import itertools
|
5 |
+
import os
|
6 |
+
from collections.abc import Iterable
|
7 |
+
from email.parser import BytesParser
|
8 |
+
|
9 |
+
from ..wheelfile import WheelFile
|
10 |
+
|
11 |
+
|
12 |
+
def _compute_tags(original_tags: Iterable[str], new_tags: str | None) -> set[str]:
|
13 |
+
"""Add or replace tags. Supports dot-separated tags"""
|
14 |
+
if new_tags is None:
|
15 |
+
return set(original_tags)
|
16 |
+
|
17 |
+
if new_tags.startswith("+"):
|
18 |
+
return {*original_tags, *new_tags[1:].split(".")}
|
19 |
+
|
20 |
+
if new_tags.startswith("-"):
|
21 |
+
return set(original_tags) - set(new_tags[1:].split("."))
|
22 |
+
|
23 |
+
return set(new_tags.split("."))
|
24 |
+
|
25 |
+
|
26 |
+
def tags(
|
27 |
+
wheel: str,
|
28 |
+
python_tags: str | None = None,
|
29 |
+
abi_tags: str | None = None,
|
30 |
+
platform_tags: str | None = None,
|
31 |
+
build_tag: str | None = None,
|
32 |
+
remove: bool = False,
|
33 |
+
) -> str:
|
34 |
+
"""Change the tags on a wheel file.
|
35 |
+
|
36 |
+
The tags are left unchanged if they are not specified. To specify "none",
|
37 |
+
use ["none"]. To append to the previous tags, a tag should start with a
|
38 |
+
"+". If a tag starts with "-", it will be removed from existing tags.
|
39 |
+
Processing is done left to right.
|
40 |
+
|
41 |
+
:param wheel: The paths to the wheels
|
42 |
+
:param python_tags: The Python tags to set
|
43 |
+
:param abi_tags: The ABI tags to set
|
44 |
+
:param platform_tags: The platform tags to set
|
45 |
+
:param build_tag: The build tag to set
|
46 |
+
:param remove: Remove the original wheel
|
47 |
+
"""
|
48 |
+
with WheelFile(wheel, "r") as f:
|
49 |
+
assert f.filename, f"{f.filename} must be available"
|
50 |
+
|
51 |
+
wheel_info = f.read(f.dist_info_path + "/WHEEL")
|
52 |
+
info = BytesParser(policy=email.policy.compat32).parsebytes(wheel_info)
|
53 |
+
|
54 |
+
original_wheel_name = os.path.basename(f.filename)
|
55 |
+
namever = f.parsed_filename.group("namever")
|
56 |
+
build = f.parsed_filename.group("build")
|
57 |
+
original_python_tags = f.parsed_filename.group("pyver").split(".")
|
58 |
+
original_abi_tags = f.parsed_filename.group("abi").split(".")
|
59 |
+
original_plat_tags = f.parsed_filename.group("plat").split(".")
|
60 |
+
|
61 |
+
tags: list[str] = info.get_all("Tag", [])
|
62 |
+
existing_build_tag = info.get("Build")
|
63 |
+
|
64 |
+
impls = {tag.split("-")[0] for tag in tags}
|
65 |
+
abivers = {tag.split("-")[1] for tag in tags}
|
66 |
+
platforms = {tag.split("-")[2] for tag in tags}
|
67 |
+
|
68 |
+
if impls != set(original_python_tags):
|
69 |
+
msg = f"Wheel internal tags {impls!r} != filename tags {original_python_tags!r}"
|
70 |
+
raise AssertionError(msg)
|
71 |
+
|
72 |
+
if abivers != set(original_abi_tags):
|
73 |
+
msg = f"Wheel internal tags {abivers!r} != filename tags {original_abi_tags!r}"
|
74 |
+
raise AssertionError(msg)
|
75 |
+
|
76 |
+
if platforms != set(original_plat_tags):
|
77 |
+
msg = (
|
78 |
+
f"Wheel internal tags {platforms!r} != filename tags {original_plat_tags!r}"
|
79 |
+
)
|
80 |
+
raise AssertionError(msg)
|
81 |
+
|
82 |
+
if existing_build_tag != build:
|
83 |
+
msg = (
|
84 |
+
f"Incorrect filename '{build}' "
|
85 |
+
f"& *.dist-info/WHEEL '{existing_build_tag}' build numbers"
|
86 |
+
)
|
87 |
+
raise AssertionError(msg)
|
88 |
+
|
89 |
+
# Start changing as needed
|
90 |
+
if build_tag is not None:
|
91 |
+
build = build_tag
|
92 |
+
|
93 |
+
final_python_tags = sorted(_compute_tags(original_python_tags, python_tags))
|
94 |
+
final_abi_tags = sorted(_compute_tags(original_abi_tags, abi_tags))
|
95 |
+
final_plat_tags = sorted(_compute_tags(original_plat_tags, platform_tags))
|
96 |
+
|
97 |
+
final_tags = [
|
98 |
+
namever,
|
99 |
+
".".join(final_python_tags),
|
100 |
+
".".join(final_abi_tags),
|
101 |
+
".".join(final_plat_tags),
|
102 |
+
]
|
103 |
+
if build:
|
104 |
+
final_tags.insert(1, build)
|
105 |
+
|
106 |
+
final_wheel_name = "-".join(final_tags) + ".whl"
|
107 |
+
|
108 |
+
if original_wheel_name != final_wheel_name:
|
109 |
+
del info["Tag"], info["Build"]
|
110 |
+
for a, b, c in itertools.product(
|
111 |
+
final_python_tags, final_abi_tags, final_plat_tags
|
112 |
+
):
|
113 |
+
info["Tag"] = f"{a}-{b}-{c}"
|
114 |
+
if build:
|
115 |
+
info["Build"] = build
|
116 |
+
|
117 |
+
original_wheel_path = os.path.join(
|
118 |
+
os.path.dirname(f.filename), original_wheel_name
|
119 |
+
)
|
120 |
+
final_wheel_path = os.path.join(os.path.dirname(f.filename), final_wheel_name)
|
121 |
+
|
122 |
+
with WheelFile(original_wheel_path, "r") as fin, WheelFile(
|
123 |
+
final_wheel_path, "w"
|
124 |
+
) as fout:
|
125 |
+
fout.comment = fin.comment # preserve the comment
|
126 |
+
for item in fin.infolist():
|
127 |
+
if item.is_dir():
|
128 |
+
continue
|
129 |
+
if item.filename == f.dist_info_path + "/RECORD":
|
130 |
+
continue
|
131 |
+
if item.filename == f.dist_info_path + "/WHEEL":
|
132 |
+
fout.writestr(item, info.as_bytes())
|
133 |
+
else:
|
134 |
+
fout.writestr(item, fin.read(item))
|
135 |
+
|
136 |
+
if remove:
|
137 |
+
os.remove(original_wheel_path)
|
138 |
+
|
139 |
+
return final_wheel_name
|
python==3.9.8/conda-meta/history
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
==> 2024-08-08 09:16:40 <==
|
2 |
+
# cmd: C:\ProgramData\miniconda3\Scripts\conda-script.py env create -f environment.yaml --p python==3.9.8
|
3 |
+
# conda version: 24.1.2
|
4 |
+
+defaults/noarch::tzdata-2024a-h04d1e81_0
|
5 |
+
+defaults/win-64::bzip2-1.0.8-h2bbff1b_6
|
6 |
+
+defaults/win-64::ca-certificates-2024.7.2-haa95532_0
|
7 |
+
+defaults/win-64::libffi-3.4.4-hd77b12b_1
|
8 |
+
+defaults/win-64::openssl-3.0.14-h827c3e9_0
|
9 |
+
+defaults/win-64::pip-24.0-py310haa95532_0
|
10 |
+
+defaults/win-64::python-3.10.14-he1021f5_1
|
11 |
+
+defaults/win-64::setuptools-72.1.0-py310haa95532_0
|
12 |
+
+defaults/win-64::sqlite-3.45.3-h2bbff1b_0
|
13 |
+
+defaults/win-64::tk-8.6.14-h0416ee5_0
|
14 |
+
+defaults/win-64::vc-14.2-h2eaa2aa_4
|
15 |
+
+defaults/win-64::vs2015_runtime-14.29.30133-h43f2093_4
|
16 |
+
+defaults/win-64::wheel-0.43.0-py310haa95532_0
|
17 |
+
+defaults/win-64::xz-5.4.6-h8cc25b3_1
|
18 |
+
+defaults/win-64::zlib-1.2.13-h8cc25b3_1
|
19 |
+
# update specs: ['pip', 'python=3.10']
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clip @ git+https://github.com/openai/CLIP.git
|
2 |
+
einops>=0.6.1
|
3 |
+
numpy>=1.24.4
|
4 |
+
open-clip-torch>=2.20.0
|
5 |
+
opencv-python==4.6.0.66
|
6 |
+
pillow>=9.5.0
|
7 |
+
scipy==1.11.1
|
8 |
+
timm>=0.9.2
|
9 |
+
tokenizers
|
10 |
+
torch>=2.0.1
|
11 |
+
|
12 |
+
torchaudio>=2.0.2
|
13 |
+
torchdata==0.6.1
|
14 |
+
torchmetrics>=1.0.1
|
15 |
+
torchvision>=0.15.2
|
16 |
+
|
17 |
+
tqdm>=4.65.0
|
18 |
+
transformers==4.35.2
|
19 |
+
triton==2.0.0
|
20 |
+
urllib3<1.27,>=1.25.4
|
21 |
+
xformers>=0.0.20
|
22 |
+
streamlit-keyup==0.2.0
|
23 |
+
lpips
|
24 |
+
clean-fid
|
25 |
+
peft
|
26 |
+
dominate
|
27 |
+
diffusers==0.25.1
|
28 |
+
gradio==3.43.1
|
scripts/download_fill50k.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p data
|
2 |
+
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip -O data/my_fill50k.zip
|
3 |
+
cd data
|
4 |
+
unzip my_fill50k.zip
|
5 |
+
rm my_fill50k.zip
|
scripts/download_horse2zebra.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p data
|
2 |
+
wget https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip -O data/my_horse2zebra.zip
|
3 |
+
cd data
|
4 |
+
unzip my_horse2zebra.zip
|
5 |
+
rm my_horse2zebra.zip
|
src/cyclegan_turbo.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
8 |
+
from peft import LoraConfig
|
9 |
+
from peft.utils import get_peft_model_state_dict
|
10 |
+
p = "src/"
|
11 |
+
sys.path.append(p)
|
12 |
+
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
|
13 |
+
|
14 |
+
|
15 |
+
class VAE_encode(nn.Module):
|
16 |
+
def __init__(self, vae, vae_b2a=None):
|
17 |
+
super(VAE_encode, self).__init__()
|
18 |
+
self.vae = vae
|
19 |
+
self.vae_b2a = vae_b2a
|
20 |
+
|
21 |
+
def forward(self, x, direction):
|
22 |
+
assert direction in ["a2b", "b2a"]
|
23 |
+
if direction == "a2b":
|
24 |
+
_vae = self.vae
|
25 |
+
else:
|
26 |
+
_vae = self.vae_b2a
|
27 |
+
return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
|
28 |
+
|
29 |
+
|
30 |
+
class VAE_decode(nn.Module):
|
31 |
+
def __init__(self, vae, vae_b2a=None):
|
32 |
+
super(VAE_decode, self).__init__()
|
33 |
+
self.vae = vae
|
34 |
+
self.vae_b2a = vae_b2a
|
35 |
+
|
36 |
+
def forward(self, x, direction):
|
37 |
+
assert direction in ["a2b", "b2a"]
|
38 |
+
if direction == "a2b":
|
39 |
+
_vae = self.vae
|
40 |
+
else:
|
41 |
+
_vae = self.vae_b2a
|
42 |
+
assert _vae.encoder.current_down_blocks is not None
|
43 |
+
_vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
|
44 |
+
x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
|
45 |
+
return x_decoded
|
46 |
+
|
47 |
+
|
48 |
+
def initialize_unet(rank, return_lora_module_names=False):
|
49 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
50 |
+
unet.requires_grad_(False)
|
51 |
+
unet.train()
|
52 |
+
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
|
53 |
+
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
|
54 |
+
for n, p in unet.named_parameters():
|
55 |
+
if "bias" in n or "norm" in n: continue
|
56 |
+
for pattern in l_grep:
|
57 |
+
if pattern in n and ("down_blocks" in n or "conv_in" in n):
|
58 |
+
l_target_modules_encoder.append(n.replace(".weight",""))
|
59 |
+
break
|
60 |
+
elif pattern in n and "up_blocks" in n:
|
61 |
+
l_target_modules_decoder.append(n.replace(".weight",""))
|
62 |
+
break
|
63 |
+
elif pattern in n:
|
64 |
+
l_modules_others.append(n.replace(".weight",""))
|
65 |
+
break
|
66 |
+
lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
|
67 |
+
lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
|
68 |
+
lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
|
69 |
+
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
70 |
+
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
71 |
+
unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
72 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
73 |
+
if return_lora_module_names:
|
74 |
+
return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
|
75 |
+
else:
|
76 |
+
return unet
|
77 |
+
|
78 |
+
|
79 |
+
def initialize_vae(rank=4, return_lora_module_names=False):
|
80 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
81 |
+
vae.requires_grad_(False)
|
82 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
83 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
84 |
+
vae.requires_grad_(True)
|
85 |
+
vae.train()
|
86 |
+
# add the skip connection convs
|
87 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
88 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
89 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
90 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
91 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
|
92 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
|
93 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
|
94 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
|
95 |
+
vae.decoder.ignore_skip = False
|
96 |
+
vae.decoder.gamma = 1
|
97 |
+
l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
|
98 |
+
"conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
|
99 |
+
"skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
|
100 |
+
]
|
101 |
+
vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
|
102 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
103 |
+
if return_lora_module_names:
|
104 |
+
return vae, l_vae_target_modules
|
105 |
+
else:
|
106 |
+
return vae
|
107 |
+
|
108 |
+
|
109 |
+
class CycleGAN_Turbo(torch.nn.Module):
|
110 |
+
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
|
111 |
+
super().__init__()
|
112 |
+
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
|
113 |
+
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
|
114 |
+
self.sched = make_1step_sched()
|
115 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
116 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
117 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
118 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
119 |
+
# add the skip connection convs
|
120 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
121 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
122 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
123 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
124 |
+
vae.decoder.ignore_skip = False
|
125 |
+
self.unet, self.vae = unet, vae
|
126 |
+
if pretrained_name == "day_to_night":
|
127 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
|
128 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
129 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
130 |
+
self.caption = "driving in the night"
|
131 |
+
self.direction = "a2b"
|
132 |
+
elif pretrained_name == "night_to_day":
|
133 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
|
134 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
135 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
136 |
+
self.caption = "driving in the day"
|
137 |
+
self.direction = "b2a"
|
138 |
+
elif pretrained_name == "clear_to_rainy":
|
139 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
|
140 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
141 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
142 |
+
self.caption = "driving in heavy rain"
|
143 |
+
self.direction = "a2b"
|
144 |
+
elif pretrained_name == "rainy_to_clear":
|
145 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
|
146 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
147 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
148 |
+
self.caption = "driving in the day"
|
149 |
+
self.direction = "b2a"
|
150 |
+
|
151 |
+
elif pretrained_path is not None:
|
152 |
+
sd = torch.load(pretrained_path)
|
153 |
+
self.load_ckpt_from_state_dict(sd)
|
154 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
155 |
+
self.caption = None
|
156 |
+
self.direction = None
|
157 |
+
|
158 |
+
self.vae_enc.cuda()
|
159 |
+
self.vae_dec.cuda()
|
160 |
+
self.unet.cuda()
|
161 |
+
|
162 |
+
def load_ckpt_from_state_dict(self, sd):
|
163 |
+
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
|
164 |
+
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
|
165 |
+
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
|
166 |
+
self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
167 |
+
self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
168 |
+
self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
169 |
+
for n, p in self.unet.named_parameters():
|
170 |
+
name_sd = n.replace(".default_encoder.weight", ".weight")
|
171 |
+
if "lora" in n and "default_encoder" in n:
|
172 |
+
p.data.copy_(sd["sd_encoder"][name_sd])
|
173 |
+
for n, p in self.unet.named_parameters():
|
174 |
+
name_sd = n.replace(".default_decoder.weight", ".weight")
|
175 |
+
if "lora" in n and "default_decoder" in n:
|
176 |
+
p.data.copy_(sd["sd_decoder"][name_sd])
|
177 |
+
for n, p in self.unet.named_parameters():
|
178 |
+
name_sd = n.replace(".default_others.weight", ".weight")
|
179 |
+
if "lora" in n and "default_others" in n:
|
180 |
+
p.data.copy_(sd["sd_other"][name_sd])
|
181 |
+
self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
|
182 |
+
|
183 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
184 |
+
self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
185 |
+
self.vae.decoder.gamma = 1
|
186 |
+
self.vae_b2a = copy.deepcopy(self.vae)
|
187 |
+
self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
|
188 |
+
self.vae_enc.load_state_dict(sd["sd_vae_enc"])
|
189 |
+
self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
|
190 |
+
self.vae_dec.load_state_dict(sd["sd_vae_dec"])
|
191 |
+
|
192 |
+
def load_ckpt_from_url(self, url, ckpt_folder):
|
193 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
194 |
+
outf = os.path.join(ckpt_folder, os.path.basename(url))
|
195 |
+
download_url(url, outf)
|
196 |
+
sd = torch.load(outf)
|
197 |
+
self.load_ckpt_from_state_dict(sd)
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
|
201 |
+
B = x.shape[0]
|
202 |
+
assert direction in ["a2b", "b2a"]
|
203 |
+
x_enc = vae_enc(x, direction=direction).to(x.dtype)
|
204 |
+
model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
|
205 |
+
x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
|
206 |
+
x_out_decoded = vae_dec(x_out, direction=direction)
|
207 |
+
return x_out_decoded
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def get_traininable_params(unet, vae_a2b, vae_b2a):
|
211 |
+
# add all unet parameters
|
212 |
+
params_gen = list(unet.conv_in.parameters())
|
213 |
+
unet.conv_in.requires_grad_(True)
|
214 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
215 |
+
for n,p in unet.named_parameters():
|
216 |
+
if "lora" in n and "default" in n:
|
217 |
+
assert p.requires_grad
|
218 |
+
params_gen.append(p)
|
219 |
+
|
220 |
+
# add all vae_a2b parameters
|
221 |
+
for n,p in vae_a2b.named_parameters():
|
222 |
+
if "lora" in n and "vae_skip" in n:
|
223 |
+
assert p.requires_grad
|
224 |
+
params_gen.append(p)
|
225 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
|
226 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
|
227 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
|
228 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
|
229 |
+
|
230 |
+
# add all vae_b2a parameters
|
231 |
+
for n,p in vae_b2a.named_parameters():
|
232 |
+
if "lora" in n and "vae_skip" in n:
|
233 |
+
assert p.requires_grad
|
234 |
+
params_gen.append(p)
|
235 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
|
236 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
|
237 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
|
238 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
|
239 |
+
return params_gen
|
240 |
+
|
241 |
+
def forward(self, x_t, direction=None, caption=None, caption_emb=None):
|
242 |
+
if direction is None:
|
243 |
+
assert self.direction is not None
|
244 |
+
direction = self.direction
|
245 |
+
if caption is None and caption_emb is None:
|
246 |
+
assert self.caption is not None
|
247 |
+
caption = self.caption
|
248 |
+
if caption_emb is not None:
|
249 |
+
caption_enc = caption_emb
|
250 |
+
else:
|
251 |
+
caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
|
252 |
+
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
|
253 |
+
caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
|
254 |
+
return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
|
src/image_prep.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def canny_from_pil(image, low_threshold=100, high_threshold=200):
|
7 |
+
image = np.array(image)
|
8 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
9 |
+
image = image[:, :, None]
|
10 |
+
image = np.concatenate([image, image, image], axis=2)
|
11 |
+
control_image = Image.fromarray(image)
|
12 |
+
return control_image
|
src/inference_paired.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
from pix2pix_turbo import Pix2Pix_Turbo
|
9 |
+
from image_prep import canny_from_pil
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
14 |
+
parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
|
15 |
+
parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
|
16 |
+
parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
|
17 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
18 |
+
parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
|
19 |
+
parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
|
20 |
+
parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
|
21 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# only one of model_name and model_path should be provided
|
25 |
+
if args.model_name == '' != args.model_path == '':
|
26 |
+
raise ValueError('Either model_name or model_path should be provided')
|
27 |
+
|
28 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
29 |
+
|
30 |
+
# initialize the model
|
31 |
+
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
32 |
+
model.set_eval()
|
33 |
+
|
34 |
+
# make sure that the input image is a multiple of 8
|
35 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
36 |
+
new_width = input_image.width - input_image.width % 8
|
37 |
+
new_height = input_image.height - input_image.height % 8
|
38 |
+
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
|
39 |
+
bname = os.path.basename(args.input_image)
|
40 |
+
|
41 |
+
# translate the image
|
42 |
+
with torch.no_grad():
|
43 |
+
if args.model_name == 'edge_to_image':
|
44 |
+
canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
|
45 |
+
canny_viz_inv = Image.fromarray(255 - np.array(canny))
|
46 |
+
canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
|
47 |
+
c_t = F.to_tensor(canny).unsqueeze(0).cuda()
|
48 |
+
output_image = model(c_t, args.prompt)
|
49 |
+
|
50 |
+
elif args.model_name == 'sketch_to_image_stochastic':
|
51 |
+
image_t = F.to_tensor(input_image) < 0.5
|
52 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
53 |
+
torch.manual_seed(args.seed)
|
54 |
+
B, C, H, W = c_t.shape
|
55 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
56 |
+
output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
|
57 |
+
|
58 |
+
else:
|
59 |
+
c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
|
60 |
+
output_image = model(c_t, args.prompt)
|
61 |
+
|
62 |
+
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
|
63 |
+
|
64 |
+
# save the output image
|
65 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|
src/inference_unpaired.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from cyclegan_turbo import CycleGAN_Turbo
|
7 |
+
from my_utils.training_utils import build_transform
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
13 |
+
parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
|
14 |
+
parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
|
15 |
+
parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
|
16 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
17 |
+
parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
|
18 |
+
parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
# only one of model_name and model_path should be provided
|
22 |
+
if args.model_name is None != args.model_path is None:
|
23 |
+
raise ValueError('Either model_name or model_path should be provided')
|
24 |
+
|
25 |
+
if args.model_path is not None and args.prompt is None:
|
26 |
+
raise ValueError('prompt is required when loading a custom model_path.')
|
27 |
+
|
28 |
+
if args.model_name is not None:
|
29 |
+
assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
|
30 |
+
assert args.direction is None, 'direction is not required when loading a pretrained model.'
|
31 |
+
|
32 |
+
# initialize the model
|
33 |
+
model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
34 |
+
model.eval()
|
35 |
+
model.unet.enable_xformers_memory_efficient_attention()
|
36 |
+
|
37 |
+
T_val = build_transform(args.image_prep)
|
38 |
+
|
39 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
40 |
+
# translate the image
|
41 |
+
with torch.no_grad():
|
42 |
+
input_img = T_val(input_image)
|
43 |
+
x_t = transforms.ToTensor()(input_img)
|
44 |
+
x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
|
45 |
+
output = model(x_t, direction=args.direction, caption=args.prompt)
|
46 |
+
|
47 |
+
output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
|
48 |
+
output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
|
49 |
+
|
50 |
+
# save the output image
|
51 |
+
bname = os.path.basename(args.input_image)
|
52 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
53 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|