Inmental commited on
Commit
38d88fc
·
verified ·
1 Parent(s): 26eda4d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +11 -0
  2. .gitignore +171 -0
  3. LICENSE +21 -0
  4. README.md +234 -12
  5. assets/cat_2x.gif +3 -0
  6. assets/clear2rainy_results.jpg +3 -0
  7. assets/day2night_results.jpg +3 -0
  8. assets/edge_to_image_results.jpg +3 -0
  9. assets/examples/bird.png +3 -0
  10. assets/examples/bird_canny.png +0 -0
  11. assets/examples/bird_canny_blue.png +0 -0
  12. assets/examples/circles_inference_input.png +0 -0
  13. assets/examples/circles_inference_output.png +0 -0
  14. assets/examples/clear2rainy_input.png +0 -0
  15. assets/examples/clear2rainy_output.png +0 -0
  16. assets/examples/day2night_input.png +0 -0
  17. assets/examples/day2night_output.png +0 -0
  18. assets/examples/my_horse2zebra_input.jpg +0 -0
  19. assets/examples/my_horse2zebra_output.jpg +0 -0
  20. assets/examples/night2day_input.png +0 -0
  21. assets/examples/night2day_output.png +0 -0
  22. assets/examples/rainy2clear_input.png +0 -0
  23. assets/examples/rainy2clear_output.png +0 -0
  24. assets/examples/sketch_input.png +0 -0
  25. assets/examples/sketch_output.png +0 -0
  26. assets/examples/training_evaluation.png +0 -0
  27. assets/examples/training_evaluation_unpaired.png +0 -0
  28. assets/examples/training_step_0.png +0 -0
  29. assets/examples/training_step_500.png +0 -0
  30. assets/examples/training_step_6000.png +0 -0
  31. assets/fish_2x.gif +3 -0
  32. assets/gen_variations.jpg +3 -0
  33. assets/method.jpg +0 -0
  34. assets/night2day_results.jpg +3 -0
  35. assets/rainy2clear.jpg +3 -0
  36. assets/teaser_results.jpg +3 -0
  37. docs/training_cyclegan_turbo.md +98 -0
  38. docs/training_pix2pix_turbo.md +118 -0
  39. environment.yaml +34 -0
  40. gradio_canny2image.py +78 -0
  41. gradio_sketch2image.py +382 -0
  42. python==3.9.8/Lib/site-packages/wheel/cli/tags.py +139 -0
  43. python==3.9.8/conda-meta/history +19 -0
  44. requirements.txt +28 -0
  45. scripts/download_fill50k.sh +5 -0
  46. scripts/download_horse2zebra.sh +5 -0
  47. src/cyclegan_turbo.py +254 -0
  48. src/image_prep.py +12 -0
  49. src/inference_paired.py +65 -0
  50. src/inference_unpaired.py +53 -0
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cat_2x.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/clear2rainy_results.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/day2night_results.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/edge_to_image_results.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/examples/bird.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/fish_2x.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/gen_variations.jpg filter=lfs diff=lfs merge=lfs -text
43
+ assets/night2day_results.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/rainy2clear.jpg filter=lfs diff=lfs merge=lfs -text
45
+ assets/teaser_results.jpg filter=lfs diff=lfs merge=lfs -text
46
+ triton-2.1.0-cp310-cp310-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ single_step_translation/
2
+
3
+ # Byte-compiled / optimized / DLL files
4
+ __pycache__/
5
+ *.py[cod]
6
+ *$py.class
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ *.egg-info/
27
+ .installed.cfg
28
+ *.egg
29
+ MANIFEST
30
+
31
+ # PyInstaller
32
+ # Usually these files are written by a python script from a template
33
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
34
+ *.manifest
35
+ *.spec
36
+
37
+ # Installer logs
38
+ pip-log.txt
39
+ pip-delete-this-directory.txt
40
+
41
+ # Unit test / coverage reports
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+ .coverage
46
+ .coverage.*
47
+ .cache
48
+ nosetests.xml
49
+ coverage.xml
50
+ *.cover
51
+ *.py,cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+ cover/
55
+
56
+ # Translations
57
+ *.mo
58
+ *.pot
59
+
60
+ # Django stuff:
61
+ *.log
62
+ local_settings.py
63
+ db.sqlite3
64
+ db.sqlite3-journal
65
+
66
+ # Flask stuff:
67
+ instance/
68
+ .webassets-cache
69
+
70
+ # Scrapy stuff:
71
+ .scrapy
72
+
73
+ # Sphinx documentation
74
+ docs/_build/
75
+
76
+ # PyBuilder
77
+ .pybuilder/
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ # For a library or package, you might want to ignore these files since the code is
89
+ # intended to run in multiple environments; otherwise, check them in:
90
+ # .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # poetry
100
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
102
+ # commonly ignored for libraries.
103
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104
+ #poetry.lock
105
+
106
+ # pdm
107
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108
+ #pdm.lock
109
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110
+ # in version control.
111
+ # https://pdm.fming.dev/#use-with-ide
112
+ .pdm.toml
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+ single_step_translation
164
+ gradio
165
+ checkpoints/
166
+ img2img-turbo-sketch
167
+ outputs/
168
+ outputs/bird.png
169
+ data
170
+ wandb
171
+ output/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 img-to-img-turbo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,234 @@
1
- ---
2
- title: Img2img Turbo
3
- emoji: 🏃
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 4.41.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: img2img-turbo
3
+ app_file: gradio_sketch2image.py
4
+ sdk: gradio
5
+ sdk_version: 3.43.1
6
+ ---
7
+ # img2img-turbo
8
+
9
+ [**Paper**](https://arxiv.org/abs/2403.12036) | [**Sketch2Image Demo**](https://huggingface.co/spaces/gparmar/img2img-turbo-sketch)
10
+ #### **Quick start:** [**Running Locally**](#getting-started) | [**Gradio (locally hosted)**](#gradio-demo) | [**Training**](#training-with-your-own-data)
11
+
12
+ ### Cat Sketching
13
+ <p align="left" >
14
+ <img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/cat_2x.gif" width="800" />
15
+ </p>
16
+
17
+ ### Fish Sketching
18
+ <p align="left">
19
+ <img src="https://raw.githubusercontent.com/GaParmar/img2img-turbo/main/assets/fish_2x.gif" width="800" />
20
+ </p>
21
+
22
+
23
+ We propose a general method for adapting a single-step diffusion model, such as SD-Turbo, to new tasks and domains through adversarial learning. This enables us to leverage the internal knowledge of pre-trained diffusion models while achieving efficient inference (e.g., for 512x512 images, 0.29 seconds on A6000 and 0.11 seconds on A100).
24
+
25
+ Our one-step conditional models **CycleGAN-Turbo** and **pix2pix-turbo** can perform various image-to-image translation tasks for both unpaired and paired settings. CycleGAN-Turbo outperforms existing GAN-based and diffusion-based methods, while pix2pix-turbo is on par with recent works such as ControlNet for Sketch2Photo and Edge2Image, but with one-step inference.
26
+
27
+ [One-Step Image Translation with Text-to-Image Models](https://arxiv.org/abs/2403.12036)<br>
28
+ [Gaurav Parmar](https://gauravparmar.com/), [Taesung Park](https://taesung.me/), [Srinivasa Narasimhan](https://www.cs.cmu.edu/~srinivas/), [Jun-Yan Zhu](https://github.com/junyanz/)<br>
29
+ CMU and Adobe, arXiv 2403.12036
30
+
31
+ <br>
32
+ <div>
33
+ <p align="center">
34
+ <img src='assets/teaser_results.jpg' align="center" width=1000px>
35
+ </p>
36
+ </div>
37
+
38
+
39
+
40
+
41
+ ## Results
42
+
43
+ ### Paired Translation with pix2pix-turbo
44
+ **Edge to Image**
45
+ <div>
46
+ <p align="center">
47
+ <img src='assets/edge_to_image_results.jpg' align="center" width=800px>
48
+ </p>
49
+ </div>
50
+
51
+ <!-- **Sketch to Image**
52
+ TODO -->
53
+ ### Generating Diverse Outputs
54
+ By varying the input noise map, our method can generate diverse outputs from the same input conditioning.
55
+ The output style can be controlled by changing the text prompt.
56
+ <div> <p align="center">
57
+ <img src='assets/gen_variations.jpg' align="center" width=800px>
58
+ </p> </div>
59
+
60
+ ### Unpaired Translation with CycleGAN-Turbo
61
+
62
+ **Day to Night**
63
+ <div> <p align="center">
64
+ <img src='assets/day2night_results.jpg' align="center" width=800px>
65
+ </p> </div>
66
+
67
+ **Night to Day**
68
+ <div><p align="center">
69
+ <img src='assets/night2day_results.jpg' align="center" width=800px>
70
+ </p> </div>
71
+
72
+ **Clear to Rainy**
73
+ <div>
74
+ <p align="center">
75
+ <img src='assets/clear2rainy_results.jpg' align="center" width=800px>
76
+ </p>
77
+ </div>
78
+
79
+ **Rainy to Clear**
80
+ <div>
81
+ <p align="center">
82
+ <img src='assets/rainy2clear.jpg' align="center" width=800px>
83
+ </p>
84
+ </div>
85
+ <hr>
86
+
87
+
88
+ ## Method
89
+ **Our Generator Architecture:**
90
+ We tightly integrate three separate modules in the original latent diffusion models into a single end-to-end network with small trainable weights. This architecture allows us to translate the input image x to the output y, while retaining the input scene structure. We use LoRA adapters in each module, introduce skip connections and Zero-Convs between input and output, and retrain the first layer of the U-Net. Blue boxes indicate trainable layers. Semi-transparent layers are frozen. The same generator can be used for various GAN objectives.
91
+ <div>
92
+ <p align="center">
93
+ <img src='assets/method.jpg' align="center" width=900px>
94
+ </p>
95
+ </div>
96
+
97
+
98
+ ## Getting Started
99
+ **Environment Setup**
100
+ - We provide a [conda env file](environment.yml) that contains all the required dependencies.
101
+ ```
102
+ conda env create -f environment.yaml
103
+ ```
104
+ - Following this, you can activate the conda environment with the command below.
105
+ ```
106
+ conda activate img2img-turbo
107
+ ```
108
+ - Or use virtual environment:
109
+ ```
110
+ python3 -m venv venv
111
+ source venv/bin/activate
112
+ pip install -r requirements.txt
113
+ ```
114
+ **Paired Image Translation (pix2pix-turbo)**
115
+ - The following command takes an image file and a prompt as inputs, extracts the canny edges, and saves the results in the directory specified.
116
+ ```bash
117
+ python src/inference_paired.py --model_name "edge_to_image" \
118
+ --input_image "assets/examples/bird.png" \
119
+ --prompt "a blue bird" \
120
+ --output_dir "outputs"
121
+ ```
122
+ <table>
123
+ <th>Input Image</th>
124
+ <th>Canny Edges</th>
125
+ <th>Model Output</th>
126
+ </tr>
127
+ <tr>
128
+ <td><img src='assets/examples/bird.png' width="200px"></td>
129
+ <td><img src='assets/examples/bird_canny.png' width="200px"></td>
130
+ <td><img src='assets/examples/bird_canny_blue.png' width="200px"></td>
131
+ </tr>
132
+ </table>
133
+ <br>
134
+
135
+ - The following command takes a sketch and a prompt as inputs, and saves the results in the directory specified.
136
+ ```bash
137
+ python src/inference_paired.py --model_name "sketch_to_image_stochastic" \
138
+ --input_image "assets/examples/sketch_input.png" --gamma 0.4 \
139
+ --prompt "ethereal fantasy concept art of an asteroid. magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy" \
140
+ --output_dir "outputs"
141
+ ```
142
+ <table>
143
+ <th>Input</th>
144
+ <th>Model Output</th>
145
+ </tr>
146
+ <tr>
147
+ <td><img src='assets/examples/sketch_input.png' width="400px"></td>
148
+ <td><img src='assets/examples/sketch_output.png' width="400px"></td>
149
+ </tr>
150
+ </table>
151
+ <br>
152
+
153
+ **Unpaired Image Translation (CycleGAN-Turbo)**
154
+ - The following command takes a **day** image file as input, and saves the output **night** in the directory specified.
155
+ ```
156
+ python src/inference_unpaired.py --model_name "day_to_night" \
157
+ --input_image "assets/examples/day2night_input.png" --output_dir "outputs"
158
+ ```
159
+ <table>
160
+ <th>Input (day)</th>
161
+ <th>Model Output (night)</th>
162
+ </tr>
163
+ <tr>
164
+ <td><img src='assets/examples/day2night_input.png' width="400px"></td>
165
+ <td><img src='assets/examples/day2night_output.png' width="400px"></td>
166
+ </tr>
167
+ </table>
168
+
169
+ - The following command takes a **night** image file as input, and saves the output **day** in the directory specified.
170
+ ```
171
+ python src/inference_unpaired.py --model_name "night_to_day" \
172
+ --input_image "assets/examples/night2day_input.png" --output_dir "outputs"
173
+ ```
174
+ <table>
175
+ <th>Input (night)</th>
176
+ <th>Model Output (day)</th>
177
+ </tr>
178
+ <tr>
179
+ <td><img src='assets/examples/night2day_input.png' width="400px"></td>
180
+ <td><img src='assets/examples/night2day_output.png' width="400px"></td>
181
+ </tr>
182
+ </table>
183
+
184
+ - The following command takes a **clear** image file as input, and saves the output **rainy** in the directory specified.
185
+ ```
186
+ python src/inference_unpaired.py --model_name "clear_to_rainy" \
187
+ --input_image "assets/examples/clear2rainy_input.png" --output_dir "outputs"
188
+ ```
189
+ <table>
190
+ <th>Input (clear)</th>
191
+ <th>Model Output (rainy)</th>
192
+ </tr>
193
+ <tr>
194
+ <td><img src='assets/examples/clear2rainy_input.png' width="400px"></td>
195
+ <td><img src='assets/examples/clear2rainy_output.png' width="400px"></td>
196
+ </tr>
197
+ </table>
198
+
199
+ - The following command takes a **rainy** image file as input, and saves the output **clear** in the directory specified.
200
+ ```
201
+ python src/inference_unpaired.py --model_name "rainy_to_clear" \
202
+ --input_image "assets/examples/rainy2clear_input.png" --output_dir "outputs"
203
+ ```
204
+ <table>
205
+ <th>Input (rainy)</th>
206
+ <th>Model Output (clear)</th>
207
+ </tr>
208
+ <tr>
209
+ <td><img src='assets/examples/rainy2clear_input.png' width="400px"></td>
210
+ <td><img src='assets/examples/rainy2clear_output.png' width="400px"></td>
211
+ </tr>
212
+ </table>
213
+
214
+
215
+
216
+ ## Gradio Demo
217
+ - We provide a Gradio demo for the paired image translation tasks.
218
+ - The following command will launch the sketch to image locally using gradio.
219
+ ```
220
+ gradio gradio_sketch2image.py
221
+ ```
222
+ - The following command will launch the canny edge to image gradio demo locally.
223
+ ```
224
+ gradio gradio_canny2image.py
225
+ ```
226
+
227
+
228
+ ## Training with your own data
229
+ - See the steps [here](docs/training_pix2pix_turbo.md) for training a pix2pix-turbo model on your paired data.
230
+ - See the steps [here](docs/training_cyclegan_turbo.md) for training a CycleGAN-Turbo model on your unpaired data.
231
+
232
+
233
+ ## Acknowledgment
234
+ Our work uses the Stable Diffusion-Turbo as the base model with the following [LICENSE](https://huggingface.co/stabilityai/sd-turbo/blob/main/LICENSE).
assets/cat_2x.gif ADDED

Git LFS Details

  • SHA256: 65a49403cf594d7b5300547edded6794e1306b61fb5f6837a96320a17954e826
  • Pointer size: 132 Bytes
  • Size of remote file: 4.63 MB
assets/clear2rainy_results.jpg ADDED

Git LFS Details

  • SHA256: f8b03789185cdb546080d0a3173e1e7054a4a013c2f3581d4d69fb4f99fe94d2
  • Pointer size: 132 Bytes
  • Size of remote file: 2.87 MB
assets/day2night_results.jpg ADDED

Git LFS Details

  • SHA256: 152448e2de3e09184f34e2d4bf8f41af02669fb6dafd77f4994a5da3b50410bf
  • Pointer size: 132 Bytes
  • Size of remote file: 2.91 MB
assets/edge_to_image_results.jpg ADDED

Git LFS Details

  • SHA256: c0e900c2fe954443b87c8643980c287ff91066a5adb21fbec75595c00a4ab615
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
assets/examples/bird.png ADDED

Git LFS Details

  • SHA256: cad49fc7d3071b2bcd078bc8dde365f8fa62eaa6d43705fd50c212794a3aac35
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
assets/examples/bird_canny.png ADDED
assets/examples/bird_canny_blue.png ADDED
assets/examples/circles_inference_input.png ADDED
assets/examples/circles_inference_output.png ADDED
assets/examples/clear2rainy_input.png ADDED
assets/examples/clear2rainy_output.png ADDED
assets/examples/day2night_input.png ADDED
assets/examples/day2night_output.png ADDED
assets/examples/my_horse2zebra_input.jpg ADDED
assets/examples/my_horse2zebra_output.jpg ADDED
assets/examples/night2day_input.png ADDED
assets/examples/night2day_output.png ADDED
assets/examples/rainy2clear_input.png ADDED
assets/examples/rainy2clear_output.png ADDED
assets/examples/sketch_input.png ADDED
assets/examples/sketch_output.png ADDED
assets/examples/training_evaluation.png ADDED
assets/examples/training_evaluation_unpaired.png ADDED
assets/examples/training_step_0.png ADDED
assets/examples/training_step_500.png ADDED
assets/examples/training_step_6000.png ADDED
assets/fish_2x.gif ADDED

Git LFS Details

  • SHA256: 9668ef45316f92d7c36db1e6d1854d2d413a2d87b32d73027149aeb02cc94e9d
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
assets/gen_variations.jpg ADDED

Git LFS Details

  • SHA256: f9443d34ae70cc7d6d5123f7517b7f6e601ba6a59fedd63935e8dcd2dbf507e7
  • Pointer size: 132 Bytes
  • Size of remote file: 3.33 MB
assets/method.jpg ADDED
assets/night2day_results.jpg ADDED

Git LFS Details

  • SHA256: 2c2e0c3e5673e803482d881ab4df66e4e3103803e52daf48da43fb398742a3e8
  • Pointer size: 132 Bytes
  • Size of remote file: 2.37 MB
assets/rainy2clear.jpg ADDED

Git LFS Details

  • SHA256: ba435223d2c72430a9defeb7da94d43af9ddf67c32f11beb78c463f6a95347f5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.49 MB
assets/teaser_results.jpg ADDED

Git LFS Details

  • SHA256: 55f14cff3825bf475ed7cf3847182a9689d4e7745204acbcd6ae8023d855e9ea
  • Pointer size: 132 Bytes
  • Size of remote file: 2.06 MB
docs/training_cyclegan_turbo.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training with Unpaired Data (CycleGAN-turbo)
2
+ Here, we show how to train a CycleGAN-turbo model using unpaired data.
3
+ We will use the [horse2zebra dataset](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/datasets.md) introduced by [CycleGAN](https://junyanz.github.io/CycleGAN/) as an example dataset.
4
+
5
+
6
+ ### Step 1. Get the Dataset
7
+ - First download the horse2zebra dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip) using the command below.
8
+ ```
9
+ bash scripts/download_horse2zebra.sh
10
+ ```
11
+
12
+ - Our training scripts expect the dataset to be in the following format:
13
+ ```
14
+ data
15
+ ├── dataset_name
16
+ │ ├── train_A
17
+ │ │ ├── 000000.png
18
+ │ │ ├── 000001.png
19
+ │ │ └── ...
20
+ │ ├── train_B
21
+ │ │ ├── 000000.png
22
+ │ │ ├── 000001.png
23
+ │ │ └── ...
24
+ │ └── fixed_prompt_a.txt
25
+ | └── fixed_prompt_b.txt
26
+ |
27
+ | ├── test_A
28
+ │ │ ├── 000000.png
29
+ │ │ ├── 000001.png
30
+ │ │ └── ...
31
+ │ ├── test_B
32
+ │ │ ├── 000000.png
33
+ │ │ ├── 000001.png
34
+ │ │ └── ...
35
+ ```
36
+ - The `fixed_prompt_a.txt` and `fixed_prompt_b.txt` files contain the **fixed caption** used for the source and target domains respectively.
37
+
38
+
39
+ ### Step 2. Train the Model
40
+ - Initialize the `accelerate` environment with the following command:
41
+ ```
42
+ accelerate config
43
+ ```
44
+
45
+ - Run the following command to train the model.
46
+ ```
47
+ export NCCL_P2P_DISABLE=1
48
+ accelerate launch --main_process_port 29501 src/train_cyclegan_turbo.py \
49
+ --pretrained_model_name_or_path="stabilityai/sd-turbo" \
50
+ --output_dir="output/cyclegan_turbo/my_horse2zebra" \
51
+ --dataset_folder "data/my_horse2zebra" \
52
+ --train_img_prep "resize_286_randomcrop_256x256_hflip" --val_img_prep "no_resize" \
53
+ --learning_rate="1e-5" --max_train_steps=25000 \
54
+ --train_batch_size=1 --gradient_accumulation_steps=1 \
55
+ --report_to "wandb" --tracker_project_name "gparmar_unpaired_h2z_cycle_debug_v2" \
56
+ --enable_xformers_memory_efficient_attention --validation_steps 250 \
57
+ --lambda_gan 0.5 --lambda_idt 1 --lambda_cycle 1
58
+ ```
59
+
60
+ - Additional optional flags:
61
+ - `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
62
+
63
+ ### Step 3. Monitor the training progress
64
+ - You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
65
+
66
+ - The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
67
+ <div>
68
+ <p align="center">
69
+ <img src='../assets/examples/training_evaluation.png' align="center" width=800px>
70
+ </p>
71
+ </div>
72
+
73
+
74
+ - The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
75
+
76
+
77
+ ### Step 4. Running Inference with the trained models
78
+
79
+ - You can run inference using the trained model using the following command:
80
+ ```
81
+ python src/inference_unpaired.py --model_path "output/cyclegan_turbo/my_horse2zebra/checkpoints/model_1001.pkl" \
82
+ --input_image "data/my_horse2zebra/test_A/n02381460_20.jpg" \
83
+ --prompt "picture of a zebra" --direction "a2b" \
84
+ --output_dir "outputs" --image_prep "no_resize"
85
+ ```
86
+
87
+ - The above command should generate the following output:
88
+ <table>
89
+ <tr>
90
+ <th>Model Input</th>
91
+ <th>Model Output</th>
92
+ </tr>
93
+ <tr>
94
+ <td><img src='../assets/examples/my_horse2zebra_input.jpg' width="200px"></td>
95
+ <td><img src='../assets/examples/my_horse2zebra_output.jpg' width="200px"></td>
96
+ </tr>
97
+ </table>
98
+
docs/training_pix2pix_turbo.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training with Paired Data (pix2pix-turbo)
2
+ Here, we show how to train a pix2pix-turbo model using paired data.
3
+ We will use the [Fill50k dataset](https://github.com/lllyasviel/ControlNet/blob/main/docs/train.md) used by [ControlNet](https://github.com/lllyasviel/ControlNet) as an example dataset.
4
+
5
+
6
+ ### Step 1. Get the Dataset
7
+ - First download a modified Fill50k dataset from [here](https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip) using the command below.
8
+ ```
9
+ bash scripts/download_fill50k.sh
10
+ ```
11
+
12
+ - Our training scripts expect the dataset to be in the following format:
13
+ ```
14
+ data
15
+ ├── dataset_name
16
+ │ ├── train_A
17
+ │ │ ├── 000000.png
18
+ │ │ ├── 000001.png
19
+ │ │ └── ...
20
+ │ ├── train_B
21
+ │ │ ├── 000000.png
22
+ │ │ ├── 000001.png
23
+ │ │ └── ...
24
+ │ └── train_prompts.json
25
+ |
26
+ | ├── test_A
27
+ │ │ ├── 000000.png
28
+ │ │ ├── 000001.png
29
+ │ │ └── ...
30
+ │ ├── test_B
31
+ │ │ ├── 000000.png
32
+ │ │ ├── 000001.png
33
+ │ │ └── ...
34
+ │ └── test_prompts.json
35
+ ```
36
+
37
+
38
+ ### Step 2. Train the Model
39
+ - Initialize the `accelerate` environment with the following command:
40
+ ```
41
+ accelerate config
42
+ ```
43
+
44
+ - Run the following command to train the model.
45
+ ```
46
+ accelerate launch src/train_pix2pix_turbo.py \
47
+ --pretrained_model_name_or_path="stabilityai/sd-turbo" \
48
+ --output_dir="output/pix2pix_turbo/fill50k" \
49
+ --dataset_folder="data/my_fill50k" \
50
+ --resolution=512 \
51
+ --train_batch_size=2 \
52
+ --enable_xformers_memory_efficient_attention --viz_freq 25 \
53
+ --track_val_fid \
54
+ --report_to "wandb" --tracker_project_name "pix2pix_turbo_fill50k"
55
+ ```
56
+
57
+ - Additional optional flags:
58
+ - `--track_val_fid`: Track FID score on the validation set using the [Clean-FID](https://github.com/GaParmar/clean-fid) implementation.
59
+ - `--enable_xformers_memory_efficient_attention`: Enable memory-efficient attention in the model.
60
+ - `--viz_freq`: Frequency of visualizing the results during training.
61
+
62
+ ### Step 3. Monitor the training progress
63
+ - You can monitor the training progress using the [Weights & Biases](https://wandb.ai/site) dashboard.
64
+
65
+ - The training script will visualizing the training batch, the training losses, and validation set L2, LPIPS, and FID scores (if specified).
66
+ <div>
67
+ <p align="center">
68
+ <img src='../assets/examples/training_evaluation.png' align="center" width=800px>
69
+ </p>
70
+ </div>
71
+
72
+
73
+ - The model checkpoints will be saved in the `<output_dir>/checkpoints` directory.
74
+
75
+ - Screenshots of the training progress are shown below:
76
+ - Step 0:
77
+ <div>
78
+ <p align="center">
79
+ <img src='../assets/examples/training_step_0.png' align="center" width=800px>
80
+ </p>
81
+ </div>
82
+
83
+ - Step 500:
84
+ <div>
85
+ <p align="center">
86
+ <img src='../assets/examples/training_step_500.png' align="center" width=800px>
87
+ </p>
88
+ </div>
89
+
90
+ - Step 6000:
91
+ <div>
92
+ <p align="center">
93
+ <img src='../assets/examples/training_step_6000.png' align="center" width=800px>
94
+ </p>
95
+ </div>
96
+
97
+
98
+ ### Step 4. Running Inference with the trained models
99
+
100
+ - You can run inference using the trained model using the following command:
101
+ ```
102
+ python src/inference_paired.py --model_path "output/pix2pix_turbo/fill50k/checkpoints/model_6001.pkl" \
103
+ --input_image "data/my_fill50k/test_A/40000.png" \
104
+ --prompt "violet circle with orange background" \
105
+ --output_dir "outputs"
106
+ ```
107
+
108
+ - The above command should generate the following output:
109
+ <table>
110
+ <tr>
111
+ <th>Model Input</th>
112
+ <th>Model Output</th>
113
+ </tr>
114
+ <tr>
115
+ <td><img src='../assets/examples/circles_inference_input.png' width="200px"></td>
116
+ <td><img src='../assets/examples/circles_inference_output.png' width="200px"></td>
117
+ </tr>
118
+ </table>
environment.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: img2img-turbo
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.10
7
+ - pip:
8
+ - clip @ git+https://github.com/openai/CLIP.git
9
+ - einops>=0.6.1
10
+ - numpy>=1.24.4
11
+ - open-clip-torch>=2.20.0
12
+ - opencv-python==4.6.0.66
13
+ - pillow>=9.5.0
14
+ - scipy==1.11.1
15
+ - timm>=0.9.2
16
+ - tokenizers
17
+ - torch>=2.0.1
18
+
19
+ - torchaudio>=2.0.2
20
+ - torchdata==0.6.1
21
+ - torchmetrics>=1.0.1
22
+ - torchvision>=0.15.2
23
+
24
+ - tqdm>=4.65.0
25
+ - transformers==4.35.2
26
+ - urllib3<1.27,>=1.25.4
27
+ - xformers>=0.0.20
28
+ - streamlit-keyup==0.2.0
29
+ - lpips
30
+ - clean-fid
31
+ - peft
32
+ - dominate
33
+ - diffusers==0.25.1
34
+ - gradio==3.43.1
gradio_canny2image.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import torch
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+ from src.image_prep import canny_from_pil
7
+ from src.pix2pix_turbo import Pix2Pix_Turbo
8
+
9
+ model = Pix2Pix_Turbo("edge_to_image")
10
+
11
+
12
+ def process(input_image, prompt, low_threshold, high_threshold):
13
+ # resize to be a multiple of 8
14
+ new_width = input_image.width - input_image.width % 8
15
+ new_height = input_image.height - input_image.height % 8
16
+ input_image = input_image.resize((new_width, new_height))
17
+ canny = canny_from_pil(input_image, low_threshold, high_threshold)
18
+ with torch.no_grad():
19
+ c_t = transforms.ToTensor()(canny).unsqueeze(0).cuda()
20
+ output_image = model(c_t, prompt)
21
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
22
+ # flippy canny values, map all 0s to 1s and 1s to 0s
23
+ canny_viz = 1 - (np.array(canny) / 255)
24
+ canny_viz = Image.fromarray((canny_viz * 255).astype(np.uint8))
25
+ return canny_viz, output_pil
26
+
27
+
28
+ if __name__ == "__main__":
29
+ # load the model
30
+ with gr.Blocks() as demo:
31
+ gr.Markdown("# Pix2pix-Turbo: **Canny Edge -> Image**")
32
+ with gr.Row():
33
+ with gr.Column():
34
+ input_image = gr.Image(sources="upload", type="pil")
35
+ prompt = gr.Textbox(label="Prompt")
36
+ low_threshold = gr.Slider(
37
+ label="Canny low threshold",
38
+ minimum=1,
39
+ maximum=255,
40
+ value=100,
41
+ step=10,
42
+ )
43
+ high_threshold = gr.Slider(
44
+ label="Canny high threshold",
45
+ minimum=1,
46
+ maximum=255,
47
+ value=200,
48
+ step=10,
49
+ )
50
+ run_button = gr.Button(value="Run")
51
+ with gr.Column():
52
+ result_canny = gr.Image(type="pil")
53
+ with gr.Column():
54
+ result_output = gr.Image(type="pil")
55
+
56
+ prompt.submit(
57
+ fn=process,
58
+ inputs=[input_image, prompt, low_threshold, high_threshold],
59
+ outputs=[result_canny, result_output],
60
+ )
61
+ low_threshold.change(
62
+ fn=process,
63
+ inputs=[input_image, prompt, low_threshold, high_threshold],
64
+ outputs=[result_canny, result_output],
65
+ )
66
+ high_threshold.change(
67
+ fn=process,
68
+ inputs=[input_image, prompt, low_threshold, high_threshold],
69
+ outputs=[result_canny, result_output],
70
+ )
71
+ run_button.click(
72
+ fn=process,
73
+ inputs=[input_image, prompt, low_threshold, high_threshold],
74
+ outputs=[result_canny, result_output],
75
+ )
76
+
77
+ demo.queue()
78
+ demo.launch(debug=True, share=False)
gradio_sketch2image.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ from PIL import Image
4
+ import base64
5
+ from io import BytesIO
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ import gradio as gr
10
+
11
+ from src.pix2pix_turbo import Pix2Pix_Turbo
12
+
13
+ model = Pix2Pix_Turbo("sketch_to_image_stochastic")
14
+
15
+ style_list = [
16
+ {
17
+ "name": "Cinematic",
18
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
19
+ },
20
+ {
21
+ "name": "3D Model",
22
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
23
+ },
24
+ {
25
+ "name": "Anime",
26
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
27
+ },
28
+ {
29
+ "name": "Digital Art",
30
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
31
+ },
32
+ {
33
+ "name": "Photographic",
34
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
35
+ },
36
+ {
37
+ "name": "Pixel art",
38
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
39
+ },
40
+ {
41
+ "name": "Fantasy art",
42
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
43
+ },
44
+ {
45
+ "name": "Neonpunk",
46
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
47
+ },
48
+ {
49
+ "name": "Manga",
50
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
51
+ },
52
+ ]
53
+
54
+ styles = {k["name"]: k["prompt"] for k in style_list}
55
+ STYLE_NAMES = list(styles.keys())
56
+ DEFAULT_STYLE_NAME = "Fantasy art"
57
+ MAX_SEED = np.iinfo(np.int32).max
58
+
59
+
60
+ def pil_image_to_data_uri(img, format="PNG"):
61
+ buffered = BytesIO()
62
+ img.save(buffered, format=format)
63
+ img_str = base64.b64encode(buffered.getvalue()).decode()
64
+ return f"data:image/{format.lower()};base64,{img_str}"
65
+
66
+
67
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
68
+ print(f"prompt: {prompt}")
69
+ print("sketch updated")
70
+ if image is None:
71
+ ones = Image.new("L", (512, 512), 255)
72
+ temp_uri = pil_image_to_data_uri(ones)
73
+ return ones, gr.update(link=temp_uri), gr.update(link=temp_uri)
74
+ prompt = prompt_template.replace("{prompt}", prompt)
75
+ image = image.convert("RGB")
76
+ image_t = F.to_tensor(image) > 0.5
77
+ print(f"r_val={val_r}, seed={seed}")
78
+ with torch.no_grad():
79
+ c_t = image_t.unsqueeze(0).cuda().float()
80
+ torch.manual_seed(seed)
81
+ B, C, H, W = c_t.shape
82
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
83
+ output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
84
+ output_pil = F.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
85
+ input_sketch_uri = pil_image_to_data_uri(Image.fromarray(255 - np.array(image)))
86
+ output_image_uri = pil_image_to_data_uri(output_pil)
87
+ return (
88
+ output_pil,
89
+ gr.update(link=input_sketch_uri),
90
+ gr.update(link=output_image_uri),
91
+ )
92
+
93
+
94
+ def update_canvas(use_line, use_eraser):
95
+ if use_eraser:
96
+ _color = "#ffffff"
97
+ brush_size = 20
98
+ if use_line:
99
+ _color = "#000000"
100
+ brush_size = 4
101
+ return gr.update(brush_radius=brush_size, brush_color=_color, interactive=True)
102
+
103
+
104
+ def upload_sketch(file):
105
+ _img = Image.open(file.name)
106
+ _img = _img.convert("L")
107
+ return gr.update(value=_img, source="upload", interactive=True)
108
+
109
+
110
+ scripts = """
111
+ async () => {
112
+ globalThis.theSketchDownloadFunction = () => {
113
+ console.log("test")
114
+ var link = document.createElement("a");
115
+ dataUri = document.getElementById('download_sketch').href
116
+ link.setAttribute("href", dataUri)
117
+ link.setAttribute("download", "sketch.png")
118
+ document.body.appendChild(link); // Required for Firefox
119
+ link.click();
120
+ document.body.removeChild(link); // Clean up
121
+
122
+ // also call the output download function
123
+ theOutputDownloadFunction();
124
+ return false
125
+ }
126
+
127
+ globalThis.theOutputDownloadFunction = () => {
128
+ console.log("test output download function")
129
+ var link = document.createElement("a");
130
+ dataUri = document.getElementById('download_output').href
131
+ link.setAttribute("href", dataUri);
132
+ link.setAttribute("download", "output.png");
133
+ document.body.appendChild(link); // Required for Firefox
134
+ link.click();
135
+ document.body.removeChild(link); // Clean up
136
+ return false
137
+ }
138
+
139
+ globalThis.UNDO_SKETCH_FUNCTION = () => {
140
+ console.log("undo sketch function")
141
+ var button_undo = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(1)');
142
+ // Create a new 'click' event
143
+ var event = new MouseEvent('click', {
144
+ 'view': window,
145
+ 'bubbles': true,
146
+ 'cancelable': true
147
+ });
148
+ button_undo.dispatchEvent(event);
149
+ }
150
+
151
+ globalThis.DELETE_SKETCH_FUNCTION = () => {
152
+ console.log("delete sketch function")
153
+ var button_del = document.querySelector('#input_image > div.image-container.svelte-p3y7hu > div.svelte-s6ybro > button:nth-child(2)');
154
+ // Create a new 'click' event
155
+ var event = new MouseEvent('click', {
156
+ 'view': window,
157
+ 'bubbles': true,
158
+ 'cancelable': true
159
+ });
160
+ button_del.dispatchEvent(event);
161
+ }
162
+
163
+ globalThis.togglePencil = () => {
164
+ el_pencil = document.getElementById('my-toggle-pencil');
165
+ el_pencil.classList.toggle('clicked');
166
+ // simulate a click on the gradio button
167
+ btn_gradio = document.querySelector("#cb-line > label > input");
168
+ var event = new MouseEvent('click', {
169
+ 'view': window,
170
+ 'bubbles': true,
171
+ 'cancelable': true
172
+ });
173
+ btn_gradio.dispatchEvent(event);
174
+ if (el_pencil.classList.contains('clicked')) {
175
+ document.getElementById('my-toggle-eraser').classList.remove('clicked');
176
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
177
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
178
+ }
179
+ else {
180
+ document.getElementById('my-toggle-eraser').classList.add('clicked');
181
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
182
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
183
+ }
184
+ }
185
+
186
+ globalThis.toggleEraser = () => {
187
+ element = document.getElementById('my-toggle-eraser');
188
+ element.classList.toggle('clicked');
189
+ // simulate a click on the gradio button
190
+ btn_gradio = document.querySelector("#cb-eraser > label > input");
191
+ var event = new MouseEvent('click', {
192
+ 'view': window,
193
+ 'bubbles': true,
194
+ 'cancelable': true
195
+ });
196
+ btn_gradio.dispatchEvent(event);
197
+ if (element.classList.contains('clicked')) {
198
+ document.getElementById('my-toggle-pencil').classList.remove('clicked');
199
+ document.getElementById('my-div-pencil').style.backgroundColor = "white";
200
+ document.getElementById('my-div-eraser').style.backgroundColor = "gray";
201
+ }
202
+ else {
203
+ document.getElementById('my-toggle-pencil').classList.add('clicked');
204
+ document.getElementById('my-div-pencil').style.backgroundColor = "gray";
205
+ document.getElementById('my-div-eraser').style.backgroundColor = "white";
206
+ }
207
+ }
208
+ }
209
+ """
210
+
211
+ with gr.Blocks(css="style.css") as demo:
212
+
213
+ gr.HTML(
214
+ """
215
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
216
+ <div>
217
+ <h2><a href="https://github.com/GaParmar/img2img-turbo">One-Step Image Translation with Text-to-Image Models</a></h2>
218
+ <div>
219
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
220
+ <a href='https://gauravparmar.com/'>Gaurav Parmar, </a>
221
+ &nbsp;
222
+ <a href='https://taesung.me/'> Taesung Park,</a>
223
+ &nbsp;
224
+ <a href='https://www.cs.cmu.edu/~srinivas/'>Srinivasa Narasimhan, </a>
225
+ &nbsp;
226
+ <a href='https://www.cs.cmu.edu/~junyanz/'> Jun-Yan Zhu </a>
227
+ </div>
228
+ </div>
229
+ </br>
230
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
231
+ <a href='https://arxiv.org/abs/2403.12036'>
232
+ <img src="https://img.shields.io/badge/arXiv-2403.12036-red">
233
+ </a>
234
+ &nbsp;
235
+ <a href='https://github.com/GaParmar/img2img-turbo'>
236
+ <img src='https://img.shields.io/badge/github-%23121011.svg'>
237
+ </a>
238
+ &nbsp;
239
+ <a href='https://github.com/GaParmar/img2img-turbo/blob/main/LICENSE'>
240
+ <img src='https://img.shields.io/badge/license-MIT-lightgrey'>
241
+ </a>
242
+ </div>
243
+ </div>
244
+ </div>
245
+ <div>
246
+ </br>
247
+ </div>
248
+ """
249
+ )
250
+
251
+ # these are hidden buttons that are used to trigger the canvas changes
252
+ line = gr.Checkbox(label="line", value=False, elem_id="cb-line")
253
+ eraser = gr.Checkbox(label="eraser", value=False, elem_id="cb-eraser")
254
+ with gr.Row(elem_id="main_row"):
255
+ with gr.Column(elem_id="column_input"):
256
+ gr.Markdown("## INPUT", elem_id="input_header")
257
+ image = gr.Image(
258
+ source="canvas",
259
+ tool="color-sketch",
260
+ type="pil",
261
+ image_mode="L",
262
+ invert_colors=True,
263
+ shape=(512, 512),
264
+ brush_radius=4,
265
+ height=440,
266
+ width=440,
267
+ brush_color="#000000",
268
+ interactive=True,
269
+ show_download_button=True,
270
+ elem_id="input_image",
271
+ show_label=False,
272
+ )
273
+ download_sketch = gr.Button(
274
+ "Download sketch", scale=1, elem_id="download_sketch"
275
+ )
276
+
277
+ gr.HTML(
278
+ """
279
+ <div class="button-row">
280
+ <div id="my-div-pencil" class="pad2"> <button id="my-toggle-pencil" onclick="return togglePencil(this)"></button> </div>
281
+ <div id="my-div-eraser" class="pad2"> <button id="my-toggle-eraser" onclick="return toggleEraser(this)"></button> </div>
282
+ <div class="pad2"> <button id="my-button-undo" onclick="return UNDO_SKETCH_FUNCTION(this)"></button> </div>
283
+ <div class="pad2"> <button id="my-button-clear" onclick="return DELETE_SKETCH_FUNCTION(this)"></button> </div>
284
+ <div class="pad2"> <button href="TODO" download="image" id="my-button-down" onclick='return theSketchDownloadFunction()'></button> </div>
285
+ </div>
286
+ """
287
+ )
288
+ # gr.Markdown("## Prompt", elem_id="tools_header")
289
+ prompt = gr.Textbox(label="Prompt", value="", show_label=True)
290
+ with gr.Row():
291
+ style = gr.Dropdown(
292
+ label="Style",
293
+ choices=STYLE_NAMES,
294
+ value=DEFAULT_STYLE_NAME,
295
+ scale=1,
296
+ )
297
+ prompt_temp = gr.Textbox(
298
+ label="Prompt Style Template",
299
+ value=styles[DEFAULT_STYLE_NAME],
300
+ scale=2,
301
+ max_lines=1,
302
+ )
303
+
304
+ with gr.Row():
305
+ val_r = gr.Slider(
306
+ label="Sketch guidance: ",
307
+ show_label=True,
308
+ minimum=0,
309
+ maximum=1,
310
+ value=0.4,
311
+ step=0.01,
312
+ scale=3,
313
+ )
314
+ seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
315
+ randomize_seed = gr.Button("Random", scale=1, min_width=50)
316
+
317
+ with gr.Column(elem_id="column_process", min_width=50, scale=0.4):
318
+ gr.Markdown("## pix2pix-turbo", elem_id="description")
319
+ run_button = gr.Button("Run", min_width=50)
320
+
321
+ with gr.Column(elem_id="column_output"):
322
+ gr.Markdown("## OUTPUT", elem_id="output_header")
323
+ result = gr.Image(
324
+ label="Result",
325
+ height=440,
326
+ width=440,
327
+ elem_id="output_image",
328
+ show_label=False,
329
+ show_download_button=True,
330
+ )
331
+ download_output = gr.Button("Download output", elem_id="download_output")
332
+ gr.Markdown("### Instructions")
333
+ gr.Markdown("**1**. Enter a text prompt (e.g. cat)")
334
+ gr.Markdown("**2**. Start sketching")
335
+ gr.Markdown("**3**. Change the image style using a style template")
336
+ gr.Markdown("**4**. Adjust the effect of sketch guidance using the slider")
337
+ gr.Markdown("**5**. Try different seeds to generate different results")
338
+
339
+ eraser.change(
340
+ fn=lambda x: gr.update(value=not x),
341
+ inputs=[eraser],
342
+ outputs=[line],
343
+ queue=False,
344
+ api_name=False,
345
+ ).then(update_canvas, [line, eraser], [image])
346
+ line.change(
347
+ fn=lambda x: gr.update(value=not x),
348
+ inputs=[line],
349
+ outputs=[eraser],
350
+ queue=False,
351
+ api_name=False,
352
+ ).then(update_canvas, [line, eraser], [image])
353
+
354
+ demo.load(None, None, None, _js=scripts)
355
+ randomize_seed.click(
356
+ lambda x: random.randint(0, MAX_SEED),
357
+ inputs=[],
358
+ outputs=seed,
359
+ queue=False,
360
+ api_name=False,
361
+ )
362
+ inputs = [image, prompt, prompt_temp, style, seed, val_r]
363
+ outputs = [result, download_sketch, download_output]
364
+ prompt.submit(fn=run, inputs=inputs, outputs=outputs, api_name=False)
365
+ style.change(
366
+ lambda x: styles[x],
367
+ inputs=[style],
368
+ outputs=[prompt_temp],
369
+ queue=False,
370
+ api_name=False,
371
+ ).then(
372
+ fn=run,
373
+ inputs=inputs,
374
+ outputs=outputs,
375
+ api_name=False,
376
+ )
377
+ val_r.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
378
+ run_button.click(fn=run, inputs=inputs, outputs=outputs, api_name=False)
379
+ image.change(run, inputs=inputs, outputs=outputs, queue=False, api_name=False)
380
+
381
+ if __name__ == "__main__":
382
+ demo.queue().launch(debug=True, share=True)
python==3.9.8/Lib/site-packages/wheel/cli/tags.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import email.policy
4
+ import itertools
5
+ import os
6
+ from collections.abc import Iterable
7
+ from email.parser import BytesParser
8
+
9
+ from ..wheelfile import WheelFile
10
+
11
+
12
+ def _compute_tags(original_tags: Iterable[str], new_tags: str | None) -> set[str]:
13
+ """Add or replace tags. Supports dot-separated tags"""
14
+ if new_tags is None:
15
+ return set(original_tags)
16
+
17
+ if new_tags.startswith("+"):
18
+ return {*original_tags, *new_tags[1:].split(".")}
19
+
20
+ if new_tags.startswith("-"):
21
+ return set(original_tags) - set(new_tags[1:].split("."))
22
+
23
+ return set(new_tags.split("."))
24
+
25
+
26
+ def tags(
27
+ wheel: str,
28
+ python_tags: str | None = None,
29
+ abi_tags: str | None = None,
30
+ platform_tags: str | None = None,
31
+ build_tag: str | None = None,
32
+ remove: bool = False,
33
+ ) -> str:
34
+ """Change the tags on a wheel file.
35
+
36
+ The tags are left unchanged if they are not specified. To specify "none",
37
+ use ["none"]. To append to the previous tags, a tag should start with a
38
+ "+". If a tag starts with "-", it will be removed from existing tags.
39
+ Processing is done left to right.
40
+
41
+ :param wheel: The paths to the wheels
42
+ :param python_tags: The Python tags to set
43
+ :param abi_tags: The ABI tags to set
44
+ :param platform_tags: The platform tags to set
45
+ :param build_tag: The build tag to set
46
+ :param remove: Remove the original wheel
47
+ """
48
+ with WheelFile(wheel, "r") as f:
49
+ assert f.filename, f"{f.filename} must be available"
50
+
51
+ wheel_info = f.read(f.dist_info_path + "/WHEEL")
52
+ info = BytesParser(policy=email.policy.compat32).parsebytes(wheel_info)
53
+
54
+ original_wheel_name = os.path.basename(f.filename)
55
+ namever = f.parsed_filename.group("namever")
56
+ build = f.parsed_filename.group("build")
57
+ original_python_tags = f.parsed_filename.group("pyver").split(".")
58
+ original_abi_tags = f.parsed_filename.group("abi").split(".")
59
+ original_plat_tags = f.parsed_filename.group("plat").split(".")
60
+
61
+ tags: list[str] = info.get_all("Tag", [])
62
+ existing_build_tag = info.get("Build")
63
+
64
+ impls = {tag.split("-")[0] for tag in tags}
65
+ abivers = {tag.split("-")[1] for tag in tags}
66
+ platforms = {tag.split("-")[2] for tag in tags}
67
+
68
+ if impls != set(original_python_tags):
69
+ msg = f"Wheel internal tags {impls!r} != filename tags {original_python_tags!r}"
70
+ raise AssertionError(msg)
71
+
72
+ if abivers != set(original_abi_tags):
73
+ msg = f"Wheel internal tags {abivers!r} != filename tags {original_abi_tags!r}"
74
+ raise AssertionError(msg)
75
+
76
+ if platforms != set(original_plat_tags):
77
+ msg = (
78
+ f"Wheel internal tags {platforms!r} != filename tags {original_plat_tags!r}"
79
+ )
80
+ raise AssertionError(msg)
81
+
82
+ if existing_build_tag != build:
83
+ msg = (
84
+ f"Incorrect filename '{build}' "
85
+ f"& *.dist-info/WHEEL '{existing_build_tag}' build numbers"
86
+ )
87
+ raise AssertionError(msg)
88
+
89
+ # Start changing as needed
90
+ if build_tag is not None:
91
+ build = build_tag
92
+
93
+ final_python_tags = sorted(_compute_tags(original_python_tags, python_tags))
94
+ final_abi_tags = sorted(_compute_tags(original_abi_tags, abi_tags))
95
+ final_plat_tags = sorted(_compute_tags(original_plat_tags, platform_tags))
96
+
97
+ final_tags = [
98
+ namever,
99
+ ".".join(final_python_tags),
100
+ ".".join(final_abi_tags),
101
+ ".".join(final_plat_tags),
102
+ ]
103
+ if build:
104
+ final_tags.insert(1, build)
105
+
106
+ final_wheel_name = "-".join(final_tags) + ".whl"
107
+
108
+ if original_wheel_name != final_wheel_name:
109
+ del info["Tag"], info["Build"]
110
+ for a, b, c in itertools.product(
111
+ final_python_tags, final_abi_tags, final_plat_tags
112
+ ):
113
+ info["Tag"] = f"{a}-{b}-{c}"
114
+ if build:
115
+ info["Build"] = build
116
+
117
+ original_wheel_path = os.path.join(
118
+ os.path.dirname(f.filename), original_wheel_name
119
+ )
120
+ final_wheel_path = os.path.join(os.path.dirname(f.filename), final_wheel_name)
121
+
122
+ with WheelFile(original_wheel_path, "r") as fin, WheelFile(
123
+ final_wheel_path, "w"
124
+ ) as fout:
125
+ fout.comment = fin.comment # preserve the comment
126
+ for item in fin.infolist():
127
+ if item.is_dir():
128
+ continue
129
+ if item.filename == f.dist_info_path + "/RECORD":
130
+ continue
131
+ if item.filename == f.dist_info_path + "/WHEEL":
132
+ fout.writestr(item, info.as_bytes())
133
+ else:
134
+ fout.writestr(item, fin.read(item))
135
+
136
+ if remove:
137
+ os.remove(original_wheel_path)
138
+
139
+ return final_wheel_name
python==3.9.8/conda-meta/history ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ==> 2024-08-08 09:16:40 <==
2
+ # cmd: C:\ProgramData\miniconda3\Scripts\conda-script.py env create -f environment.yaml --p python==3.9.8
3
+ # conda version: 24.1.2
4
+ +defaults/noarch::tzdata-2024a-h04d1e81_0
5
+ +defaults/win-64::bzip2-1.0.8-h2bbff1b_6
6
+ +defaults/win-64::ca-certificates-2024.7.2-haa95532_0
7
+ +defaults/win-64::libffi-3.4.4-hd77b12b_1
8
+ +defaults/win-64::openssl-3.0.14-h827c3e9_0
9
+ +defaults/win-64::pip-24.0-py310haa95532_0
10
+ +defaults/win-64::python-3.10.14-he1021f5_1
11
+ +defaults/win-64::setuptools-72.1.0-py310haa95532_0
12
+ +defaults/win-64::sqlite-3.45.3-h2bbff1b_0
13
+ +defaults/win-64::tk-8.6.14-h0416ee5_0
14
+ +defaults/win-64::vc-14.2-h2eaa2aa_4
15
+ +defaults/win-64::vs2015_runtime-14.29.30133-h43f2093_4
16
+ +defaults/win-64::wheel-0.43.0-py310haa95532_0
17
+ +defaults/win-64::xz-5.4.6-h8cc25b3_1
18
+ +defaults/win-64::zlib-1.2.13-h8cc25b3_1
19
+ # update specs: ['pip', 'python=3.10']
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip @ git+https://github.com/openai/CLIP.git
2
+ einops>=0.6.1
3
+ numpy>=1.24.4
4
+ open-clip-torch>=2.20.0
5
+ opencv-python==4.6.0.66
6
+ pillow>=9.5.0
7
+ scipy==1.11.1
8
+ timm>=0.9.2
9
+ tokenizers
10
+ torch>=2.0.1
11
+
12
+ torchaudio>=2.0.2
13
+ torchdata==0.6.1
14
+ torchmetrics>=1.0.1
15
+ torchvision>=0.15.2
16
+
17
+ tqdm>=4.65.0
18
+ transformers==4.35.2
19
+ triton==2.0.0
20
+ urllib3<1.27,>=1.25.4
21
+ xformers>=0.0.20
22
+ streamlit-keyup==0.2.0
23
+ lpips
24
+ clean-fid
25
+ peft
26
+ dominate
27
+ diffusers==0.25.1
28
+ gradio==3.43.1
scripts/download_fill50k.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mkdir -p data
2
+ wget https://www.cs.cmu.edu/~img2img-turbo/data/my_fill50k.zip -O data/my_fill50k.zip
3
+ cd data
4
+ unzip my_fill50k.zip
5
+ rm my_fill50k.zip
scripts/download_horse2zebra.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mkdir -p data
2
+ wget https://www.cs.cmu.edu/~img2img-turbo/data/my_horse2zebra.zip -O data/my_horse2zebra.zip
3
+ cd data
4
+ unzip my_horse2zebra.zip
5
+ rm my_horse2zebra.zip
src/cyclegan_turbo.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import copy
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoTokenizer, CLIPTextModel
7
+ from diffusers import AutoencoderKL, UNet2DConditionModel
8
+ from peft import LoraConfig
9
+ from peft.utils import get_peft_model_state_dict
10
+ p = "src/"
11
+ sys.path.append(p)
12
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
13
+
14
+
15
+ class VAE_encode(nn.Module):
16
+ def __init__(self, vae, vae_b2a=None):
17
+ super(VAE_encode, self).__init__()
18
+ self.vae = vae
19
+ self.vae_b2a = vae_b2a
20
+
21
+ def forward(self, x, direction):
22
+ assert direction in ["a2b", "b2a"]
23
+ if direction == "a2b":
24
+ _vae = self.vae
25
+ else:
26
+ _vae = self.vae_b2a
27
+ return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
28
+
29
+
30
+ class VAE_decode(nn.Module):
31
+ def __init__(self, vae, vae_b2a=None):
32
+ super(VAE_decode, self).__init__()
33
+ self.vae = vae
34
+ self.vae_b2a = vae_b2a
35
+
36
+ def forward(self, x, direction):
37
+ assert direction in ["a2b", "b2a"]
38
+ if direction == "a2b":
39
+ _vae = self.vae
40
+ else:
41
+ _vae = self.vae_b2a
42
+ assert _vae.encoder.current_down_blocks is not None
43
+ _vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
44
+ x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
45
+ return x_decoded
46
+
47
+
48
+ def initialize_unet(rank, return_lora_module_names=False):
49
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
50
+ unet.requires_grad_(False)
51
+ unet.train()
52
+ l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
53
+ l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
54
+ for n, p in unet.named_parameters():
55
+ if "bias" in n or "norm" in n: continue
56
+ for pattern in l_grep:
57
+ if pattern in n and ("down_blocks" in n or "conv_in" in n):
58
+ l_target_modules_encoder.append(n.replace(".weight",""))
59
+ break
60
+ elif pattern in n and "up_blocks" in n:
61
+ l_target_modules_decoder.append(n.replace(".weight",""))
62
+ break
63
+ elif pattern in n:
64
+ l_modules_others.append(n.replace(".weight",""))
65
+ break
66
+ lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
67
+ lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
68
+ lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
69
+ unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
70
+ unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
71
+ unet.add_adapter(lora_conf_others, adapter_name="default_others")
72
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
73
+ if return_lora_module_names:
74
+ return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
75
+ else:
76
+ return unet
77
+
78
+
79
+ def initialize_vae(rank=4, return_lora_module_names=False):
80
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
81
+ vae.requires_grad_(False)
82
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
83
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
84
+ vae.requires_grad_(True)
85
+ vae.train()
86
+ # add the skip connection convs
87
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
88
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
89
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
90
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
91
+ torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
92
+ torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
93
+ torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
94
+ torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
95
+ vae.decoder.ignore_skip = False
96
+ vae.decoder.gamma = 1
97
+ l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
98
+ "conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
99
+ "skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
100
+ ]
101
+ vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
102
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
103
+ if return_lora_module_names:
104
+ return vae, l_vae_target_modules
105
+ else:
106
+ return vae
107
+
108
+
109
+ class CycleGAN_Turbo(torch.nn.Module):
110
+ def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
111
+ super().__init__()
112
+ self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
113
+ self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
114
+ self.sched = make_1step_sched()
115
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
116
+ unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
117
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
118
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
119
+ # add the skip connection convs
120
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
121
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
122
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
123
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
124
+ vae.decoder.ignore_skip = False
125
+ self.unet, self.vae = unet, vae
126
+ if pretrained_name == "day_to_night":
127
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
128
+ self.load_ckpt_from_url(url, ckpt_folder)
129
+ self.timesteps = torch.tensor([999], device="cuda").long()
130
+ self.caption = "driving in the night"
131
+ self.direction = "a2b"
132
+ elif pretrained_name == "night_to_day":
133
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
134
+ self.load_ckpt_from_url(url, ckpt_folder)
135
+ self.timesteps = torch.tensor([999], device="cuda").long()
136
+ self.caption = "driving in the day"
137
+ self.direction = "b2a"
138
+ elif pretrained_name == "clear_to_rainy":
139
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
140
+ self.load_ckpt_from_url(url, ckpt_folder)
141
+ self.timesteps = torch.tensor([999], device="cuda").long()
142
+ self.caption = "driving in heavy rain"
143
+ self.direction = "a2b"
144
+ elif pretrained_name == "rainy_to_clear":
145
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
146
+ self.load_ckpt_from_url(url, ckpt_folder)
147
+ self.timesteps = torch.tensor([999], device="cuda").long()
148
+ self.caption = "driving in the day"
149
+ self.direction = "b2a"
150
+
151
+ elif pretrained_path is not None:
152
+ sd = torch.load(pretrained_path)
153
+ self.load_ckpt_from_state_dict(sd)
154
+ self.timesteps = torch.tensor([999], device="cuda").long()
155
+ self.caption = None
156
+ self.direction = None
157
+
158
+ self.vae_enc.cuda()
159
+ self.vae_dec.cuda()
160
+ self.unet.cuda()
161
+
162
+ def load_ckpt_from_state_dict(self, sd):
163
+ lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
164
+ lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
165
+ lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
166
+ self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
167
+ self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
168
+ self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
169
+ for n, p in self.unet.named_parameters():
170
+ name_sd = n.replace(".default_encoder.weight", ".weight")
171
+ if "lora" in n and "default_encoder" in n:
172
+ p.data.copy_(sd["sd_encoder"][name_sd])
173
+ for n, p in self.unet.named_parameters():
174
+ name_sd = n.replace(".default_decoder.weight", ".weight")
175
+ if "lora" in n and "default_decoder" in n:
176
+ p.data.copy_(sd["sd_decoder"][name_sd])
177
+ for n, p in self.unet.named_parameters():
178
+ name_sd = n.replace(".default_others.weight", ".weight")
179
+ if "lora" in n and "default_others" in n:
180
+ p.data.copy_(sd["sd_other"][name_sd])
181
+ self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
182
+
183
+ vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
184
+ self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
185
+ self.vae.decoder.gamma = 1
186
+ self.vae_b2a = copy.deepcopy(self.vae)
187
+ self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
188
+ self.vae_enc.load_state_dict(sd["sd_vae_enc"])
189
+ self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
190
+ self.vae_dec.load_state_dict(sd["sd_vae_dec"])
191
+
192
+ def load_ckpt_from_url(self, url, ckpt_folder):
193
+ os.makedirs(ckpt_folder, exist_ok=True)
194
+ outf = os.path.join(ckpt_folder, os.path.basename(url))
195
+ download_url(url, outf)
196
+ sd = torch.load(outf)
197
+ self.load_ckpt_from_state_dict(sd)
198
+
199
+ @staticmethod
200
+ def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
201
+ B = x.shape[0]
202
+ assert direction in ["a2b", "b2a"]
203
+ x_enc = vae_enc(x, direction=direction).to(x.dtype)
204
+ model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
205
+ x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
206
+ x_out_decoded = vae_dec(x_out, direction=direction)
207
+ return x_out_decoded
208
+
209
+ @staticmethod
210
+ def get_traininable_params(unet, vae_a2b, vae_b2a):
211
+ # add all unet parameters
212
+ params_gen = list(unet.conv_in.parameters())
213
+ unet.conv_in.requires_grad_(True)
214
+ unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
215
+ for n,p in unet.named_parameters():
216
+ if "lora" in n and "default" in n:
217
+ assert p.requires_grad
218
+ params_gen.append(p)
219
+
220
+ # add all vae_a2b parameters
221
+ for n,p in vae_a2b.named_parameters():
222
+ if "lora" in n and "vae_skip" in n:
223
+ assert p.requires_grad
224
+ params_gen.append(p)
225
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
226
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
227
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
228
+ params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
229
+
230
+ # add all vae_b2a parameters
231
+ for n,p in vae_b2a.named_parameters():
232
+ if "lora" in n and "vae_skip" in n:
233
+ assert p.requires_grad
234
+ params_gen.append(p)
235
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
236
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
237
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
238
+ params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
239
+ return params_gen
240
+
241
+ def forward(self, x_t, direction=None, caption=None, caption_emb=None):
242
+ if direction is None:
243
+ assert self.direction is not None
244
+ direction = self.direction
245
+ if caption is None and caption_emb is None:
246
+ assert self.caption is not None
247
+ caption = self.caption
248
+ if caption_emb is not None:
249
+ caption_enc = caption_emb
250
+ else:
251
+ caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
252
+ padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
253
+ caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
254
+ return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
src/image_prep.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+
6
+ def canny_from_pil(image, low_threshold=100, high_threshold=200):
7
+ image = np.array(image)
8
+ image = cv2.Canny(image, low_threshold, high_threshold)
9
+ image = image[:, :, None]
10
+ image = np.concatenate([image, image, image], axis=2)
11
+ control_image = Image.fromarray(image)
12
+ return control_image
src/inference_paired.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ from torchvision import transforms
7
+ import torchvision.transforms.functional as F
8
+ from pix2pix_turbo import Pix2Pix_Turbo
9
+ from image_prep import canny_from_pil
10
+
11
+ if __name__ == "__main__":
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
14
+ parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
15
+ parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
16
+ parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
17
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
18
+ parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
19
+ parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
20
+ parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
21
+ parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
22
+ args = parser.parse_args()
23
+
24
+ # only one of model_name and model_path should be provided
25
+ if args.model_name == '' != args.model_path == '':
26
+ raise ValueError('Either model_name or model_path should be provided')
27
+
28
+ os.makedirs(args.output_dir, exist_ok=True)
29
+
30
+ # initialize the model
31
+ model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
32
+ model.set_eval()
33
+
34
+ # make sure that the input image is a multiple of 8
35
+ input_image = Image.open(args.input_image).convert('RGB')
36
+ new_width = input_image.width - input_image.width % 8
37
+ new_height = input_image.height - input_image.height % 8
38
+ input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
39
+ bname = os.path.basename(args.input_image)
40
+
41
+ # translate the image
42
+ with torch.no_grad():
43
+ if args.model_name == 'edge_to_image':
44
+ canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
45
+ canny_viz_inv = Image.fromarray(255 - np.array(canny))
46
+ canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
47
+ c_t = F.to_tensor(canny).unsqueeze(0).cuda()
48
+ output_image = model(c_t, args.prompt)
49
+
50
+ elif args.model_name == 'sketch_to_image_stochastic':
51
+ image_t = F.to_tensor(input_image) < 0.5
52
+ c_t = image_t.unsqueeze(0).cuda().float()
53
+ torch.manual_seed(args.seed)
54
+ B, C, H, W = c_t.shape
55
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
56
+ output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
57
+
58
+ else:
59
+ c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
60
+ output_image = model(c_t, args.prompt)
61
+
62
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
63
+
64
+ # save the output image
65
+ output_pil.save(os.path.join(args.output_dir, bname))
src/inference_unpaired.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from PIL import Image
4
+ import torch
5
+ from torchvision import transforms
6
+ from cyclegan_turbo import CycleGAN_Turbo
7
+ from my_utils.training_utils import build_transform
8
+
9
+
10
+ if __name__ == "__main__":
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
13
+ parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
14
+ parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
15
+ parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
16
+ parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
17
+ parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
18
+ parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
19
+ args = parser.parse_args()
20
+
21
+ # only one of model_name and model_path should be provided
22
+ if args.model_name is None != args.model_path is None:
23
+ raise ValueError('Either model_name or model_path should be provided')
24
+
25
+ if args.model_path is not None and args.prompt is None:
26
+ raise ValueError('prompt is required when loading a custom model_path.')
27
+
28
+ if args.model_name is not None:
29
+ assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
30
+ assert args.direction is None, 'direction is not required when loading a pretrained model.'
31
+
32
+ # initialize the model
33
+ model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
34
+ model.eval()
35
+ model.unet.enable_xformers_memory_efficient_attention()
36
+
37
+ T_val = build_transform(args.image_prep)
38
+
39
+ input_image = Image.open(args.input_image).convert('RGB')
40
+ # translate the image
41
+ with torch.no_grad():
42
+ input_img = T_val(input_image)
43
+ x_t = transforms.ToTensor()(input_img)
44
+ x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
45
+ output = model(x_t, direction=args.direction, caption=args.prompt)
46
+
47
+ output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
48
+ output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
49
+
50
+ # save the output image
51
+ bname = os.path.basename(args.input_image)
52
+ os.makedirs(args.output_dir, exist_ok=True)
53
+ output_pil.save(os.path.join(args.output_dir, bname))