deepsuchak commited on
Commit
949df4a
·
verified ·
1 Parent(s): a61d4bf

Upload 8 files

Browse files
Files changed (8) hide show
  1. .gitattributes +35 -0
  2. LICENSE +437 -0
  3. README.md +201 -0
  4. environment.yaml +197 -0
  5. main.py +738 -0
  6. test.py +447 -0
  7. test.sh +13 -0
  8. train.sh +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <<<<<<< HEAD
2
+ # MV-VTON
3
+
4
+ PyTorch implementation of **MV-VTON: Multi-View Virtual Try-On with Diffusion Models**
5
+
6
+ [![arXiv](https://img.shields.io/badge/arXiv-2404.04908-b10.svg)](https://arxiv.org/abs/2404.17364)
7
+ [![Project](https://img.shields.io/badge/Project-Website-orange)](https://hywang2002.github.io/MV-VTON/)
8
+ ![visitors](https://visitor-badge.laobi.icu/badge?page_id=hywang2002.MV-VTON)
9
+ [![LICENSE](https://img.shields.io/badge/license-CC--BY--NC--SA--4.0-lightgrey)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
10
+
11
+ ## News
12
+ - 🔥The first multi-view virtual try-on dataset MVG is now available.
13
+ - 🔥Checkpoints on both frontal-view and multi-view virtual try-on tasks are released.
14
+
15
+ ## Overview
16
+
17
+ ![](assets/framework.png)
18
+ > **Abstract:**
19
+ > The goal of image-based virtual try-on is to generate an image of the target person naturally wearing the given
20
+ > clothing. However, most existing methods solely focus on the frontal try-on using the frontal clothing. When the views
21
+ > of the clothing and person are significantly inconsistent, particularly when the person’s view is non-frontal, the
22
+ > results are unsatisfactory. To address this challenge, we introduce Multi-View Virtual Try-ON (MV-VTON), which aims to
23
+ > reconstruct the dressing results of a person from multiple views using the given clothes. On the one hand, given that
24
+ > single-view clothes provide insufficient information for MV-VTON, we instead employ two images, i.e., the frontal and
25
+ > back views of the clothing, to encompass the complete view as much as possible. On the other hand, the diffusion
26
+ > models
27
+ > that have demonstrated superior abilities are adopted to perform our MV-VTON. In particular, we propose a
28
+ > view-adaptive
29
+ > selection method where hard-selection and soft-selection are applied to the global and local clothing feature
30
+ > extraction, respectively. This ensures that the clothing features are roughly fit to the person’s view. Subsequently,
31
+ > we
32
+ > suggest a joint attention block to align and fuse clothing features with person features. Additionally, we collect a
33
+ > MV-VTON dataset, i.e., Multi-View Garment (MVG), in which each person has multiple photos with diverse views and
34
+ > poses.
35
+ > Experiments show that the proposed method not only achieves state-of-the-art results on MV-VTON task using our MVG
36
+ > dataset, but also has superiority on frontal-view virtual try-on task using VITON-HD and DressCode datasets.
37
+
38
+ ## Getting Started
39
+
40
+ ### Installation
41
+
42
+ 1. Clone the repository
43
+
44
+ ```shell
45
+ git clone https://github.com/hywang2002/MV-VTON.git
46
+ cd MV-VTON
47
+ ```
48
+
49
+ 2. Install Python dependencies
50
+
51
+ ```shell
52
+ conda env create -f environment.yaml
53
+ conda activate mv-vton
54
+ ```
55
+
56
+ 3. Download the pretrained [vgg](https://drive.google.com/file/d/1rvow8jStPt8t2prDcSRlnf8yzXhrYeGo/view?usp=sharing)
57
+ checkpoint and put it in `models/vgg/` for Multi-View VTON and `Frontal-View VTON/models/vgg/` for Frontal-View VTON.
58
+ 4. Download the pretrained models `mvg.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/17SC8fHE5w2g7gEtzJgRRew?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1J91PoT8A9yqHWNxkgRe6ZCnDEhN-H9O6/view?usp=sharing),
59
+ and `vitonhd.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/1R2yGgm35UwTpnXPEU6-tlA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/13A0uzUY6PuvitLOqzyHzWASOh0dNXdem/view?usp=sharing), and put `mvg.ckpt` in `checkpoint/` and
60
+ put `vitonhd.ckpt`
61
+ in `Frontal-View VTON/checkpoint/`.
62
+
63
+ ### Datasets
64
+
65
+ #### MVG
66
+
67
+ 1. Fill `Dataset Request Form` via [Baidu Cloud](https://pan.baidu.com/s/12HAq0V4FfgpU_q8AeyZzwA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1zWt6HYBz7Vzaxu8rp1bwkhRoBkxbwQjw/view?usp=sharing), and
68
+ contact `cshy2mvvton@outlook.com` with this form to get MVG dataset (
69
+ Non-institutional emails (e.g. gmail.com) are not allowed. Please provide your institutional
70
+ email address.).
71
+
72
+ After these, the folder structure should look like this (the warp_feat_unpair* only included in test directory):
73
+
74
+ ```
75
+ ├── MVG
76
+ | ├── unpaired.txt
77
+ │ ├── [train | test]
78
+ | | ├── image-wo-bg
79
+ │ │ ├── cloth
80
+ │ │ ├── cloth-mask
81
+ │ │ ├── warp_feat
82
+ │ │ ├── warp_feat_unpair
83
+ │ │ ├── ...
84
+ ```
85
+
86
+ #### VITON-HD
87
+
88
+ 1. Download [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset
89
+ 2. Download pre-warped cloth image/mask via [Baidu Cloud](https://pan.baidu.com/s/1uQM0IOltOmbeqwdOKX5kCw?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/18DTWfhxUnfg41nnwwpCKN--akC4eT9DM/view?usp=sharing) and
90
+ put
91
+ it under VITON-HD dataset.
92
+
93
+ After these, the folder structure should look like this (the unpaired-cloth* only included in test directory):
94
+
95
+ ```
96
+ ├── VITON-HD
97
+ | ├── test_pairs.txt
98
+ | ├── train_pairs.txt
99
+ │ ├── [train | test]
100
+ | | ├── image
101
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
102
+ │ │ ├── cloth
103
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
104
+ │ │ ├── cloth-mask
105
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
106
+ │ │ ├── cloth-warp
107
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
108
+ │ │ ├── cloth-warp-mask
109
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
110
+ │ │ ├── unpaired-cloth-warp
111
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
112
+ │ │ ├── unpaired-cloth-warp-mask
113
+ │ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
114
+ ```
115
+
116
+ ### Inference
117
+
118
+ #### MVG
119
+
120
+ To test on paired settings (using `cp_dataset_mv_paired.py`), you can modify the `configs/viton512.yaml` and `main.py`,
121
+ or directly rename `cp_dataset_mv_paired.py` to `cp_dataset.py` (recommended). Then run:
122
+
123
+ ```shell
124
+ sh test.sh
125
+ ```
126
+
127
+ To test on unpaired settings, rename `cp_dataset_mv_unpaired.py` to `cp_dataset.py`, and do the same operation.
128
+
129
+ #### VITON-HD
130
+
131
+ To test on paired settings, input command `cd Frontal-View\ VTON/`, then directly run:
132
+
133
+ ```shell
134
+ sh test.sh
135
+ ```
136
+
137
+ To test on unpaired settings, input command `cd Frontal-View\ VTON/`, add `--unpaired` to `test.sh`, add then run:
138
+
139
+ ```shell
140
+ sh test.sh
141
+ ```
142
+
143
+ #### Metrics
144
+
145
+ We compute `LPIPS`, `SSIM`, `FID`, `KID` using the same tools in [LaDI-VTON](https://github.com/miccunifi/ladi-vton).
146
+
147
+ ### Training
148
+
149
+ #### MVG
150
+
151
+ We use Paint-by-Example as initialization, please download the pretrained model
152
+ from [Google Drive](https://drive.google.com/file/d/15QzaTWsvZonJcXsNv-ilMRCYaQLhzR_i/view) and save the model to
153
+ directory `checkpoints`. Rename `cp_dataset_mv_paired.py` to `cp_dataset.py`, then run:
154
+
155
+ ```shell
156
+ sh train.sh
157
+ ```
158
+
159
+ #### VITON-HD
160
+
161
+ Input command `cd Frontal-View\ VTON/`, then directly run:
162
+
163
+ ```shell
164
+ sh train.sh
165
+ ```
166
+
167
+ ## Acknowledgements
168
+
169
+ Our code is heavily borrowed from [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example)
170
+ and [DCI-VTON](https://github.com/bcmi/DCI-VTON-Virtual-Try-On). We also
171
+ thank previous work [PF-AFN](https://github.com/geyuying/PF-AFN), [GP-VTON](https://github.com/xiezhy6/GP-VTON),
172
+ [LaDI-VTON](https://github.com/miccunifi/ladi-vton)
173
+ and [StableVITON](https://github.com/rlawjdghek/StableVITON).
174
+
175
+ ## LICENSE
176
+ MV-VTON: Multi-View Virtual Try-On with Diffusion Models © 2024 by Haoyu Wang, Zhilu Zhang, Donglin Di, Shiliang Zhang, Wangmeng Zuo is licensed under CC BY-NC-SA 4.0
177
+
178
+ ## Citation
179
+
180
+ ```
181
+ @article{wang2024mv,
182
+ title={MV-VTON: Multi-View Virtual Try-On with Diffusion Models},
183
+ author={Wang, Haoyu and Zhang, Zhilu and Di, Donglin and Zhang, Shiliang and Zuo, Wangmeng},
184
+ journal={arXiv preprint arXiv:2404.17364},
185
+ year={2024}
186
+ }
187
+ ```
188
+ =======
189
+ ---
190
+ title: Mv Vton Demo
191
+ emoji: 👁
192
+ colorFrom: gray
193
+ colorTo: pink
194
+ sdk: gradio
195
+ sdk_version: 5.26.0
196
+ app_file: app.py
197
+ pinned: false
198
+ ---
199
+
200
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
201
+ >>>>>>> 2a4541c57faf075fa9e813ae2777dfaa55fc0306
environment.yaml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: mv-vton
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - blas=1.0=mkl
9
+ - brotli-python=1.0.9=py38h6a678d5_7
10
+ - bzip2=1.0.8=h7b6447c_0
11
+ - ca-certificates=2023.08.22=h06a4308_0
12
+ - certifi=2023.11.17=py38h06a4308_0
13
+ - cffi=1.15.1=py38h74dc2b5_0
14
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
15
+ - cryptography=41.0.3=py38h130f0dd_0
16
+ - cudatoolkit=11.3.1=h2bc3f7f_2
17
+ - ffmpeg=4.3=hf484d3e_0
18
+ - freetype=2.12.1=h4a9f257_0
19
+ - giflib=5.2.1=h5eee18b_3
20
+ - gmp=6.2.1=h295c915_3
21
+ - gnutls=3.6.15=he1e5248_0
22
+ - idna=3.4=py38h06a4308_0
23
+ - intel-openmp=2021.4.0=h06a4308_3561
24
+ - jpeg=9e=h5eee18b_1
25
+ - lame=3.100=h7b6447c_0
26
+ - lcms2=2.12=h3be6417_0
27
+ - ld_impl_linux-64=2.38=h1181459_1
28
+ - lerc=3.0=h295c915_0
29
+ - libdeflate=1.17=h5eee18b_1
30
+ - libffi=3.3=he6710b0_2
31
+ - libgcc-ng=11.2.0=h1234567_1
32
+ - libgfortran-ng=11.2.0=h00389a5_1
33
+ - libgfortran5=11.2.0=h1234567_1
34
+ - libgomp=11.2.0=h1234567_1
35
+ - libiconv=1.16=h7f8727e_2
36
+ - libidn2=2.3.4=h5eee18b_0
37
+ - libpng=1.6.39=h5eee18b_0
38
+ - libstdcxx-ng=11.2.0=h1234567_1
39
+ - libtasn1=4.19.0=h5eee18b_0
40
+ - libtiff=4.5.1=h6a678d5_0
41
+ - libunistring=0.9.10=h27cfd23_0
42
+ - libuv=1.44.2=h5eee18b_0
43
+ - libwebp=1.3.2=h11a3e52_0
44
+ - libwebp-base=1.3.2=h5eee18b_0
45
+ - lz4-c=1.9.4=h6a678d5_0
46
+ - mkl=2021.4.0=h06a4308_640
47
+ - mkl-service=2.4.0=py38h7f8727e_0
48
+ - mkl_fft=1.3.1=py38hd3c417c_0
49
+ - mkl_random=1.2.2=py38h51133e4_0
50
+ - ncurses=6.4=h6a678d5_0
51
+ - nettle=3.7.3=hbbd107a_1
52
+ - openh264=2.1.1=h4ff587b_0
53
+ - openjpeg=2.4.0=h3ad879b_0
54
+ - openssl=1.1.1w=h7f8727e_0
55
+ - pillow=10.0.1=py38ha6cbd5a_0
56
+ - pip=20.3.3=py38h06a4308_0
57
+ - pycparser=2.21=pyhd3eb1b0_0
58
+ - pyopenssl=23.2.0=py38h06a4308_0
59
+ - pysocks=1.7.1=py38h06a4308_0
60
+ - python=3.8.5=h7579374_1
61
+ - pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
62
+ - pytorch-mutex=1.0=cuda
63
+ - readline=8.2=h5eee18b_0
64
+ - requests=2.31.0=py38h06a4308_0
65
+ - setuptools=68.0.0=py38h06a4308_0
66
+ - six=1.16.0=pyhd3eb1b0_1
67
+ - sqlite=3.41.2=h5eee18b_0
68
+ - tk=8.6.12=h1ccaba5_0
69
+ - torchvision=0.12.0=py38_cu113
70
+ - typing_extensions=4.7.1=py38h06a4308_0
71
+ - urllib3=1.26.18=py38h06a4308_0
72
+ - wheel=0.41.2=py38h06a4308_0
73
+ - xz=5.4.5=h5eee18b_0
74
+ - zlib=1.2.13=h5eee18b_0
75
+ - zstd=1.5.5=hc292b87_0
76
+ - pip:
77
+ - absl-py==2.0.0
78
+ - aiohttp==3.9.1
79
+ - aiosignal==1.3.1
80
+ - albumentations==0.4.3
81
+ - altair==5.2.0
82
+ - antlr4-python3-runtime==4.9.3
83
+ - async-timeout==4.0.3
84
+ - attrs==23.1.0
85
+ - av==12.0.0
86
+ - backports-zoneinfo==0.2.1
87
+ - bezier==2023.7.28
88
+ - black==24.2.0
89
+ - blinker==1.7.0
90
+ - cachetools==5.3.2
91
+ - click==8.1.7
92
+ - clip==0.2.0
93
+ - cloudpickle==3.0.0
94
+ - contourpy==1.1.1
95
+ - cupy==12.3.0
96
+ - cycler==0.12.1
97
+ - diffusers==0.20.0
98
+ - einops==0.3.0
99
+ - fastrlock==0.8.2
100
+ - filelock==3.13.1
101
+ - fonttools==4.45.1
102
+ - frozenlist==1.4.0
103
+ - fsspec==2023.10.0
104
+ - future==0.18.3
105
+ - fvcore==0.1.5.post20221221
106
+ - gitdb==4.0.11
107
+ - gitpython==3.1.40
108
+ - google-auth==2.23.4
109
+ - google-auth-oauthlib==1.0.0
110
+ - grpcio==1.59.3
111
+ - huggingface-hub==0.19.4
112
+ - hydra-core==1.3.2
113
+ - imageio==2.9.0
114
+ - imageio-ffmpeg==0.4.2
115
+ - imgaug==0.2.6
116
+ - importlib-metadata==6.8.0
117
+ - importlib-resources==6.1.1
118
+ - invisible-watermark==0.2.0
119
+ - iopath==0.1.9
120
+ - jinja2==3.1.2
121
+ - jsonschema==4.20.0
122
+ - jsonschema-specifications==2023.11.1
123
+ - kiwisolver==1.4.5
124
+ - kornia==0.6.0
125
+ - lazy-loader==0.3
126
+ - markdown==3.5.1
127
+ - markdown-it-py==3.0.0
128
+ - markupsafe==2.1.3
129
+ - matplotlib==3.7.4
130
+ - mdurl==0.1.2
131
+ - multidict==6.0.4
132
+ - mypy-extensions==1.0.0
133
+ - networkx==3.1
134
+ - numpy==1.24.4
135
+ - oauthlib==3.2.2
136
+ - omegaconf==2.3.0
137
+ - opencv-python==4.1.2.30
138
+ - opencv-python-headless==4.8.1.78
139
+ - packaging==23.2
140
+ - pandas==2.0.3
141
+ - pathspec==0.12.1
142
+ - pkgutil-resolve-name==1.3.10
143
+ - platformdirs==4.2.0
144
+ - portalocker==2.8.2
145
+ - protobuf==4.25.1
146
+ - pudb==2019.2
147
+ - pyarrow==14.0.1
148
+ - pyasn1==0.5.1
149
+ - pyasn1-modules==0.3.0
150
+ - pycocotools==2.0.7
151
+ - pydeck==0.8.1b0
152
+ - pydeprecate==0.3.1
153
+ - pygments==2.17.2
154
+ - pyparsing==3.1.1
155
+ - python-dateutil==2.8.2
156
+ - pytorch-lightning==1.4.2
157
+ - pytz==2023.3.post1
158
+ - pywavelets==1.4.1
159
+ - pyyaml==6.0.1
160
+ - referencing==0.31.1
161
+ - regex==2023.10.3
162
+ - requests-oauthlib==1.3.1
163
+ - rich==13.7.0
164
+ - rpds-py==0.13.2
165
+ - rsa==4.9
166
+ - safetensors==0.4.1
167
+ - scikit-image==0.20.0
168
+ - scipy==1.9.1
169
+ - smmap==5.0.1
170
+ - streamlit==1.28.2
171
+ - tabulate==0.9.0
172
+ - taming-transformers==0.0.1
173
+ - tenacity==8.2.3
174
+ - tensorboard==2.14.0
175
+ - tensorboard-data-server==0.7.2
176
+ - termcolor==2.4.0
177
+ - test-tube==0.7.5
178
+ - tifffile==2023.7.10
179
+ - tokenizers==0.12.1
180
+ - toml==0.10.2
181
+ - tomli==2.0.1
182
+ - toolz==0.12.0
183
+ - torch-fidelity==0.3.0
184
+ - torchmetrics==0.6.0
185
+ - tornado==6.4
186
+ - tqdm==4.66.1
187
+ - transformers==4.27.3
188
+ - tzdata==2023.3
189
+ - tzlocal==5.2
190
+ - urwid==2.2.3
191
+ - validators==0.22.0
192
+ - watchdog==3.0.0
193
+ - werkzeug==3.0.1
194
+ - yacs==0.1.8
195
+ - yarl==1.9.3
196
+ - zipp==3.17.0
197
+ prefix: /mnt/pfs-mc0p4k/cvg/team/didonglin/conda_envs/mv-vton
main.py ADDED
@@ -0,0 +1,738 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, datetime, glob, importlib, csv
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+ import torchvision
6
+ import pytorch_lightning as pl
7
+
8
+ sys.setrecursionlimit(10000)
9
+ from packaging import version
10
+ from omegaconf import OmegaConf
11
+ from torch.utils.data import random_split, DataLoader, Dataset, Subset
12
+ from functools import partial
13
+ from PIL import Image
14
+
15
+ from pytorch_lightning import seed_everything
16
+ from pytorch_lightning.trainer import Trainer
17
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
18
+ from pytorch_lightning.utilities.distributed import rank_zero_only
19
+ from pytorch_lightning.utilities import rank_zero_info
20
+
21
+ from ldm.data.base import Txt2ImgIterableBaseDataset
22
+ from ldm.util import instantiate_from_config
23
+ import socket
24
+ from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment
25
+
26
+
27
+ def get_parser(**parser_kwargs):
28
+ def str2bool(v):
29
+ if isinstance(v, bool):
30
+ return v
31
+ if v.lower() in ("yes", "true", "t", "y", "1"):
32
+ return True
33
+ elif v.lower() in ("no", "false", "f", "n", "0"):
34
+ return False
35
+ else:
36
+ raise argparse.ArgumentTypeError("Boolean value expected.")
37
+
38
+ parser = argparse.ArgumentParser(**parser_kwargs)
39
+ parser.add_argument(
40
+ "-n",
41
+ "--name",
42
+ type=str,
43
+ const=True,
44
+ default="",
45
+ nargs="?",
46
+ help="postfix for logdir",
47
+ )
48
+ parser.add_argument(
49
+ "-r",
50
+ "--resume",
51
+ type=str,
52
+ const=True,
53
+ default="",
54
+ nargs="?",
55
+ help="resume from logdir or checkpoint in logdir",
56
+ )
57
+ parser.add_argument(
58
+ "-b",
59
+ "--base",
60
+ nargs="*",
61
+ metavar="base_config.yaml",
62
+ help="paths to base configs. Loaded from left-to-right. "
63
+ "Parameters can be overwritten or added with command-line options of the form `--key value`.",
64
+ default=["configs/stable-diffusion/v1-inference-inpaint.yaml"],
65
+ )
66
+ parser.add_argument(
67
+ "-t",
68
+ "--train",
69
+ type=str2bool,
70
+ const=True,
71
+ default=True,
72
+ nargs="?",
73
+ help="train",
74
+ )
75
+ parser.add_argument(
76
+ "--no-test",
77
+ type=str2bool,
78
+ const=True,
79
+ default=False,
80
+ nargs="?",
81
+ help="disable test",
82
+ )
83
+ parser.add_argument(
84
+ "-p",
85
+ "--project",
86
+ help="name of new or path to existing project"
87
+ )
88
+ parser.add_argument(
89
+ "-d",
90
+ "--debug",
91
+ type=str2bool,
92
+ nargs="?",
93
+ const=True,
94
+ default=False,
95
+ help="enable post-mortem debugging",
96
+ )
97
+ parser.add_argument(
98
+ "-s",
99
+ "--seed",
100
+ type=int,
101
+ default=23,
102
+ help="seed for seed_everything",
103
+ )
104
+ parser.add_argument(
105
+ "-f",
106
+ "--postfix",
107
+ type=str,
108
+ default="",
109
+ help="post-postfix for default name",
110
+ )
111
+ parser.add_argument(
112
+ "-l",
113
+ "--logdir",
114
+ type=str,
115
+ default="logs",
116
+ help="directory for logging dat shit",
117
+ )
118
+ parser.add_argument(
119
+ "--pretrained_model",
120
+ type=str,
121
+ default="",
122
+ help="path to pretrained model",
123
+ )
124
+ parser.add_argument(
125
+ "--scale_lr",
126
+ type=str2bool,
127
+ nargs="?",
128
+ const=True,
129
+ default=True,
130
+ help="scale base-lr by ngpu * batch_size * n_accumulate",
131
+ )
132
+ parser.add_argument(
133
+ "--train_from_scratch",
134
+ type=str2bool,
135
+ nargs="?",
136
+ const=True,
137
+ default=False,
138
+ help="Train from scratch",
139
+ )
140
+ return parser
141
+
142
+
143
+ def nondefault_trainer_args(opt):
144
+ parser = argparse.ArgumentParser()
145
+ parser = Trainer.add_argparse_args(parser)
146
+ args = parser.parse_args([])
147
+ return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
148
+
149
+
150
+ class WrappedDataset(Dataset):
151
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
152
+
153
+ def __init__(self, dataset):
154
+ self.data = dataset
155
+
156
+ def __len__(self):
157
+ return len(self.data)
158
+
159
+ def __getitem__(self, idx):
160
+ return self.data[idx]
161
+
162
+
163
+ def worker_init_fn(_):
164
+ worker_info = torch.utils.data.get_worker_info()
165
+
166
+ dataset = worker_info.dataset
167
+ worker_id = worker_info.id
168
+
169
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
170
+ split_size = dataset.num_records // worker_info.num_workers
171
+ # reset num_records to the true number to retain reliable length information
172
+ dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
173
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
174
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
175
+ else:
176
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
177
+
178
+
179
+ class DataModuleFromConfig(pl.LightningDataModule):
180
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
181
+ wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
182
+ shuffle_val_dataloader=False):
183
+ super().__init__()
184
+ self.batch_size = batch_size
185
+ self.dataset_configs = dict()
186
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
187
+ self.use_worker_init_fn = use_worker_init_fn
188
+ if train is not None:
189
+ self.dataset_configs["train"] = train
190
+ self.train_dataloader = self._train_dataloader
191
+ if validation is not None:
192
+ self.dataset_configs["validation"] = validation
193
+ self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
194
+ if test is not None:
195
+ self.dataset_configs["test"] = test
196
+ self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
197
+ if predict is not None:
198
+ self.dataset_configs["predict"] = predict
199
+ self.predict_dataloader = self._predict_dataloader
200
+ self.wrap = wrap
201
+
202
+ def prepare_data(self):
203
+ for data_cfg in self.dataset_configs.values():
204
+ instantiate_from_config(data_cfg)
205
+
206
+ def setup(self, stage=None):
207
+ self.datasets = dict(
208
+ (k, instantiate_from_config(self.dataset_configs[k]))
209
+ for k in self.dataset_configs)
210
+ if self.wrap:
211
+ for k in self.datasets:
212
+ self.datasets[k] = WrappedDataset(self.datasets[k])
213
+
214
+ def _train_dataloader(self):
215
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
216
+ if is_iterable_dataset or self.use_worker_init_fn:
217
+ init_fn = worker_init_fn
218
+ else:
219
+ init_fn = None
220
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
221
+ num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
222
+ worker_init_fn=init_fn)
223
+
224
+ def _val_dataloader(self, shuffle=False):
225
+ if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
226
+ init_fn = worker_init_fn
227
+ else:
228
+ init_fn = None
229
+ return DataLoader(self.datasets["validation"],
230
+ batch_size=self.batch_size,
231
+ num_workers=self.num_workers,
232
+ worker_init_fn=init_fn,
233
+ shuffle=shuffle)
234
+
235
+ def _test_dataloader(self, shuffle=False):
236
+ is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
237
+ if is_iterable_dataset or self.use_worker_init_fn:
238
+ init_fn = worker_init_fn
239
+ else:
240
+ init_fn = None
241
+
242
+ # do not shuffle dataloader for iterable dataset
243
+ shuffle = shuffle and (not is_iterable_dataset)
244
+
245
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
246
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
247
+
248
+ def _predict_dataloader(self, shuffle=False):
249
+ if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
250
+ init_fn = worker_init_fn
251
+ else:
252
+ init_fn = None
253
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
254
+ num_workers=self.num_workers, worker_init_fn=init_fn)
255
+
256
+
257
+ class SetupCallback(Callback):
258
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
259
+ super().__init__()
260
+ self.resume = resume
261
+ self.now = now
262
+ self.logdir = logdir
263
+ self.ckptdir = ckptdir
264
+ self.cfgdir = cfgdir
265
+ self.config = config
266
+ self.lightning_config = lightning_config
267
+
268
+ def on_keyboard_interrupt(self, trainer, pl_module):
269
+ if trainer.global_rank == 0:
270
+ print("Summoning checkpoint.")
271
+ if hasattr(self.config, 'lora_config'):
272
+ ckpt_path = os.path.join(self.ckptdir, "lora_last.ckpt")
273
+ from lora.lora import save_lora_weight
274
+ save_lora_weight(trainer.model, path=ckpt_path)
275
+ else:
276
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
277
+ trainer.save_checkpoint(ckpt_path)
278
+
279
+ def on_pretrain_routine_start(self, trainer, pl_module):
280
+ if trainer.global_rank == 0:
281
+ # Create logdirs and save configs
282
+ os.makedirs(self.logdir, exist_ok=True)
283
+ os.makedirs(self.ckptdir, exist_ok=True)
284
+ os.makedirs(self.cfgdir, exist_ok=True)
285
+
286
+ if "callbacks" in self.lightning_config:
287
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
288
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
289
+ print("Project config")
290
+ print(OmegaConf.to_yaml(self.config))
291
+ OmegaConf.save(self.config,
292
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
293
+
294
+ print("Lightning config")
295
+ print(OmegaConf.to_yaml(self.lightning_config))
296
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
297
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
298
+
299
+ else:
300
+ # ModelCheckpoint callback created log directory --- remove it
301
+ if not self.resume and os.path.exists(self.logdir):
302
+ dst, name = os.path.split(self.logdir)
303
+ dst = os.path.join(dst, "child_runs", name)
304
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
305
+ try:
306
+ os.rename(self.logdir, dst)
307
+ except FileNotFoundError:
308
+ pass
309
+
310
+
311
+ class ImageLogger(Callback):
312
+ def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
313
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
314
+ log_images_kwargs=None):
315
+ super().__init__()
316
+ self.rescale = rescale
317
+ self.batch_freq = batch_frequency
318
+ self.max_images = max_images
319
+ self.logger_log_images = {
320
+ pl.loggers.TestTubeLogger: self._testtube,
321
+ }
322
+ self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
323
+ if not increase_log_steps:
324
+ self.log_steps = [self.batch_freq]
325
+ self.clamp = clamp
326
+ self.disabled = disabled
327
+ self.log_on_batch_idx = log_on_batch_idx
328
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
329
+ self.log_first_step = log_first_step
330
+
331
+ @rank_zero_only
332
+ def _testtube(self, pl_module, images, batch_idx, split):
333
+ for k in images:
334
+ grid = torchvision.utils.make_grid(images[k])
335
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
336
+
337
+ tag = f"{split}/{k}"
338
+ pl_module.logger.experiment.add_image(
339
+ tag, grid,
340
+ global_step=pl_module.global_step)
341
+
342
+ @rank_zero_only
343
+ def log_local(self, save_dir, split, images,
344
+ global_step, current_epoch, batch_idx):
345
+ root = os.path.join(save_dir, "images", split)
346
+ for k in images:
347
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
348
+ if self.rescale:
349
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
350
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
351
+ grid = grid.numpy()
352
+ grid = (grid * 255).astype(np.uint8)
353
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
354
+ k,
355
+ global_step,
356
+ current_epoch,
357
+ batch_idx)
358
+ path = os.path.join(root, filename)
359
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
360
+ Image.fromarray(grid).save(path)
361
+
362
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
363
+ check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
364
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
365
+ hasattr(pl_module, "log_images") and
366
+ callable(pl_module.log_images) and
367
+ self.max_images > 0):
368
+ logger = type(pl_module.logger)
369
+
370
+ is_train = pl_module.training
371
+ if is_train:
372
+ pl_module.eval()
373
+
374
+ with torch.no_grad():
375
+ images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
376
+
377
+ for k in images:
378
+ N = min(images[k].shape[0], self.max_images)
379
+ images[k] = images[k][:N]
380
+ if isinstance(images[k], torch.Tensor):
381
+ images[k] = images[k].detach().cpu()
382
+ if self.clamp:
383
+ images[k] = torch.clamp(images[k], -1., 1.)
384
+
385
+ self.log_local(pl_module.logger.save_dir, split, images,
386
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
387
+
388
+ logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
389
+ logger_log_images(pl_module, images, pl_module.global_step, split)
390
+
391
+ if is_train:
392
+ pl_module.train()
393
+
394
+ def check_frequency(self, check_idx):
395
+ if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
396
+ check_idx > 0 or self.log_first_step):
397
+ try:
398
+ self.log_steps.pop(0)
399
+ except IndexError as e:
400
+ print(e)
401
+ pass
402
+ return True
403
+ return False
404
+
405
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
406
+ if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
407
+ self.log_img(pl_module, batch, batch_idx, split="train")
408
+
409
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
410
+ if not self.disabled and pl_module.global_step > 0:
411
+ self.log_img(pl_module, batch, batch_idx, split="val")
412
+ if hasattr(pl_module, 'calibrate_grad_norm'):
413
+ if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
414
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
415
+
416
+
417
+ class CUDACallback(Callback):
418
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
419
+ def on_train_epoch_start(self, trainer, pl_module):
420
+ # Reset the memory use counter
421
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
422
+ torch.cuda.synchronize(trainer.root_gpu)
423
+ self.start_time = time.time()
424
+
425
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
426
+ torch.cuda.synchronize(trainer.root_gpu)
427
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
428
+ epoch_time = time.time() - self.start_time
429
+
430
+ try:
431
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
432
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
433
+
434
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
435
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
436
+ except AttributeError:
437
+ pass
438
+
439
+
440
+ if __name__ == "__main__":
441
+
442
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
443
+ sys.path.append(os.getcwd())
444
+
445
+ parser = get_parser()
446
+ parser = Trainer.add_argparse_args(parser)
447
+
448
+ opt, unknown = parser.parse_known_args()
449
+ if opt.name and opt.resume:
450
+ raise ValueError(
451
+ "-n/--name and -r/--resume cannot be specified both."
452
+ "If you want to resume training in a new log folder, "
453
+ "use -n/--name in combination with --resume_from_checkpoint"
454
+ )
455
+ if opt.resume:
456
+ if not os.path.exists(opt.resume):
457
+ raise ValueError("Cannot find {}".format(opt.resume))
458
+ if os.path.isfile(opt.resume):
459
+ paths = opt.resume.split("/")
460
+ # idx = len(paths)-paths[::-1].index("logs")+1
461
+ # logdir = "/".join(paths[:idx])
462
+ logdir = "/".join(paths[:-2])
463
+ ckpt = opt.resume
464
+ else:
465
+ assert os.path.isdir(opt.resume), opt.resume
466
+ logdir = opt.resume.rstrip("/")
467
+ ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
468
+
469
+ opt.resume_from_checkpoint = ckpt
470
+ base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
471
+ opt.base = base_configs + opt.base
472
+ _tmp = logdir.split("/")
473
+ nowname = _tmp[-1]
474
+ else:
475
+ if opt.name:
476
+ name = "_" + opt.name
477
+ elif opt.base:
478
+ cfg_fname = os.path.split(opt.base[0])[-1]
479
+ cfg_name = os.path.splitext(cfg_fname)[0]
480
+ name = "_" + cfg_name
481
+ else:
482
+ name = ""
483
+ nowname = now + name + opt.postfix
484
+ logdir = os.path.join(opt.logdir, nowname)
485
+
486
+ ckptdir = os.path.join(logdir, "checkpoints")
487
+ cfgdir = os.path.join(logdir, "configs")
488
+ seed_everything(opt.seed)
489
+
490
+ # try:
491
+ # init and save configs
492
+ configs = [OmegaConf.load(cfg) for cfg in opt.base]
493
+ cli = OmegaConf.from_dotlist(unknown)
494
+ config = OmegaConf.merge(*configs, cli)
495
+ lightning_config = config.pop("lightning", OmegaConf.create())
496
+ # merge trainer cli with config
497
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
498
+ # default to ddp
499
+ trainer_config["accelerator"] = "ddp"
500
+ for k in nondefault_trainer_args(opt):
501
+ trainer_config[k] = getattr(opt, k)
502
+ if not "gpus" in trainer_config:
503
+ del trainer_config["accelerator"]
504
+ cpu = True
505
+ else:
506
+ gpuinfo = trainer_config["gpus"]
507
+ print(f"Running on GPUs {gpuinfo}")
508
+ cpu = False
509
+ trainer_opt = argparse.Namespace(**trainer_config)
510
+ lightning_config.trainer = trainer_config
511
+
512
+ # model
513
+ model = instantiate_from_config(config.model)
514
+ if not opt.resume:
515
+ if opt.train_from_scratch:
516
+ ckpt_file = torch.load(opt.pretrained_model, map_location='cpu')['state_dict']
517
+ ckpt_file = {key: value for key, value in ckpt_file.items() if not (key[:6] == 'model.')}
518
+ model.load_state_dict(ckpt_file, strict=False)
519
+ print("Train from scratch!")
520
+ else:
521
+ model.load_state_dict(torch.load(opt.pretrained_model, map_location='cpu')['state_dict'], strict=False)
522
+ print("Load Stable Diffusion v1-4!")
523
+
524
+ # lora
525
+ if hasattr(config, 'lora_config'):
526
+ model.eval()
527
+ model._requires_grad = False
528
+ from lora.lora import inject_trainable_lora_extended
529
+
530
+ params, names = inject_trainable_lora_extended(model, r=config.lora_config.rank)
531
+
532
+ model.requires_grad_(False)
533
+ for name, param in model.named_parameters():
534
+ if "diffusion_model.output_blocks" in name and "transformer_blocks" in name:
535
+ param.requires_grad = True
536
+ if "local_controlnet" in name or "pose" in name:
537
+ param.requires_grad = True
538
+ # 打开一个文件来写入模块名称
539
+ with open("module_names.txt", "w") as file:
540
+ # 遍历模型的所有模块并将名称写入文件
541
+ for name, param in model.named_parameters():
542
+ if param.requires_grad == True:
543
+ file.write(name + "\n")
544
+
545
+ # trainer and callbacks
546
+ trainer_kwargs = dict()
547
+
548
+ # default logger configs
549
+ default_logger_cfgs = {
550
+ "wandb": {
551
+ "target": "pytorch_lightning.loggers.WandbLogger",
552
+ "params": {
553
+ "name": nowname,
554
+ "save_dir": logdir,
555
+ "offline": opt.debug,
556
+ "id": nowname,
557
+ }
558
+ },
559
+ "testtube": {
560
+ "target": "pytorch_lightning.loggers.TestTubeLogger",
561
+ "params": {
562
+ "name": "testtube",
563
+ "save_dir": logdir,
564
+ }
565
+ },
566
+ }
567
+ default_logger_cfg = default_logger_cfgs["testtube"]
568
+ if "logger" in lightning_config:
569
+ logger_cfg = lightning_config.logger
570
+ else:
571
+ logger_cfg = OmegaConf.create()
572
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
573
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
574
+
575
+ # modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
576
+ # specify which metric is used to determine best models
577
+ default_modelckpt_cfg = {
578
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
579
+ "params": {
580
+ "dirpath": ckptdir,
581
+ "filename": "{epoch:06}",
582
+ "verbose": True,
583
+ "save_last": False,
584
+ "every_n_epochs": 1
585
+ }
586
+ }
587
+ if hasattr(model, "monitor"):
588
+ print(f"Monitoring {model.monitor} as checkpoint metric.")
589
+ default_modelckpt_cfg["params"]["monitor"] = model.monitor
590
+ default_modelckpt_cfg["params"]["save_top_k"] = 30
591
+
592
+ if "modelcheckpoint" in lightning_config:
593
+ modelckpt_cfg = lightning_config.modelcheckpoint
594
+ else:
595
+ modelckpt_cfg = OmegaConf.create()
596
+ modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
597
+ print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
598
+ if version.parse(pl.__version__) < version.parse('1.4.0'):
599
+ trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
600
+
601
+ # add callback which sets up log directory
602
+ default_callbacks_cfg = {
603
+ "setup_callback": {
604
+ "target": "main.SetupCallback",
605
+ "params": {
606
+ "resume": opt.resume,
607
+ "now": now,
608
+ "logdir": logdir,
609
+ "ckptdir": ckptdir,
610
+ "cfgdir": cfgdir,
611
+ "config": config,
612
+ "lightning_config": lightning_config,
613
+ }
614
+ },
615
+ "image_logger": {
616
+ "target": "main.ImageLogger",
617
+ "params": {
618
+ "batch_frequency": 500,
619
+ "max_images": 4,
620
+ "clamp": True
621
+ }
622
+ },
623
+ "learning_rate_logger": {
624
+ "target": "main.LearningRateMonitor",
625
+ "params": {
626
+ "logging_interval": "step",
627
+ # "log_momentum": True
628
+ }
629
+ },
630
+ "cuda_callback": {
631
+ "target": "main.CUDACallback"
632
+ },
633
+ }
634
+ if version.parse(pl.__version__) >= version.parse('1.4.0'):
635
+ default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
636
+
637
+ if "callbacks" in lightning_config:
638
+ callbacks_cfg = lightning_config.callbacks
639
+ else:
640
+ callbacks_cfg = OmegaConf.create()
641
+
642
+ if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
643
+ print(
644
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
645
+ default_metrics_over_trainsteps_ckpt_dict = {
646
+ 'metrics_over_trainsteps_checkpoint':
647
+ {"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
648
+ 'params': {
649
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
650
+ "filename": "{epoch:06}-{step:09}",
651
+ "verbose": True,
652
+ 'save_top_k': -1,
653
+ 'every_n_train_steps': 10000,
654
+ 'save_weights_only': True
655
+ }
656
+ }
657
+ }
658
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
659
+
660
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
661
+ if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
662
+ callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
663
+ elif 'ignore_keys_callback' in callbacks_cfg:
664
+ del callbacks_cfg['ignore_keys_callback']
665
+
666
+ trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
667
+
668
+ trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
669
+ # trainer.plugins = [MyCluster()]
670
+ trainer.logdir = logdir ###
671
+
672
+ # data
673
+ data = instantiate_from_config(config.data)
674
+ # NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
675
+ # calling these ourselves should not be necessary but it is.
676
+ # lightning still takes care of proper multiprocessing though
677
+ data.prepare_data()
678
+ data.setup()
679
+ print("#### Data #####")
680
+ for k in data.datasets:
681
+ print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
682
+
683
+ # configure learning rate
684
+ bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
685
+ if not cpu:
686
+ ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
687
+ else:
688
+ ngpu = 1
689
+ if 'accumulate_grad_batches' in lightning_config.trainer:
690
+ accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
691
+ else:
692
+ accumulate_grad_batches = 1
693
+ # if 'num_nodes' in lightning_config.trainer:
694
+ # num_nodes = lightning_config.trainer.num_nodes
695
+ # else:
696
+ num_nodes = 1
697
+ print(f"accumulate_grad_batches = {accumulate_grad_batches}")
698
+ lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
699
+ if opt.scale_lr:
700
+ model.learning_rate = accumulate_grad_batches * num_nodes * ngpu * bs * base_lr
701
+ print(
702
+ "Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_nodes) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
703
+ model.learning_rate, accumulate_grad_batches, num_nodes, ngpu, bs, base_lr))
704
+ else:
705
+ model.learning_rate = base_lr
706
+ print("++++ NOT USING LR SCALING ++++")
707
+ print(f"Setting learning rate to {model.learning_rate:.2e}")
708
+
709
+
710
+ # allow checkpointing via USR1
711
+ def melk(*args, **kwargs):
712
+ # run all checkpoint hooks
713
+ if trainer.global_rank == 0:
714
+ print("Summoning checkpoint.")
715
+ ckpt_path = os.path.join(ckptdir, "last.ckpt")
716
+ trainer.save_checkpoint(ckpt_path)
717
+
718
+
719
+ def divein(*args, **kwargs):
720
+ if trainer.global_rank == 0:
721
+ import pudb
722
+ pudb.set_trace()
723
+
724
+
725
+ import signal
726
+
727
+ signal.signal(signal.SIGUSR1, melk)
728
+ signal.signal(signal.SIGUSR2, divein)
729
+
730
+ # run
731
+ if opt.train:
732
+ try:
733
+ trainer.fit(model, data)
734
+ except Exception:
735
+ melk()
736
+ raise
737
+ if not opt.no_test and not trainer.interrupted:
738
+ trainer.test(model, data)
test.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, os, sys, glob
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ from omegaconf import OmegaConf
6
+ from PIL import Image
7
+ from torch.utils.data import DataLoader
8
+ from torchvision import transforms
9
+ from tqdm import tqdm, trange
10
+ from itertools import islice
11
+ from einops import rearrange
12
+ from torchvision.utils import make_grid
13
+ import time
14
+ from pytorch_lightning import seed_everything
15
+ from torch import autocast
16
+ from contextlib import contextmanager, nullcontext
17
+ import torchvision
18
+
19
+ from ldm.data.cp_dataset import CPDataset
20
+ from ldm.resizer import Resizer
21
+ from ldm.util import instantiate_from_config
22
+ from ldm.models.diffusion.ddim import DDIMSampler
23
+ from ldm.models.diffusion.plms import PLMSSampler
24
+ from ldm.data.deepfashions import DFPairDataset
25
+
26
+ import clip
27
+ from torchvision.transforms import Resize
28
+
29
+
30
+ def chunk(it, size):
31
+ it = iter(it)
32
+ return iter(lambda: tuple(islice(it, size)), ())
33
+
34
+
35
+ def get_tensor_clip(normalize=True, toTensor=True):
36
+ transform_list = []
37
+ if toTensor:
38
+ transform_list += [torchvision.transforms.ToTensor()]
39
+
40
+ if normalize:
41
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
42
+ (0.26862954, 0.26130258, 0.27577711))]
43
+ return torchvision.transforms.Compose(transform_list)
44
+
45
+
46
+ def numpy_to_pil(images):
47
+ """
48
+ Convert a numpy image or a batch of images to a PIL image.
49
+ """
50
+ if images.ndim == 3:
51
+ images = images[None, ...]
52
+ images = (images * 255).round().astype("uint8")
53
+ pil_images = [Image.fromarray(image) for image in images]
54
+
55
+ return pil_images
56
+
57
+
58
+ def load_model_from_config(config, ckpt, verbose=False):
59
+ print(f"Loading model from {ckpt}")
60
+ pl_sd = torch.load(ckpt, map_location="cpu")
61
+ if "global_step" in pl_sd:
62
+ print(f"Global Step: {pl_sd['global_step']}")
63
+ sd = pl_sd["state_dict"]
64
+ model = instantiate_from_config(config.model)
65
+ m, u = model.load_state_dict(sd, strict=False)
66
+ if len(m) > 0 and verbose:
67
+ print("missing keys:")
68
+ print(m)
69
+ if len(u) > 0 and verbose:
70
+ print("unexpected keys:")
71
+ print(u)
72
+
73
+ model.cuda()
74
+ model.eval()
75
+ return model
76
+
77
+
78
+ def put_watermark(img, wm_encoder=None):
79
+ if wm_encoder is not None:
80
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
81
+ img = wm_encoder.encode(img, 'dwtDct')
82
+ img = Image.fromarray(img[:, :, ::-1])
83
+ return img
84
+
85
+
86
+ def load_replacement(x):
87
+ try:
88
+ hwc = x.shape
89
+ y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
90
+ y = (np.array(y) / 255.0).astype(x.dtype)
91
+ assert y.shape == x.shape
92
+ return y
93
+ except Exception:
94
+ return x
95
+
96
+
97
+ def get_tensor(normalize=True, toTensor=True):
98
+ transform_list = []
99
+ if toTensor:
100
+ transform_list += [torchvision.transforms.ToTensor()]
101
+
102
+ if normalize:
103
+ transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
104
+ (0.5, 0.5, 0.5))]
105
+ return torchvision.transforms.Compose(transform_list)
106
+
107
+
108
+ def get_tensor_clip(normalize=True, toTensor=True):
109
+ transform_list = []
110
+ if toTensor:
111
+ transform_list += [torchvision.transforms.ToTensor()]
112
+
113
+ if normalize:
114
+ transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
115
+ (0.26862954, 0.26130258, 0.27577711))]
116
+ return torchvision.transforms.Compose(transform_list)
117
+
118
+
119
+ def main():
120
+ parser = argparse.ArgumentParser()
121
+
122
+ parser.add_argument(
123
+ "--outdir",
124
+ type=str,
125
+ nargs="?",
126
+ help="dir to write results to",
127
+ default="outputs/txt2img-samples"
128
+ )
129
+ parser.add_argument(
130
+ "--skip_grid",
131
+ action='store_true',
132
+ help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
133
+ )
134
+ parser.add_argument(
135
+ "--skip_save",
136
+ action='store_true',
137
+ help="do not save individual samples. For speed measurements.",
138
+ )
139
+ parser.add_argument(
140
+ "--gpu_id",
141
+ type=int,
142
+ default=0,
143
+ help="which gpu to use",
144
+ )
145
+ parser.add_argument(
146
+ "--ddim_steps",
147
+ type=int,
148
+ default=30,
149
+ help="number of ddim sampling steps",
150
+ )
151
+ parser.add_argument(
152
+ "--plms",
153
+ action='store_true',
154
+ help="use plms sampling",
155
+ )
156
+ parser.add_argument(
157
+ "--fixed_code",
158
+ action='store_true',
159
+ help="if enabled, uses the same starting code across samples ",
160
+ )
161
+ parser.add_argument(
162
+ "--ddim_eta",
163
+ type=float,
164
+ default=0.0,
165
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
166
+ )
167
+ parser.add_argument(
168
+ "--n_iter",
169
+ type=int,
170
+ default=2,
171
+ help="sample this often",
172
+ )
173
+ parser.add_argument(
174
+ "--H",
175
+ type=int,
176
+ default=512,
177
+ help="image height, in pixel space",
178
+ )
179
+ parser.add_argument(
180
+ "--W",
181
+ type=int,
182
+ default=512,
183
+ help="image width, in pixel space",
184
+ )
185
+ parser.add_argument(
186
+ "--n_imgs",
187
+ type=int,
188
+ default=100,
189
+ help="image width, in pixel space",
190
+ )
191
+ parser.add_argument(
192
+ "--C",
193
+ type=int,
194
+ default=4,
195
+ help="latent channels",
196
+ )
197
+ parser.add_argument(
198
+ "--f",
199
+ type=int,
200
+ default=8,
201
+ help="downsampling factor",
202
+ )
203
+ parser.add_argument(
204
+ "--n_samples",
205
+ type=int,
206
+ default=1,
207
+ help="how many samples to produce for each given reference image. A.k.a. batch size",
208
+ )
209
+ parser.add_argument(
210
+ "--n_rows",
211
+ type=int,
212
+ default=0,
213
+ help="rows in the grid (default: n_samples)",
214
+ )
215
+ parser.add_argument(
216
+ "--scale",
217
+ type=float,
218
+ default=1,
219
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
220
+ )
221
+ parser.add_argument(
222
+ "--config",
223
+ type=str,
224
+ default="",
225
+ help="path to config which constructs model",
226
+ )
227
+ parser.add_argument(
228
+ "--ckpt",
229
+ type=str,
230
+ default="",
231
+ help="path to checkpoint of model",
232
+ )
233
+ parser.add_argument(
234
+ "--seed",
235
+ type=int,
236
+ default=42,
237
+ help="the seed (for reproducible sampling)",
238
+ )
239
+ parser.add_argument(
240
+ "--precision",
241
+ type=str,
242
+ help="evaluate at this precision",
243
+ choices=["full", "autocast"],
244
+ default="autocast"
245
+ )
246
+ parser.add_argument(
247
+ "--unpaired",
248
+ action='store_true',
249
+ help="if enabled, uses the same starting code across samples "
250
+ )
251
+ parser.add_argument(
252
+ "--dataroot",
253
+ type=str,
254
+ help="path to dataroot of the dataset",
255
+ default=""
256
+ )
257
+
258
+ opt = parser.parse_args()
259
+
260
+ seed_everything(opt.seed)
261
+
262
+ device = torch.device("cuda:{}".format(opt.gpu_id)) if torch.cuda.is_available() else torch.device("cpu")
263
+ torch.cuda.set_device(device)
264
+
265
+ config = OmegaConf.load(f"{opt.config}")
266
+ version = opt.config.split('/')[-1].split('.')[0]
267
+ model = load_model_from_config(config, f"{opt.ckpt}")
268
+
269
+ # model = model.to(device)
270
+ dataset = CPDataset(opt.dataroot, opt.H, mode='test', unpaired=opt.unpaired)
271
+ loader = DataLoader(dataset, batch_size=opt.n_samples, shuffle=False, num_workers=4, pin_memory=True)
272
+ if opt.plms:
273
+ sampler = PLMSSampler(model)
274
+ else:
275
+ sampler = DDIMSampler(model)
276
+
277
+ os.makedirs(opt.outdir, exist_ok=True)
278
+ outpath = opt.outdir
279
+
280
+ result_path = os.path.join(outpath, "upper_body")
281
+ os.makedirs(result_path, exist_ok=True)
282
+
283
+ start_code = None
284
+ if opt.fixed_code:
285
+ start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
286
+
287
+ iterator = tqdm(loader, desc='Test Dataset', total=len(loader))
288
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
289
+ with torch.no_grad():
290
+ with precision_scope("cuda"):
291
+ with model.ema_scope():
292
+ for data in iterator:
293
+ mask_tensor = data['inpaint_mask']
294
+ inpaint_image = data['inpaint_image']
295
+ ref_tensor_f = data['ref_imgs_f']
296
+ ref_tensor_b = data['ref_imgs_b']
297
+ skeleton_cf = data['skeleton_cf']
298
+ skeleton_cb = data['skeleton_cb']
299
+ skeleton_p = data['skeleton_p']
300
+ order = data['order']
301
+ feat_tensor = data['warp_feat']
302
+ image_tensor = data['GT']
303
+
304
+ controlnet_cond_f = data['controlnet_cond_f']
305
+ controlnet_cond_b = data['controlnet_cond_b']
306
+
307
+ ref_tensor = ref_tensor_f
308
+ for i in range(len(order)):
309
+ if order[i] == "1" or order[i] == "2":
310
+ continue
311
+ elif order[i] == "3":
312
+ ref_tensor[i] = ref_tensor_b[i]
313
+ else:
314
+ raise ValueError("Invalid order")
315
+
316
+ # filename = data['file_name']
317
+
318
+ test_model_kwargs = {}
319
+ test_model_kwargs['inpaint_mask'] = mask_tensor.to(device)
320
+ test_model_kwargs['inpaint_image'] = inpaint_image.to(device)
321
+ feat_tensor = feat_tensor.to(device)
322
+ ref_tensor = ref_tensor.to(device)
323
+
324
+ controlnet_cond_f = controlnet_cond_f.to(device)
325
+ controlnet_cond_b = controlnet_cond_b.to(device)
326
+ skeleton_cf = skeleton_cf.to(device)
327
+ skeleton_cb = skeleton_cb.to(device)
328
+ skeleton_p = skeleton_p.to(device)
329
+
330
+ uc = None
331
+ if opt.scale != 1.0:
332
+ uc = model.learnable_vector
333
+ uc = uc.repeat(ref_tensor.size(0), 1, 1)
334
+ c = model.get_learned_conditioning(ref_tensor.to(torch.float16))
335
+ c = model.proj_out(c)
336
+
337
+ # z_gt = model.encode_first_stage(image_tensor.to(device))
338
+ # z_gt = model.get_first_stage_encoding(z_gt).detach()
339
+
340
+ z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image'])
341
+ z_inpaint = model.get_first_stage_encoding(z_inpaint).detach()
342
+ test_model_kwargs['inpaint_image'] = z_inpaint
343
+ test_model_kwargs['inpaint_mask'] = Resize([z_inpaint.shape[-2], z_inpaint.shape[-1]])(
344
+ test_model_kwargs['inpaint_mask'])
345
+
346
+ warp_feat = model.encode_first_stage(feat_tensor)
347
+ warp_feat = model.get_first_stage_encoding(warp_feat).detach()
348
+
349
+ ts = torch.full((1,), 999, device=device, dtype=torch.long)
350
+ start_code = model.q_sample(warp_feat, ts)
351
+
352
+ # local_controlnet
353
+ ehs_cf = model.pose_model(skeleton_cf)
354
+ ehs_cb = model.pose_model(skeleton_cb)
355
+ ehs_p = model.pose_model(skeleton_p)
356
+ ehs_text = torch.zeros((c.shape[0], 1, 768)).to("cuda")
357
+ # controlnet_cond = torch.cat((controlnet_cond_f, controlnet_cond_b, ehs_cf, ehs_cb, ehs_p), dim=1)
358
+ x_noisy = torch.cat(
359
+ (start_code, test_model_kwargs['inpaint_image'], test_model_kwargs['inpaint_mask']), dim=1)
360
+
361
+ down_samples_f, mid_samples_f = model.local_controlnet(x_noisy, ts,
362
+ encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_f, ehs_c=ehs_cf, ehs_p=ehs_p)
363
+ down_samples_b, mid_samples_b = model.local_controlnet(x_noisy, ts,
364
+ encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_b, ehs_c=ehs_cb, ehs_p=ehs_p)
365
+
366
+ # print(torch.max(down_samples_f[0]))
367
+ # print(torch.min(down_samples_f[0]))
368
+
369
+ # normalized_tensor = (down_samples_f[0] + 1) / 2
370
+
371
+ # # 将张量值范围从[0,1]转换到[0,255]
372
+ # scaled_tensor = normalized_tensor * 255
373
+
374
+ # # 将张量转换为NumPy数组
375
+ # numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
376
+
377
+ # # 将NumPy数组转换为PIL图像
378
+ # image = Image.fromarray(numpy_array)
379
+
380
+ # # 保存图像
381
+ # image.save("down_samples_f.jpg")
382
+
383
+ # normalized_tensor = (down_samples_b[0] + 1) / 2
384
+
385
+ # # 将张量值范围从[0,1]转换到[0,255]
386
+ # scaled_tensor = normalized_tensor * 255
387
+
388
+ # # 将张量转换为NumPy数组
389
+ # numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
390
+
391
+ # # 将NumPy数组转换为PIL图像
392
+ # image = Image.fromarray(numpy_array)
393
+
394
+ # # 保存图像
395
+ # image.save("down_samples_b.jpg")
396
+
397
+ mid_samples = mid_samples_f + mid_samples_b
398
+ down_samples = ()
399
+ for ds in range(len(down_samples_f)):
400
+ tmp = torch.cat((down_samples_f[ds], down_samples_b[ds]), dim=1)
401
+ down_samples = down_samples + (tmp,)
402
+
403
+ shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
404
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
405
+ conditioning=c,
406
+ batch_size=opt.n_samples,
407
+ shape=shape,
408
+ verbose=False,
409
+ unconditional_guidance_scale=opt.scale,
410
+ unconditional_conditioning=uc,
411
+ eta=opt.ddim_eta,
412
+ x_T=start_code,
413
+ down_samples=down_samples,
414
+ test_model_kwargs=test_model_kwargs)
415
+
416
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
417
+ x_sample_result = x_samples_ddim
418
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
419
+ x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
420
+
421
+ x_checked_image = x_samples_ddim
422
+ x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
423
+ x_source = torch.clamp((image_tensor + 1.0) / 2.0, min=0.0, max=1.0)
424
+ x_result = x_checked_image_torch * (1 - mask_tensor) + mask_tensor * x_source
425
+ # x_result = x_checked_image_torch
426
+
427
+ resize = transforms.Resize((opt.H, int(opt.H / 256 * 192)))
428
+
429
+ if not opt.skip_save:
430
+
431
+ def un_norm(x):
432
+ return (x + 1.0) / 2.0
433
+
434
+ for i, x_sample in enumerate(x_result):
435
+ filename = data['file_name'][i]
436
+ # filename = data['file_name']
437
+ save_x = resize(x_sample)
438
+ save_x = 255. * rearrange(save_x.cpu().numpy(), 'c h w -> h w c')
439
+ img = Image.fromarray(save_x.astype(np.uint8))
440
+ img.save(os.path.join(result_path, filename[:-4] + ".png"))
441
+
442
+ print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
443
+ f" \nEnjoy.")
444
+
445
+
446
+ if __name__ == "__main__":
447
+ main()
test.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CUDA_VISIBLE_DEVICES=3 python test.py --gpu_id 0 \
2
+ --ddim_steps 50 \
3
+ --outdir results/try/ \
4
+ --config configs/viton512.yaml \
5
+ --dataroot /datasets/NVG \
6
+ --ckpt checkpoints/mvg.ckpt \
7
+ --n_samples 1 \
8
+ --seed 23 \
9
+ --scale 1 \
10
+ --H 512 \
11
+ --W 384
12
+
13
+ #!/bin/bash
train.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ CUDA_VISIBLE_DEVICES=4,5 python -u main.py --logdir models/oc --pretrained_model checkpoints/model.ckpt --base configs/viton512.yaml --scale_lr False