Spaces:
Configuration error
Configuration error
Upload 8 files
Browse files- .gitattributes +35 -0
- LICENSE +437 -0
- README.md +201 -0
- environment.yaml +197 -0
- main.py +738 -0
- test.py +447 -0
- test.sh +13 -0
- train.sh +1 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<<<<<<< HEAD
|
2 |
+
# MV-VTON
|
3 |
+
|
4 |
+
PyTorch implementation of **MV-VTON: Multi-View Virtual Try-On with Diffusion Models**
|
5 |
+
|
6 |
+
[](https://arxiv.org/abs/2404.17364)
|
7 |
+
[](https://hywang2002.github.io/MV-VTON/)
|
8 |
+

|
9 |
+
[](https://creativecommons.org/licenses/by-nc-sa/4.0/)
|
10 |
+
|
11 |
+
## News
|
12 |
+
- 🔥The first multi-view virtual try-on dataset MVG is now available.
|
13 |
+
- 🔥Checkpoints on both frontal-view and multi-view virtual try-on tasks are released.
|
14 |
+
|
15 |
+
## Overview
|
16 |
+
|
17 |
+

|
18 |
+
> **Abstract:**
|
19 |
+
> The goal of image-based virtual try-on is to generate an image of the target person naturally wearing the given
|
20 |
+
> clothing. However, most existing methods solely focus on the frontal try-on using the frontal clothing. When the views
|
21 |
+
> of the clothing and person are significantly inconsistent, particularly when the person’s view is non-frontal, the
|
22 |
+
> results are unsatisfactory. To address this challenge, we introduce Multi-View Virtual Try-ON (MV-VTON), which aims to
|
23 |
+
> reconstruct the dressing results of a person from multiple views using the given clothes. On the one hand, given that
|
24 |
+
> single-view clothes provide insufficient information for MV-VTON, we instead employ two images, i.e., the frontal and
|
25 |
+
> back views of the clothing, to encompass the complete view as much as possible. On the other hand, the diffusion
|
26 |
+
> models
|
27 |
+
> that have demonstrated superior abilities are adopted to perform our MV-VTON. In particular, we propose a
|
28 |
+
> view-adaptive
|
29 |
+
> selection method where hard-selection and soft-selection are applied to the global and local clothing feature
|
30 |
+
> extraction, respectively. This ensures that the clothing features are roughly fit to the person’s view. Subsequently,
|
31 |
+
> we
|
32 |
+
> suggest a joint attention block to align and fuse clothing features with person features. Additionally, we collect a
|
33 |
+
> MV-VTON dataset, i.e., Multi-View Garment (MVG), in which each person has multiple photos with diverse views and
|
34 |
+
> poses.
|
35 |
+
> Experiments show that the proposed method not only achieves state-of-the-art results on MV-VTON task using our MVG
|
36 |
+
> dataset, but also has superiority on frontal-view virtual try-on task using VITON-HD and DressCode datasets.
|
37 |
+
|
38 |
+
## Getting Started
|
39 |
+
|
40 |
+
### Installation
|
41 |
+
|
42 |
+
1. Clone the repository
|
43 |
+
|
44 |
+
```shell
|
45 |
+
git clone https://github.com/hywang2002/MV-VTON.git
|
46 |
+
cd MV-VTON
|
47 |
+
```
|
48 |
+
|
49 |
+
2. Install Python dependencies
|
50 |
+
|
51 |
+
```shell
|
52 |
+
conda env create -f environment.yaml
|
53 |
+
conda activate mv-vton
|
54 |
+
```
|
55 |
+
|
56 |
+
3. Download the pretrained [vgg](https://drive.google.com/file/d/1rvow8jStPt8t2prDcSRlnf8yzXhrYeGo/view?usp=sharing)
|
57 |
+
checkpoint and put it in `models/vgg/` for Multi-View VTON and `Frontal-View VTON/models/vgg/` for Frontal-View VTON.
|
58 |
+
4. Download the pretrained models `mvg.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/17SC8fHE5w2g7gEtzJgRRew?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1J91PoT8A9yqHWNxkgRe6ZCnDEhN-H9O6/view?usp=sharing),
|
59 |
+
and `vitonhd.ckpt` via [Baidu Cloud](https://pan.baidu.com/s/1R2yGgm35UwTpnXPEU6-tlA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/13A0uzUY6PuvitLOqzyHzWASOh0dNXdem/view?usp=sharing), and put `mvg.ckpt` in `checkpoint/` and
|
60 |
+
put `vitonhd.ckpt`
|
61 |
+
in `Frontal-View VTON/checkpoint/`.
|
62 |
+
|
63 |
+
### Datasets
|
64 |
+
|
65 |
+
#### MVG
|
66 |
+
|
67 |
+
1. Fill `Dataset Request Form` via [Baidu Cloud](https://pan.baidu.com/s/12HAq0V4FfgpU_q8AeyZzwA?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/1zWt6HYBz7Vzaxu8rp1bwkhRoBkxbwQjw/view?usp=sharing), and
|
68 |
+
contact `cshy2mvvton@outlook.com` with this form to get MVG dataset (
|
69 |
+
Non-institutional emails (e.g. gmail.com) are not allowed. Please provide your institutional
|
70 |
+
email address.).
|
71 |
+
|
72 |
+
After these, the folder structure should look like this (the warp_feat_unpair* only included in test directory):
|
73 |
+
|
74 |
+
```
|
75 |
+
├── MVG
|
76 |
+
| ├── unpaired.txt
|
77 |
+
│ ├── [train | test]
|
78 |
+
| | ├── image-wo-bg
|
79 |
+
│ │ ├── cloth
|
80 |
+
│ │ ├── cloth-mask
|
81 |
+
│ │ ├── warp_feat
|
82 |
+
│ │ ├── warp_feat_unpair
|
83 |
+
│ │ ├── ...
|
84 |
+
```
|
85 |
+
|
86 |
+
#### VITON-HD
|
87 |
+
|
88 |
+
1. Download [VITON-HD](https://github.com/shadow2496/VITON-HD) dataset
|
89 |
+
2. Download pre-warped cloth image/mask via [Baidu Cloud](https://pan.baidu.com/s/1uQM0IOltOmbeqwdOKX5kCw?pwd=cshy) or [Google Drive](https://drive.google.com/file/d/18DTWfhxUnfg41nnwwpCKN--akC4eT9DM/view?usp=sharing) and
|
90 |
+
put
|
91 |
+
it under VITON-HD dataset.
|
92 |
+
|
93 |
+
After these, the folder structure should look like this (the unpaired-cloth* only included in test directory):
|
94 |
+
|
95 |
+
```
|
96 |
+
├── VITON-HD
|
97 |
+
| ├── test_pairs.txt
|
98 |
+
| ├── train_pairs.txt
|
99 |
+
│ ├── [train | test]
|
100 |
+
| | ├── image
|
101 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
102 |
+
│ │ ├── cloth
|
103 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
104 |
+
│ │ ├── cloth-mask
|
105 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
106 |
+
│ │ ├── cloth-warp
|
107 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
108 |
+
│ │ ├── cloth-warp-mask
|
109 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
110 |
+
│ │ ├── unpaired-cloth-warp
|
111 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
112 |
+
│ │ ├── unpaired-cloth-warp-mask
|
113 |
+
│ │ │ ├── [000006_00.jpg | 000008_00.jpg | ...]
|
114 |
+
```
|
115 |
+
|
116 |
+
### Inference
|
117 |
+
|
118 |
+
#### MVG
|
119 |
+
|
120 |
+
To test on paired settings (using `cp_dataset_mv_paired.py`), you can modify the `configs/viton512.yaml` and `main.py`,
|
121 |
+
or directly rename `cp_dataset_mv_paired.py` to `cp_dataset.py` (recommended). Then run:
|
122 |
+
|
123 |
+
```shell
|
124 |
+
sh test.sh
|
125 |
+
```
|
126 |
+
|
127 |
+
To test on unpaired settings, rename `cp_dataset_mv_unpaired.py` to `cp_dataset.py`, and do the same operation.
|
128 |
+
|
129 |
+
#### VITON-HD
|
130 |
+
|
131 |
+
To test on paired settings, input command `cd Frontal-View\ VTON/`, then directly run:
|
132 |
+
|
133 |
+
```shell
|
134 |
+
sh test.sh
|
135 |
+
```
|
136 |
+
|
137 |
+
To test on unpaired settings, input command `cd Frontal-View\ VTON/`, add `--unpaired` to `test.sh`, add then run:
|
138 |
+
|
139 |
+
```shell
|
140 |
+
sh test.sh
|
141 |
+
```
|
142 |
+
|
143 |
+
#### Metrics
|
144 |
+
|
145 |
+
We compute `LPIPS`, `SSIM`, `FID`, `KID` using the same tools in [LaDI-VTON](https://github.com/miccunifi/ladi-vton).
|
146 |
+
|
147 |
+
### Training
|
148 |
+
|
149 |
+
#### MVG
|
150 |
+
|
151 |
+
We use Paint-by-Example as initialization, please download the pretrained model
|
152 |
+
from [Google Drive](https://drive.google.com/file/d/15QzaTWsvZonJcXsNv-ilMRCYaQLhzR_i/view) and save the model to
|
153 |
+
directory `checkpoints`. Rename `cp_dataset_mv_paired.py` to `cp_dataset.py`, then run:
|
154 |
+
|
155 |
+
```shell
|
156 |
+
sh train.sh
|
157 |
+
```
|
158 |
+
|
159 |
+
#### VITON-HD
|
160 |
+
|
161 |
+
Input command `cd Frontal-View\ VTON/`, then directly run:
|
162 |
+
|
163 |
+
```shell
|
164 |
+
sh train.sh
|
165 |
+
```
|
166 |
+
|
167 |
+
## Acknowledgements
|
168 |
+
|
169 |
+
Our code is heavily borrowed from [Paint-by-Example](https://github.com/Fantasy-Studio/Paint-by-Example)
|
170 |
+
and [DCI-VTON](https://github.com/bcmi/DCI-VTON-Virtual-Try-On). We also
|
171 |
+
thank previous work [PF-AFN](https://github.com/geyuying/PF-AFN), [GP-VTON](https://github.com/xiezhy6/GP-VTON),
|
172 |
+
[LaDI-VTON](https://github.com/miccunifi/ladi-vton)
|
173 |
+
and [StableVITON](https://github.com/rlawjdghek/StableVITON).
|
174 |
+
|
175 |
+
## LICENSE
|
176 |
+
MV-VTON: Multi-View Virtual Try-On with Diffusion Models © 2024 by Haoyu Wang, Zhilu Zhang, Donglin Di, Shiliang Zhang, Wangmeng Zuo is licensed under CC BY-NC-SA 4.0
|
177 |
+
|
178 |
+
## Citation
|
179 |
+
|
180 |
+
```
|
181 |
+
@article{wang2024mv,
|
182 |
+
title={MV-VTON: Multi-View Virtual Try-On with Diffusion Models},
|
183 |
+
author={Wang, Haoyu and Zhang, Zhilu and Di, Donglin and Zhang, Shiliang and Zuo, Wangmeng},
|
184 |
+
journal={arXiv preprint arXiv:2404.17364},
|
185 |
+
year={2024}
|
186 |
+
}
|
187 |
+
```
|
188 |
+
=======
|
189 |
+
---
|
190 |
+
title: Mv Vton Demo
|
191 |
+
emoji: 👁
|
192 |
+
colorFrom: gray
|
193 |
+
colorTo: pink
|
194 |
+
sdk: gradio
|
195 |
+
sdk_version: 5.26.0
|
196 |
+
app_file: app.py
|
197 |
+
pinned: false
|
198 |
+
---
|
199 |
+
|
200 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
201 |
+
>>>>>>> 2a4541c57faf075fa9e813ae2777dfaa55fc0306
|
environment.yaml
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mv-vton
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- blas=1.0=mkl
|
9 |
+
- brotli-python=1.0.9=py38h6a678d5_7
|
10 |
+
- bzip2=1.0.8=h7b6447c_0
|
11 |
+
- ca-certificates=2023.08.22=h06a4308_0
|
12 |
+
- certifi=2023.11.17=py38h06a4308_0
|
13 |
+
- cffi=1.15.1=py38h74dc2b5_0
|
14 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
15 |
+
- cryptography=41.0.3=py38h130f0dd_0
|
16 |
+
- cudatoolkit=11.3.1=h2bc3f7f_2
|
17 |
+
- ffmpeg=4.3=hf484d3e_0
|
18 |
+
- freetype=2.12.1=h4a9f257_0
|
19 |
+
- giflib=5.2.1=h5eee18b_3
|
20 |
+
- gmp=6.2.1=h295c915_3
|
21 |
+
- gnutls=3.6.15=he1e5248_0
|
22 |
+
- idna=3.4=py38h06a4308_0
|
23 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
24 |
+
- jpeg=9e=h5eee18b_1
|
25 |
+
- lame=3.100=h7b6447c_0
|
26 |
+
- lcms2=2.12=h3be6417_0
|
27 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
28 |
+
- lerc=3.0=h295c915_0
|
29 |
+
- libdeflate=1.17=h5eee18b_1
|
30 |
+
- libffi=3.3=he6710b0_2
|
31 |
+
- libgcc-ng=11.2.0=h1234567_1
|
32 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
33 |
+
- libgfortran5=11.2.0=h1234567_1
|
34 |
+
- libgomp=11.2.0=h1234567_1
|
35 |
+
- libiconv=1.16=h7f8727e_2
|
36 |
+
- libidn2=2.3.4=h5eee18b_0
|
37 |
+
- libpng=1.6.39=h5eee18b_0
|
38 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
39 |
+
- libtasn1=4.19.0=h5eee18b_0
|
40 |
+
- libtiff=4.5.1=h6a678d5_0
|
41 |
+
- libunistring=0.9.10=h27cfd23_0
|
42 |
+
- libuv=1.44.2=h5eee18b_0
|
43 |
+
- libwebp=1.3.2=h11a3e52_0
|
44 |
+
- libwebp-base=1.3.2=h5eee18b_0
|
45 |
+
- lz4-c=1.9.4=h6a678d5_0
|
46 |
+
- mkl=2021.4.0=h06a4308_640
|
47 |
+
- mkl-service=2.4.0=py38h7f8727e_0
|
48 |
+
- mkl_fft=1.3.1=py38hd3c417c_0
|
49 |
+
- mkl_random=1.2.2=py38h51133e4_0
|
50 |
+
- ncurses=6.4=h6a678d5_0
|
51 |
+
- nettle=3.7.3=hbbd107a_1
|
52 |
+
- openh264=2.1.1=h4ff587b_0
|
53 |
+
- openjpeg=2.4.0=h3ad879b_0
|
54 |
+
- openssl=1.1.1w=h7f8727e_0
|
55 |
+
- pillow=10.0.1=py38ha6cbd5a_0
|
56 |
+
- pip=20.3.3=py38h06a4308_0
|
57 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
58 |
+
- pyopenssl=23.2.0=py38h06a4308_0
|
59 |
+
- pysocks=1.7.1=py38h06a4308_0
|
60 |
+
- python=3.8.5=h7579374_1
|
61 |
+
- pytorch=1.11.0=py3.8_cuda11.3_cudnn8.2.0_0
|
62 |
+
- pytorch-mutex=1.0=cuda
|
63 |
+
- readline=8.2=h5eee18b_0
|
64 |
+
- requests=2.31.0=py38h06a4308_0
|
65 |
+
- setuptools=68.0.0=py38h06a4308_0
|
66 |
+
- six=1.16.0=pyhd3eb1b0_1
|
67 |
+
- sqlite=3.41.2=h5eee18b_0
|
68 |
+
- tk=8.6.12=h1ccaba5_0
|
69 |
+
- torchvision=0.12.0=py38_cu113
|
70 |
+
- typing_extensions=4.7.1=py38h06a4308_0
|
71 |
+
- urllib3=1.26.18=py38h06a4308_0
|
72 |
+
- wheel=0.41.2=py38h06a4308_0
|
73 |
+
- xz=5.4.5=h5eee18b_0
|
74 |
+
- zlib=1.2.13=h5eee18b_0
|
75 |
+
- zstd=1.5.5=hc292b87_0
|
76 |
+
- pip:
|
77 |
+
- absl-py==2.0.0
|
78 |
+
- aiohttp==3.9.1
|
79 |
+
- aiosignal==1.3.1
|
80 |
+
- albumentations==0.4.3
|
81 |
+
- altair==5.2.0
|
82 |
+
- antlr4-python3-runtime==4.9.3
|
83 |
+
- async-timeout==4.0.3
|
84 |
+
- attrs==23.1.0
|
85 |
+
- av==12.0.0
|
86 |
+
- backports-zoneinfo==0.2.1
|
87 |
+
- bezier==2023.7.28
|
88 |
+
- black==24.2.0
|
89 |
+
- blinker==1.7.0
|
90 |
+
- cachetools==5.3.2
|
91 |
+
- click==8.1.7
|
92 |
+
- clip==0.2.0
|
93 |
+
- cloudpickle==3.0.0
|
94 |
+
- contourpy==1.1.1
|
95 |
+
- cupy==12.3.0
|
96 |
+
- cycler==0.12.1
|
97 |
+
- diffusers==0.20.0
|
98 |
+
- einops==0.3.0
|
99 |
+
- fastrlock==0.8.2
|
100 |
+
- filelock==3.13.1
|
101 |
+
- fonttools==4.45.1
|
102 |
+
- frozenlist==1.4.0
|
103 |
+
- fsspec==2023.10.0
|
104 |
+
- future==0.18.3
|
105 |
+
- fvcore==0.1.5.post20221221
|
106 |
+
- gitdb==4.0.11
|
107 |
+
- gitpython==3.1.40
|
108 |
+
- google-auth==2.23.4
|
109 |
+
- google-auth-oauthlib==1.0.0
|
110 |
+
- grpcio==1.59.3
|
111 |
+
- huggingface-hub==0.19.4
|
112 |
+
- hydra-core==1.3.2
|
113 |
+
- imageio==2.9.0
|
114 |
+
- imageio-ffmpeg==0.4.2
|
115 |
+
- imgaug==0.2.6
|
116 |
+
- importlib-metadata==6.8.0
|
117 |
+
- importlib-resources==6.1.1
|
118 |
+
- invisible-watermark==0.2.0
|
119 |
+
- iopath==0.1.9
|
120 |
+
- jinja2==3.1.2
|
121 |
+
- jsonschema==4.20.0
|
122 |
+
- jsonschema-specifications==2023.11.1
|
123 |
+
- kiwisolver==1.4.5
|
124 |
+
- kornia==0.6.0
|
125 |
+
- lazy-loader==0.3
|
126 |
+
- markdown==3.5.1
|
127 |
+
- markdown-it-py==3.0.0
|
128 |
+
- markupsafe==2.1.3
|
129 |
+
- matplotlib==3.7.4
|
130 |
+
- mdurl==0.1.2
|
131 |
+
- multidict==6.0.4
|
132 |
+
- mypy-extensions==1.0.0
|
133 |
+
- networkx==3.1
|
134 |
+
- numpy==1.24.4
|
135 |
+
- oauthlib==3.2.2
|
136 |
+
- omegaconf==2.3.0
|
137 |
+
- opencv-python==4.1.2.30
|
138 |
+
- opencv-python-headless==4.8.1.78
|
139 |
+
- packaging==23.2
|
140 |
+
- pandas==2.0.3
|
141 |
+
- pathspec==0.12.1
|
142 |
+
- pkgutil-resolve-name==1.3.10
|
143 |
+
- platformdirs==4.2.0
|
144 |
+
- portalocker==2.8.2
|
145 |
+
- protobuf==4.25.1
|
146 |
+
- pudb==2019.2
|
147 |
+
- pyarrow==14.0.1
|
148 |
+
- pyasn1==0.5.1
|
149 |
+
- pyasn1-modules==0.3.0
|
150 |
+
- pycocotools==2.0.7
|
151 |
+
- pydeck==0.8.1b0
|
152 |
+
- pydeprecate==0.3.1
|
153 |
+
- pygments==2.17.2
|
154 |
+
- pyparsing==3.1.1
|
155 |
+
- python-dateutil==2.8.2
|
156 |
+
- pytorch-lightning==1.4.2
|
157 |
+
- pytz==2023.3.post1
|
158 |
+
- pywavelets==1.4.1
|
159 |
+
- pyyaml==6.0.1
|
160 |
+
- referencing==0.31.1
|
161 |
+
- regex==2023.10.3
|
162 |
+
- requests-oauthlib==1.3.1
|
163 |
+
- rich==13.7.0
|
164 |
+
- rpds-py==0.13.2
|
165 |
+
- rsa==4.9
|
166 |
+
- safetensors==0.4.1
|
167 |
+
- scikit-image==0.20.0
|
168 |
+
- scipy==1.9.1
|
169 |
+
- smmap==5.0.1
|
170 |
+
- streamlit==1.28.2
|
171 |
+
- tabulate==0.9.0
|
172 |
+
- taming-transformers==0.0.1
|
173 |
+
- tenacity==8.2.3
|
174 |
+
- tensorboard==2.14.0
|
175 |
+
- tensorboard-data-server==0.7.2
|
176 |
+
- termcolor==2.4.0
|
177 |
+
- test-tube==0.7.5
|
178 |
+
- tifffile==2023.7.10
|
179 |
+
- tokenizers==0.12.1
|
180 |
+
- toml==0.10.2
|
181 |
+
- tomli==2.0.1
|
182 |
+
- toolz==0.12.0
|
183 |
+
- torch-fidelity==0.3.0
|
184 |
+
- torchmetrics==0.6.0
|
185 |
+
- tornado==6.4
|
186 |
+
- tqdm==4.66.1
|
187 |
+
- transformers==4.27.3
|
188 |
+
- tzdata==2023.3
|
189 |
+
- tzlocal==5.2
|
190 |
+
- urwid==2.2.3
|
191 |
+
- validators==0.22.0
|
192 |
+
- watchdog==3.0.0
|
193 |
+
- werkzeug==3.0.1
|
194 |
+
- yacs==0.1.8
|
195 |
+
- yarl==1.9.3
|
196 |
+
- zipp==3.17.0
|
197 |
+
prefix: /mnt/pfs-mc0p4k/cvg/team/didonglin/conda_envs/mv-vton
|
main.py
ADDED
@@ -0,0 +1,738 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, datetime, glob, importlib, csv
|
2 |
+
import numpy as np
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
|
8 |
+
sys.setrecursionlimit(10000)
|
9 |
+
from packaging import version
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from torch.utils.data import random_split, DataLoader, Dataset, Subset
|
12 |
+
from functools import partial
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from pytorch_lightning import seed_everything
|
16 |
+
from pytorch_lightning.trainer import Trainer
|
17 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
|
18 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
19 |
+
from pytorch_lightning.utilities import rank_zero_info
|
20 |
+
|
21 |
+
from ldm.data.base import Txt2ImgIterableBaseDataset
|
22 |
+
from ldm.util import instantiate_from_config
|
23 |
+
import socket
|
24 |
+
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment
|
25 |
+
|
26 |
+
|
27 |
+
def get_parser(**parser_kwargs):
|
28 |
+
def str2bool(v):
|
29 |
+
if isinstance(v, bool):
|
30 |
+
return v
|
31 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
32 |
+
return True
|
33 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
34 |
+
return False
|
35 |
+
else:
|
36 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
37 |
+
|
38 |
+
parser = argparse.ArgumentParser(**parser_kwargs)
|
39 |
+
parser.add_argument(
|
40 |
+
"-n",
|
41 |
+
"--name",
|
42 |
+
type=str,
|
43 |
+
const=True,
|
44 |
+
default="",
|
45 |
+
nargs="?",
|
46 |
+
help="postfix for logdir",
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"-r",
|
50 |
+
"--resume",
|
51 |
+
type=str,
|
52 |
+
const=True,
|
53 |
+
default="",
|
54 |
+
nargs="?",
|
55 |
+
help="resume from logdir or checkpoint in logdir",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"-b",
|
59 |
+
"--base",
|
60 |
+
nargs="*",
|
61 |
+
metavar="base_config.yaml",
|
62 |
+
help="paths to base configs. Loaded from left-to-right. "
|
63 |
+
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
|
64 |
+
default=["configs/stable-diffusion/v1-inference-inpaint.yaml"],
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"-t",
|
68 |
+
"--train",
|
69 |
+
type=str2bool,
|
70 |
+
const=True,
|
71 |
+
default=True,
|
72 |
+
nargs="?",
|
73 |
+
help="train",
|
74 |
+
)
|
75 |
+
parser.add_argument(
|
76 |
+
"--no-test",
|
77 |
+
type=str2bool,
|
78 |
+
const=True,
|
79 |
+
default=False,
|
80 |
+
nargs="?",
|
81 |
+
help="disable test",
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"-p",
|
85 |
+
"--project",
|
86 |
+
help="name of new or path to existing project"
|
87 |
+
)
|
88 |
+
parser.add_argument(
|
89 |
+
"-d",
|
90 |
+
"--debug",
|
91 |
+
type=str2bool,
|
92 |
+
nargs="?",
|
93 |
+
const=True,
|
94 |
+
default=False,
|
95 |
+
help="enable post-mortem debugging",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"-s",
|
99 |
+
"--seed",
|
100 |
+
type=int,
|
101 |
+
default=23,
|
102 |
+
help="seed for seed_everything",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"-f",
|
106 |
+
"--postfix",
|
107 |
+
type=str,
|
108 |
+
default="",
|
109 |
+
help="post-postfix for default name",
|
110 |
+
)
|
111 |
+
parser.add_argument(
|
112 |
+
"-l",
|
113 |
+
"--logdir",
|
114 |
+
type=str,
|
115 |
+
default="logs",
|
116 |
+
help="directory for logging dat shit",
|
117 |
+
)
|
118 |
+
parser.add_argument(
|
119 |
+
"--pretrained_model",
|
120 |
+
type=str,
|
121 |
+
default="",
|
122 |
+
help="path to pretrained model",
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--scale_lr",
|
126 |
+
type=str2bool,
|
127 |
+
nargs="?",
|
128 |
+
const=True,
|
129 |
+
default=True,
|
130 |
+
help="scale base-lr by ngpu * batch_size * n_accumulate",
|
131 |
+
)
|
132 |
+
parser.add_argument(
|
133 |
+
"--train_from_scratch",
|
134 |
+
type=str2bool,
|
135 |
+
nargs="?",
|
136 |
+
const=True,
|
137 |
+
default=False,
|
138 |
+
help="Train from scratch",
|
139 |
+
)
|
140 |
+
return parser
|
141 |
+
|
142 |
+
|
143 |
+
def nondefault_trainer_args(opt):
|
144 |
+
parser = argparse.ArgumentParser()
|
145 |
+
parser = Trainer.add_argparse_args(parser)
|
146 |
+
args = parser.parse_args([])
|
147 |
+
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
|
148 |
+
|
149 |
+
|
150 |
+
class WrappedDataset(Dataset):
|
151 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
152 |
+
|
153 |
+
def __init__(self, dataset):
|
154 |
+
self.data = dataset
|
155 |
+
|
156 |
+
def __len__(self):
|
157 |
+
return len(self.data)
|
158 |
+
|
159 |
+
def __getitem__(self, idx):
|
160 |
+
return self.data[idx]
|
161 |
+
|
162 |
+
|
163 |
+
def worker_init_fn(_):
|
164 |
+
worker_info = torch.utils.data.get_worker_info()
|
165 |
+
|
166 |
+
dataset = worker_info.dataset
|
167 |
+
worker_id = worker_info.id
|
168 |
+
|
169 |
+
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
170 |
+
split_size = dataset.num_records // worker_info.num_workers
|
171 |
+
# reset num_records to the true number to retain reliable length information
|
172 |
+
dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
|
173 |
+
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
174 |
+
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
175 |
+
else:
|
176 |
+
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
177 |
+
|
178 |
+
|
179 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
180 |
+
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None,
|
181 |
+
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
182 |
+
shuffle_val_dataloader=False):
|
183 |
+
super().__init__()
|
184 |
+
self.batch_size = batch_size
|
185 |
+
self.dataset_configs = dict()
|
186 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
187 |
+
self.use_worker_init_fn = use_worker_init_fn
|
188 |
+
if train is not None:
|
189 |
+
self.dataset_configs["train"] = train
|
190 |
+
self.train_dataloader = self._train_dataloader
|
191 |
+
if validation is not None:
|
192 |
+
self.dataset_configs["validation"] = validation
|
193 |
+
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
194 |
+
if test is not None:
|
195 |
+
self.dataset_configs["test"] = test
|
196 |
+
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
197 |
+
if predict is not None:
|
198 |
+
self.dataset_configs["predict"] = predict
|
199 |
+
self.predict_dataloader = self._predict_dataloader
|
200 |
+
self.wrap = wrap
|
201 |
+
|
202 |
+
def prepare_data(self):
|
203 |
+
for data_cfg in self.dataset_configs.values():
|
204 |
+
instantiate_from_config(data_cfg)
|
205 |
+
|
206 |
+
def setup(self, stage=None):
|
207 |
+
self.datasets = dict(
|
208 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
209 |
+
for k in self.dataset_configs)
|
210 |
+
if self.wrap:
|
211 |
+
for k in self.datasets:
|
212 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
213 |
+
|
214 |
+
def _train_dataloader(self):
|
215 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
216 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
217 |
+
init_fn = worker_init_fn
|
218 |
+
else:
|
219 |
+
init_fn = None
|
220 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
221 |
+
num_workers=self.num_workers, shuffle=False if is_iterable_dataset else True,
|
222 |
+
worker_init_fn=init_fn)
|
223 |
+
|
224 |
+
def _val_dataloader(self, shuffle=False):
|
225 |
+
if isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
226 |
+
init_fn = worker_init_fn
|
227 |
+
else:
|
228 |
+
init_fn = None
|
229 |
+
return DataLoader(self.datasets["validation"],
|
230 |
+
batch_size=self.batch_size,
|
231 |
+
num_workers=self.num_workers,
|
232 |
+
worker_init_fn=init_fn,
|
233 |
+
shuffle=shuffle)
|
234 |
+
|
235 |
+
def _test_dataloader(self, shuffle=False):
|
236 |
+
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset)
|
237 |
+
if is_iterable_dataset or self.use_worker_init_fn:
|
238 |
+
init_fn = worker_init_fn
|
239 |
+
else:
|
240 |
+
init_fn = None
|
241 |
+
|
242 |
+
# do not shuffle dataloader for iterable dataset
|
243 |
+
shuffle = shuffle and (not is_iterable_dataset)
|
244 |
+
|
245 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
246 |
+
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
247 |
+
|
248 |
+
def _predict_dataloader(self, shuffle=False):
|
249 |
+
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
250 |
+
init_fn = worker_init_fn
|
251 |
+
else:
|
252 |
+
init_fn = None
|
253 |
+
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
|
254 |
+
num_workers=self.num_workers, worker_init_fn=init_fn)
|
255 |
+
|
256 |
+
|
257 |
+
class SetupCallback(Callback):
|
258 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
259 |
+
super().__init__()
|
260 |
+
self.resume = resume
|
261 |
+
self.now = now
|
262 |
+
self.logdir = logdir
|
263 |
+
self.ckptdir = ckptdir
|
264 |
+
self.cfgdir = cfgdir
|
265 |
+
self.config = config
|
266 |
+
self.lightning_config = lightning_config
|
267 |
+
|
268 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
269 |
+
if trainer.global_rank == 0:
|
270 |
+
print("Summoning checkpoint.")
|
271 |
+
if hasattr(self.config, 'lora_config'):
|
272 |
+
ckpt_path = os.path.join(self.ckptdir, "lora_last.ckpt")
|
273 |
+
from lora.lora import save_lora_weight
|
274 |
+
save_lora_weight(trainer.model, path=ckpt_path)
|
275 |
+
else:
|
276 |
+
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
277 |
+
trainer.save_checkpoint(ckpt_path)
|
278 |
+
|
279 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
280 |
+
if trainer.global_rank == 0:
|
281 |
+
# Create logdirs and save configs
|
282 |
+
os.makedirs(self.logdir, exist_ok=True)
|
283 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
284 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
285 |
+
|
286 |
+
if "callbacks" in self.lightning_config:
|
287 |
+
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
288 |
+
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
289 |
+
print("Project config")
|
290 |
+
print(OmegaConf.to_yaml(self.config))
|
291 |
+
OmegaConf.save(self.config,
|
292 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
293 |
+
|
294 |
+
print("Lightning config")
|
295 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
296 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
297 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
298 |
+
|
299 |
+
else:
|
300 |
+
# ModelCheckpoint callback created log directory --- remove it
|
301 |
+
if not self.resume and os.path.exists(self.logdir):
|
302 |
+
dst, name = os.path.split(self.logdir)
|
303 |
+
dst = os.path.join(dst, "child_runs", name)
|
304 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
305 |
+
try:
|
306 |
+
os.rename(self.logdir, dst)
|
307 |
+
except FileNotFoundError:
|
308 |
+
pass
|
309 |
+
|
310 |
+
|
311 |
+
class ImageLogger(Callback):
|
312 |
+
def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True,
|
313 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
314 |
+
log_images_kwargs=None):
|
315 |
+
super().__init__()
|
316 |
+
self.rescale = rescale
|
317 |
+
self.batch_freq = batch_frequency
|
318 |
+
self.max_images = max_images
|
319 |
+
self.logger_log_images = {
|
320 |
+
pl.loggers.TestTubeLogger: self._testtube,
|
321 |
+
}
|
322 |
+
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
323 |
+
if not increase_log_steps:
|
324 |
+
self.log_steps = [self.batch_freq]
|
325 |
+
self.clamp = clamp
|
326 |
+
self.disabled = disabled
|
327 |
+
self.log_on_batch_idx = log_on_batch_idx
|
328 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
329 |
+
self.log_first_step = log_first_step
|
330 |
+
|
331 |
+
@rank_zero_only
|
332 |
+
def _testtube(self, pl_module, images, batch_idx, split):
|
333 |
+
for k in images:
|
334 |
+
grid = torchvision.utils.make_grid(images[k])
|
335 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
336 |
+
|
337 |
+
tag = f"{split}/{k}"
|
338 |
+
pl_module.logger.experiment.add_image(
|
339 |
+
tag, grid,
|
340 |
+
global_step=pl_module.global_step)
|
341 |
+
|
342 |
+
@rank_zero_only
|
343 |
+
def log_local(self, save_dir, split, images,
|
344 |
+
global_step, current_epoch, batch_idx):
|
345 |
+
root = os.path.join(save_dir, "images", split)
|
346 |
+
for k in images:
|
347 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
348 |
+
if self.rescale:
|
349 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
350 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
351 |
+
grid = grid.numpy()
|
352 |
+
grid = (grid * 255).astype(np.uint8)
|
353 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
|
354 |
+
k,
|
355 |
+
global_step,
|
356 |
+
current_epoch,
|
357 |
+
batch_idx)
|
358 |
+
path = os.path.join(root, filename)
|
359 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
360 |
+
Image.fromarray(grid).save(path)
|
361 |
+
|
362 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
363 |
+
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
|
364 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
365 |
+
hasattr(pl_module, "log_images") and
|
366 |
+
callable(pl_module.log_images) and
|
367 |
+
self.max_images > 0):
|
368 |
+
logger = type(pl_module.logger)
|
369 |
+
|
370 |
+
is_train = pl_module.training
|
371 |
+
if is_train:
|
372 |
+
pl_module.eval()
|
373 |
+
|
374 |
+
with torch.no_grad():
|
375 |
+
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
376 |
+
|
377 |
+
for k in images:
|
378 |
+
N = min(images[k].shape[0], self.max_images)
|
379 |
+
images[k] = images[k][:N]
|
380 |
+
if isinstance(images[k], torch.Tensor):
|
381 |
+
images[k] = images[k].detach().cpu()
|
382 |
+
if self.clamp:
|
383 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
384 |
+
|
385 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
386 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
387 |
+
|
388 |
+
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
|
389 |
+
logger_log_images(pl_module, images, pl_module.global_step, split)
|
390 |
+
|
391 |
+
if is_train:
|
392 |
+
pl_module.train()
|
393 |
+
|
394 |
+
def check_frequency(self, check_idx):
|
395 |
+
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
|
396 |
+
check_idx > 0 or self.log_first_step):
|
397 |
+
try:
|
398 |
+
self.log_steps.pop(0)
|
399 |
+
except IndexError as e:
|
400 |
+
print(e)
|
401 |
+
pass
|
402 |
+
return True
|
403 |
+
return False
|
404 |
+
|
405 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
406 |
+
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
407 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
408 |
+
|
409 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
410 |
+
if not self.disabled and pl_module.global_step > 0:
|
411 |
+
self.log_img(pl_module, batch, batch_idx, split="val")
|
412 |
+
if hasattr(pl_module, 'calibrate_grad_norm'):
|
413 |
+
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
|
414 |
+
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
415 |
+
|
416 |
+
|
417 |
+
class CUDACallback(Callback):
|
418 |
+
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
419 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
420 |
+
# Reset the memory use counter
|
421 |
+
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
422 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
423 |
+
self.start_time = time.time()
|
424 |
+
|
425 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
426 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
427 |
+
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
428 |
+
epoch_time = time.time() - self.start_time
|
429 |
+
|
430 |
+
try:
|
431 |
+
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
432 |
+
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
433 |
+
|
434 |
+
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
435 |
+
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
436 |
+
except AttributeError:
|
437 |
+
pass
|
438 |
+
|
439 |
+
|
440 |
+
if __name__ == "__main__":
|
441 |
+
|
442 |
+
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
443 |
+
sys.path.append(os.getcwd())
|
444 |
+
|
445 |
+
parser = get_parser()
|
446 |
+
parser = Trainer.add_argparse_args(parser)
|
447 |
+
|
448 |
+
opt, unknown = parser.parse_known_args()
|
449 |
+
if opt.name and opt.resume:
|
450 |
+
raise ValueError(
|
451 |
+
"-n/--name and -r/--resume cannot be specified both."
|
452 |
+
"If you want to resume training in a new log folder, "
|
453 |
+
"use -n/--name in combination with --resume_from_checkpoint"
|
454 |
+
)
|
455 |
+
if opt.resume:
|
456 |
+
if not os.path.exists(opt.resume):
|
457 |
+
raise ValueError("Cannot find {}".format(opt.resume))
|
458 |
+
if os.path.isfile(opt.resume):
|
459 |
+
paths = opt.resume.split("/")
|
460 |
+
# idx = len(paths)-paths[::-1].index("logs")+1
|
461 |
+
# logdir = "/".join(paths[:idx])
|
462 |
+
logdir = "/".join(paths[:-2])
|
463 |
+
ckpt = opt.resume
|
464 |
+
else:
|
465 |
+
assert os.path.isdir(opt.resume), opt.resume
|
466 |
+
logdir = opt.resume.rstrip("/")
|
467 |
+
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
|
468 |
+
|
469 |
+
opt.resume_from_checkpoint = ckpt
|
470 |
+
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
|
471 |
+
opt.base = base_configs + opt.base
|
472 |
+
_tmp = logdir.split("/")
|
473 |
+
nowname = _tmp[-1]
|
474 |
+
else:
|
475 |
+
if opt.name:
|
476 |
+
name = "_" + opt.name
|
477 |
+
elif opt.base:
|
478 |
+
cfg_fname = os.path.split(opt.base[0])[-1]
|
479 |
+
cfg_name = os.path.splitext(cfg_fname)[0]
|
480 |
+
name = "_" + cfg_name
|
481 |
+
else:
|
482 |
+
name = ""
|
483 |
+
nowname = now + name + opt.postfix
|
484 |
+
logdir = os.path.join(opt.logdir, nowname)
|
485 |
+
|
486 |
+
ckptdir = os.path.join(logdir, "checkpoints")
|
487 |
+
cfgdir = os.path.join(logdir, "configs")
|
488 |
+
seed_everything(opt.seed)
|
489 |
+
|
490 |
+
# try:
|
491 |
+
# init and save configs
|
492 |
+
configs = [OmegaConf.load(cfg) for cfg in opt.base]
|
493 |
+
cli = OmegaConf.from_dotlist(unknown)
|
494 |
+
config = OmegaConf.merge(*configs, cli)
|
495 |
+
lightning_config = config.pop("lightning", OmegaConf.create())
|
496 |
+
# merge trainer cli with config
|
497 |
+
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
498 |
+
# default to ddp
|
499 |
+
trainer_config["accelerator"] = "ddp"
|
500 |
+
for k in nondefault_trainer_args(opt):
|
501 |
+
trainer_config[k] = getattr(opt, k)
|
502 |
+
if not "gpus" in trainer_config:
|
503 |
+
del trainer_config["accelerator"]
|
504 |
+
cpu = True
|
505 |
+
else:
|
506 |
+
gpuinfo = trainer_config["gpus"]
|
507 |
+
print(f"Running on GPUs {gpuinfo}")
|
508 |
+
cpu = False
|
509 |
+
trainer_opt = argparse.Namespace(**trainer_config)
|
510 |
+
lightning_config.trainer = trainer_config
|
511 |
+
|
512 |
+
# model
|
513 |
+
model = instantiate_from_config(config.model)
|
514 |
+
if not opt.resume:
|
515 |
+
if opt.train_from_scratch:
|
516 |
+
ckpt_file = torch.load(opt.pretrained_model, map_location='cpu')['state_dict']
|
517 |
+
ckpt_file = {key: value for key, value in ckpt_file.items() if not (key[:6] == 'model.')}
|
518 |
+
model.load_state_dict(ckpt_file, strict=False)
|
519 |
+
print("Train from scratch!")
|
520 |
+
else:
|
521 |
+
model.load_state_dict(torch.load(opt.pretrained_model, map_location='cpu')['state_dict'], strict=False)
|
522 |
+
print("Load Stable Diffusion v1-4!")
|
523 |
+
|
524 |
+
# lora
|
525 |
+
if hasattr(config, 'lora_config'):
|
526 |
+
model.eval()
|
527 |
+
model._requires_grad = False
|
528 |
+
from lora.lora import inject_trainable_lora_extended
|
529 |
+
|
530 |
+
params, names = inject_trainable_lora_extended(model, r=config.lora_config.rank)
|
531 |
+
|
532 |
+
model.requires_grad_(False)
|
533 |
+
for name, param in model.named_parameters():
|
534 |
+
if "diffusion_model.output_blocks" in name and "transformer_blocks" in name:
|
535 |
+
param.requires_grad = True
|
536 |
+
if "local_controlnet" in name or "pose" in name:
|
537 |
+
param.requires_grad = True
|
538 |
+
# 打开一个文件来写入模块名称
|
539 |
+
with open("module_names.txt", "w") as file:
|
540 |
+
# 遍历模型的所有模块并将名称写入文件
|
541 |
+
for name, param in model.named_parameters():
|
542 |
+
if param.requires_grad == True:
|
543 |
+
file.write(name + "\n")
|
544 |
+
|
545 |
+
# trainer and callbacks
|
546 |
+
trainer_kwargs = dict()
|
547 |
+
|
548 |
+
# default logger configs
|
549 |
+
default_logger_cfgs = {
|
550 |
+
"wandb": {
|
551 |
+
"target": "pytorch_lightning.loggers.WandbLogger",
|
552 |
+
"params": {
|
553 |
+
"name": nowname,
|
554 |
+
"save_dir": logdir,
|
555 |
+
"offline": opt.debug,
|
556 |
+
"id": nowname,
|
557 |
+
}
|
558 |
+
},
|
559 |
+
"testtube": {
|
560 |
+
"target": "pytorch_lightning.loggers.TestTubeLogger",
|
561 |
+
"params": {
|
562 |
+
"name": "testtube",
|
563 |
+
"save_dir": logdir,
|
564 |
+
}
|
565 |
+
},
|
566 |
+
}
|
567 |
+
default_logger_cfg = default_logger_cfgs["testtube"]
|
568 |
+
if "logger" in lightning_config:
|
569 |
+
logger_cfg = lightning_config.logger
|
570 |
+
else:
|
571 |
+
logger_cfg = OmegaConf.create()
|
572 |
+
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
573 |
+
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
574 |
+
|
575 |
+
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
|
576 |
+
# specify which metric is used to determine best models
|
577 |
+
default_modelckpt_cfg = {
|
578 |
+
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
579 |
+
"params": {
|
580 |
+
"dirpath": ckptdir,
|
581 |
+
"filename": "{epoch:06}",
|
582 |
+
"verbose": True,
|
583 |
+
"save_last": False,
|
584 |
+
"every_n_epochs": 1
|
585 |
+
}
|
586 |
+
}
|
587 |
+
if hasattr(model, "monitor"):
|
588 |
+
print(f"Monitoring {model.monitor} as checkpoint metric.")
|
589 |
+
default_modelckpt_cfg["params"]["monitor"] = model.monitor
|
590 |
+
default_modelckpt_cfg["params"]["save_top_k"] = 30
|
591 |
+
|
592 |
+
if "modelcheckpoint" in lightning_config:
|
593 |
+
modelckpt_cfg = lightning_config.modelcheckpoint
|
594 |
+
else:
|
595 |
+
modelckpt_cfg = OmegaConf.create()
|
596 |
+
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
|
597 |
+
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
|
598 |
+
if version.parse(pl.__version__) < version.parse('1.4.0'):
|
599 |
+
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
|
600 |
+
|
601 |
+
# add callback which sets up log directory
|
602 |
+
default_callbacks_cfg = {
|
603 |
+
"setup_callback": {
|
604 |
+
"target": "main.SetupCallback",
|
605 |
+
"params": {
|
606 |
+
"resume": opt.resume,
|
607 |
+
"now": now,
|
608 |
+
"logdir": logdir,
|
609 |
+
"ckptdir": ckptdir,
|
610 |
+
"cfgdir": cfgdir,
|
611 |
+
"config": config,
|
612 |
+
"lightning_config": lightning_config,
|
613 |
+
}
|
614 |
+
},
|
615 |
+
"image_logger": {
|
616 |
+
"target": "main.ImageLogger",
|
617 |
+
"params": {
|
618 |
+
"batch_frequency": 500,
|
619 |
+
"max_images": 4,
|
620 |
+
"clamp": True
|
621 |
+
}
|
622 |
+
},
|
623 |
+
"learning_rate_logger": {
|
624 |
+
"target": "main.LearningRateMonitor",
|
625 |
+
"params": {
|
626 |
+
"logging_interval": "step",
|
627 |
+
# "log_momentum": True
|
628 |
+
}
|
629 |
+
},
|
630 |
+
"cuda_callback": {
|
631 |
+
"target": "main.CUDACallback"
|
632 |
+
},
|
633 |
+
}
|
634 |
+
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
635 |
+
default_callbacks_cfg.update({'checkpoint_callback': modelckpt_cfg})
|
636 |
+
|
637 |
+
if "callbacks" in lightning_config:
|
638 |
+
callbacks_cfg = lightning_config.callbacks
|
639 |
+
else:
|
640 |
+
callbacks_cfg = OmegaConf.create()
|
641 |
+
|
642 |
+
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
|
643 |
+
print(
|
644 |
+
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.')
|
645 |
+
default_metrics_over_trainsteps_ckpt_dict = {
|
646 |
+
'metrics_over_trainsteps_checkpoint':
|
647 |
+
{"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
648 |
+
'params': {
|
649 |
+
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
650 |
+
"filename": "{epoch:06}-{step:09}",
|
651 |
+
"verbose": True,
|
652 |
+
'save_top_k': -1,
|
653 |
+
'every_n_train_steps': 10000,
|
654 |
+
'save_weights_only': True
|
655 |
+
}
|
656 |
+
}
|
657 |
+
}
|
658 |
+
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
659 |
+
|
660 |
+
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
661 |
+
if 'ignore_keys_callback' in callbacks_cfg and hasattr(trainer_opt, 'resume_from_checkpoint'):
|
662 |
+
callbacks_cfg.ignore_keys_callback.params['ckpt_path'] = trainer_opt.resume_from_checkpoint
|
663 |
+
elif 'ignore_keys_callback' in callbacks_cfg:
|
664 |
+
del callbacks_cfg['ignore_keys_callback']
|
665 |
+
|
666 |
+
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
667 |
+
|
668 |
+
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
669 |
+
# trainer.plugins = [MyCluster()]
|
670 |
+
trainer.logdir = logdir ###
|
671 |
+
|
672 |
+
# data
|
673 |
+
data = instantiate_from_config(config.data)
|
674 |
+
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
|
675 |
+
# calling these ourselves should not be necessary but it is.
|
676 |
+
# lightning still takes care of proper multiprocessing though
|
677 |
+
data.prepare_data()
|
678 |
+
data.setup()
|
679 |
+
print("#### Data #####")
|
680 |
+
for k in data.datasets:
|
681 |
+
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
|
682 |
+
|
683 |
+
# configure learning rate
|
684 |
+
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
|
685 |
+
if not cpu:
|
686 |
+
ngpu = len(lightning_config.trainer.gpus.strip(",").split(','))
|
687 |
+
else:
|
688 |
+
ngpu = 1
|
689 |
+
if 'accumulate_grad_batches' in lightning_config.trainer:
|
690 |
+
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
|
691 |
+
else:
|
692 |
+
accumulate_grad_batches = 1
|
693 |
+
# if 'num_nodes' in lightning_config.trainer:
|
694 |
+
# num_nodes = lightning_config.trainer.num_nodes
|
695 |
+
# else:
|
696 |
+
num_nodes = 1
|
697 |
+
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
|
698 |
+
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
|
699 |
+
if opt.scale_lr:
|
700 |
+
model.learning_rate = accumulate_grad_batches * num_nodes * ngpu * bs * base_lr
|
701 |
+
print(
|
702 |
+
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_nodes) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
|
703 |
+
model.learning_rate, accumulate_grad_batches, num_nodes, ngpu, bs, base_lr))
|
704 |
+
else:
|
705 |
+
model.learning_rate = base_lr
|
706 |
+
print("++++ NOT USING LR SCALING ++++")
|
707 |
+
print(f"Setting learning rate to {model.learning_rate:.2e}")
|
708 |
+
|
709 |
+
|
710 |
+
# allow checkpointing via USR1
|
711 |
+
def melk(*args, **kwargs):
|
712 |
+
# run all checkpoint hooks
|
713 |
+
if trainer.global_rank == 0:
|
714 |
+
print("Summoning checkpoint.")
|
715 |
+
ckpt_path = os.path.join(ckptdir, "last.ckpt")
|
716 |
+
trainer.save_checkpoint(ckpt_path)
|
717 |
+
|
718 |
+
|
719 |
+
def divein(*args, **kwargs):
|
720 |
+
if trainer.global_rank == 0:
|
721 |
+
import pudb
|
722 |
+
pudb.set_trace()
|
723 |
+
|
724 |
+
|
725 |
+
import signal
|
726 |
+
|
727 |
+
signal.signal(signal.SIGUSR1, melk)
|
728 |
+
signal.signal(signal.SIGUSR2, divein)
|
729 |
+
|
730 |
+
# run
|
731 |
+
if opt.train:
|
732 |
+
try:
|
733 |
+
trainer.fit(model, data)
|
734 |
+
except Exception:
|
735 |
+
melk()
|
736 |
+
raise
|
737 |
+
if not opt.no_test and not trainer.interrupted:
|
738 |
+
trainer.test(model, data)
|
test.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse, os, sys, glob
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision import transforms
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
from itertools import islice
|
11 |
+
from einops import rearrange
|
12 |
+
from torchvision.utils import make_grid
|
13 |
+
import time
|
14 |
+
from pytorch_lightning import seed_everything
|
15 |
+
from torch import autocast
|
16 |
+
from contextlib import contextmanager, nullcontext
|
17 |
+
import torchvision
|
18 |
+
|
19 |
+
from ldm.data.cp_dataset import CPDataset
|
20 |
+
from ldm.resizer import Resizer
|
21 |
+
from ldm.util import instantiate_from_config
|
22 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
23 |
+
from ldm.models.diffusion.plms import PLMSSampler
|
24 |
+
from ldm.data.deepfashions import DFPairDataset
|
25 |
+
|
26 |
+
import clip
|
27 |
+
from torchvision.transforms import Resize
|
28 |
+
|
29 |
+
|
30 |
+
def chunk(it, size):
|
31 |
+
it = iter(it)
|
32 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
33 |
+
|
34 |
+
|
35 |
+
def get_tensor_clip(normalize=True, toTensor=True):
|
36 |
+
transform_list = []
|
37 |
+
if toTensor:
|
38 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
39 |
+
|
40 |
+
if normalize:
|
41 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
42 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
43 |
+
return torchvision.transforms.Compose(transform_list)
|
44 |
+
|
45 |
+
|
46 |
+
def numpy_to_pil(images):
|
47 |
+
"""
|
48 |
+
Convert a numpy image or a batch of images to a PIL image.
|
49 |
+
"""
|
50 |
+
if images.ndim == 3:
|
51 |
+
images = images[None, ...]
|
52 |
+
images = (images * 255).round().astype("uint8")
|
53 |
+
pil_images = [Image.fromarray(image) for image in images]
|
54 |
+
|
55 |
+
return pil_images
|
56 |
+
|
57 |
+
|
58 |
+
def load_model_from_config(config, ckpt, verbose=False):
|
59 |
+
print(f"Loading model from {ckpt}")
|
60 |
+
pl_sd = torch.load(ckpt, map_location="cpu")
|
61 |
+
if "global_step" in pl_sd:
|
62 |
+
print(f"Global Step: {pl_sd['global_step']}")
|
63 |
+
sd = pl_sd["state_dict"]
|
64 |
+
model = instantiate_from_config(config.model)
|
65 |
+
m, u = model.load_state_dict(sd, strict=False)
|
66 |
+
if len(m) > 0 and verbose:
|
67 |
+
print("missing keys:")
|
68 |
+
print(m)
|
69 |
+
if len(u) > 0 and verbose:
|
70 |
+
print("unexpected keys:")
|
71 |
+
print(u)
|
72 |
+
|
73 |
+
model.cuda()
|
74 |
+
model.eval()
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def put_watermark(img, wm_encoder=None):
|
79 |
+
if wm_encoder is not None:
|
80 |
+
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
81 |
+
img = wm_encoder.encode(img, 'dwtDct')
|
82 |
+
img = Image.fromarray(img[:, :, ::-1])
|
83 |
+
return img
|
84 |
+
|
85 |
+
|
86 |
+
def load_replacement(x):
|
87 |
+
try:
|
88 |
+
hwc = x.shape
|
89 |
+
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
90 |
+
y = (np.array(y) / 255.0).astype(x.dtype)
|
91 |
+
assert y.shape == x.shape
|
92 |
+
return y
|
93 |
+
except Exception:
|
94 |
+
return x
|
95 |
+
|
96 |
+
|
97 |
+
def get_tensor(normalize=True, toTensor=True):
|
98 |
+
transform_list = []
|
99 |
+
if toTensor:
|
100 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
101 |
+
|
102 |
+
if normalize:
|
103 |
+
transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5),
|
104 |
+
(0.5, 0.5, 0.5))]
|
105 |
+
return torchvision.transforms.Compose(transform_list)
|
106 |
+
|
107 |
+
|
108 |
+
def get_tensor_clip(normalize=True, toTensor=True):
|
109 |
+
transform_list = []
|
110 |
+
if toTensor:
|
111 |
+
transform_list += [torchvision.transforms.ToTensor()]
|
112 |
+
|
113 |
+
if normalize:
|
114 |
+
transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
115 |
+
(0.26862954, 0.26130258, 0.27577711))]
|
116 |
+
return torchvision.transforms.Compose(transform_list)
|
117 |
+
|
118 |
+
|
119 |
+
def main():
|
120 |
+
parser = argparse.ArgumentParser()
|
121 |
+
|
122 |
+
parser.add_argument(
|
123 |
+
"--outdir",
|
124 |
+
type=str,
|
125 |
+
nargs="?",
|
126 |
+
help="dir to write results to",
|
127 |
+
default="outputs/txt2img-samples"
|
128 |
+
)
|
129 |
+
parser.add_argument(
|
130 |
+
"--skip_grid",
|
131 |
+
action='store_true',
|
132 |
+
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
|
133 |
+
)
|
134 |
+
parser.add_argument(
|
135 |
+
"--skip_save",
|
136 |
+
action='store_true',
|
137 |
+
help="do not save individual samples. For speed measurements.",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--gpu_id",
|
141 |
+
type=int,
|
142 |
+
default=0,
|
143 |
+
help="which gpu to use",
|
144 |
+
)
|
145 |
+
parser.add_argument(
|
146 |
+
"--ddim_steps",
|
147 |
+
type=int,
|
148 |
+
default=30,
|
149 |
+
help="number of ddim sampling steps",
|
150 |
+
)
|
151 |
+
parser.add_argument(
|
152 |
+
"--plms",
|
153 |
+
action='store_true',
|
154 |
+
help="use plms sampling",
|
155 |
+
)
|
156 |
+
parser.add_argument(
|
157 |
+
"--fixed_code",
|
158 |
+
action='store_true',
|
159 |
+
help="if enabled, uses the same starting code across samples ",
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--ddim_eta",
|
163 |
+
type=float,
|
164 |
+
default=0.0,
|
165 |
+
help="ddim eta (eta=0.0 corresponds to deterministic sampling",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--n_iter",
|
169 |
+
type=int,
|
170 |
+
default=2,
|
171 |
+
help="sample this often",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--H",
|
175 |
+
type=int,
|
176 |
+
default=512,
|
177 |
+
help="image height, in pixel space",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--W",
|
181 |
+
type=int,
|
182 |
+
default=512,
|
183 |
+
help="image width, in pixel space",
|
184 |
+
)
|
185 |
+
parser.add_argument(
|
186 |
+
"--n_imgs",
|
187 |
+
type=int,
|
188 |
+
default=100,
|
189 |
+
help="image width, in pixel space",
|
190 |
+
)
|
191 |
+
parser.add_argument(
|
192 |
+
"--C",
|
193 |
+
type=int,
|
194 |
+
default=4,
|
195 |
+
help="latent channels",
|
196 |
+
)
|
197 |
+
parser.add_argument(
|
198 |
+
"--f",
|
199 |
+
type=int,
|
200 |
+
default=8,
|
201 |
+
help="downsampling factor",
|
202 |
+
)
|
203 |
+
parser.add_argument(
|
204 |
+
"--n_samples",
|
205 |
+
type=int,
|
206 |
+
default=1,
|
207 |
+
help="how many samples to produce for each given reference image. A.k.a. batch size",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--n_rows",
|
211 |
+
type=int,
|
212 |
+
default=0,
|
213 |
+
help="rows in the grid (default: n_samples)",
|
214 |
+
)
|
215 |
+
parser.add_argument(
|
216 |
+
"--scale",
|
217 |
+
type=float,
|
218 |
+
default=1,
|
219 |
+
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
|
220 |
+
)
|
221 |
+
parser.add_argument(
|
222 |
+
"--config",
|
223 |
+
type=str,
|
224 |
+
default="",
|
225 |
+
help="path to config which constructs model",
|
226 |
+
)
|
227 |
+
parser.add_argument(
|
228 |
+
"--ckpt",
|
229 |
+
type=str,
|
230 |
+
default="",
|
231 |
+
help="path to checkpoint of model",
|
232 |
+
)
|
233 |
+
parser.add_argument(
|
234 |
+
"--seed",
|
235 |
+
type=int,
|
236 |
+
default=42,
|
237 |
+
help="the seed (for reproducible sampling)",
|
238 |
+
)
|
239 |
+
parser.add_argument(
|
240 |
+
"--precision",
|
241 |
+
type=str,
|
242 |
+
help="evaluate at this precision",
|
243 |
+
choices=["full", "autocast"],
|
244 |
+
default="autocast"
|
245 |
+
)
|
246 |
+
parser.add_argument(
|
247 |
+
"--unpaired",
|
248 |
+
action='store_true',
|
249 |
+
help="if enabled, uses the same starting code across samples "
|
250 |
+
)
|
251 |
+
parser.add_argument(
|
252 |
+
"--dataroot",
|
253 |
+
type=str,
|
254 |
+
help="path to dataroot of the dataset",
|
255 |
+
default=""
|
256 |
+
)
|
257 |
+
|
258 |
+
opt = parser.parse_args()
|
259 |
+
|
260 |
+
seed_everything(opt.seed)
|
261 |
+
|
262 |
+
device = torch.device("cuda:{}".format(opt.gpu_id)) if torch.cuda.is_available() else torch.device("cpu")
|
263 |
+
torch.cuda.set_device(device)
|
264 |
+
|
265 |
+
config = OmegaConf.load(f"{opt.config}")
|
266 |
+
version = opt.config.split('/')[-1].split('.')[0]
|
267 |
+
model = load_model_from_config(config, f"{opt.ckpt}")
|
268 |
+
|
269 |
+
# model = model.to(device)
|
270 |
+
dataset = CPDataset(opt.dataroot, opt.H, mode='test', unpaired=opt.unpaired)
|
271 |
+
loader = DataLoader(dataset, batch_size=opt.n_samples, shuffle=False, num_workers=4, pin_memory=True)
|
272 |
+
if opt.plms:
|
273 |
+
sampler = PLMSSampler(model)
|
274 |
+
else:
|
275 |
+
sampler = DDIMSampler(model)
|
276 |
+
|
277 |
+
os.makedirs(opt.outdir, exist_ok=True)
|
278 |
+
outpath = opt.outdir
|
279 |
+
|
280 |
+
result_path = os.path.join(outpath, "upper_body")
|
281 |
+
os.makedirs(result_path, exist_ok=True)
|
282 |
+
|
283 |
+
start_code = None
|
284 |
+
if opt.fixed_code:
|
285 |
+
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
|
286 |
+
|
287 |
+
iterator = tqdm(loader, desc='Test Dataset', total=len(loader))
|
288 |
+
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
289 |
+
with torch.no_grad():
|
290 |
+
with precision_scope("cuda"):
|
291 |
+
with model.ema_scope():
|
292 |
+
for data in iterator:
|
293 |
+
mask_tensor = data['inpaint_mask']
|
294 |
+
inpaint_image = data['inpaint_image']
|
295 |
+
ref_tensor_f = data['ref_imgs_f']
|
296 |
+
ref_tensor_b = data['ref_imgs_b']
|
297 |
+
skeleton_cf = data['skeleton_cf']
|
298 |
+
skeleton_cb = data['skeleton_cb']
|
299 |
+
skeleton_p = data['skeleton_p']
|
300 |
+
order = data['order']
|
301 |
+
feat_tensor = data['warp_feat']
|
302 |
+
image_tensor = data['GT']
|
303 |
+
|
304 |
+
controlnet_cond_f = data['controlnet_cond_f']
|
305 |
+
controlnet_cond_b = data['controlnet_cond_b']
|
306 |
+
|
307 |
+
ref_tensor = ref_tensor_f
|
308 |
+
for i in range(len(order)):
|
309 |
+
if order[i] == "1" or order[i] == "2":
|
310 |
+
continue
|
311 |
+
elif order[i] == "3":
|
312 |
+
ref_tensor[i] = ref_tensor_b[i]
|
313 |
+
else:
|
314 |
+
raise ValueError("Invalid order")
|
315 |
+
|
316 |
+
# filename = data['file_name']
|
317 |
+
|
318 |
+
test_model_kwargs = {}
|
319 |
+
test_model_kwargs['inpaint_mask'] = mask_tensor.to(device)
|
320 |
+
test_model_kwargs['inpaint_image'] = inpaint_image.to(device)
|
321 |
+
feat_tensor = feat_tensor.to(device)
|
322 |
+
ref_tensor = ref_tensor.to(device)
|
323 |
+
|
324 |
+
controlnet_cond_f = controlnet_cond_f.to(device)
|
325 |
+
controlnet_cond_b = controlnet_cond_b.to(device)
|
326 |
+
skeleton_cf = skeleton_cf.to(device)
|
327 |
+
skeleton_cb = skeleton_cb.to(device)
|
328 |
+
skeleton_p = skeleton_p.to(device)
|
329 |
+
|
330 |
+
uc = None
|
331 |
+
if opt.scale != 1.0:
|
332 |
+
uc = model.learnable_vector
|
333 |
+
uc = uc.repeat(ref_tensor.size(0), 1, 1)
|
334 |
+
c = model.get_learned_conditioning(ref_tensor.to(torch.float16))
|
335 |
+
c = model.proj_out(c)
|
336 |
+
|
337 |
+
# z_gt = model.encode_first_stage(image_tensor.to(device))
|
338 |
+
# z_gt = model.get_first_stage_encoding(z_gt).detach()
|
339 |
+
|
340 |
+
z_inpaint = model.encode_first_stage(test_model_kwargs['inpaint_image'])
|
341 |
+
z_inpaint = model.get_first_stage_encoding(z_inpaint).detach()
|
342 |
+
test_model_kwargs['inpaint_image'] = z_inpaint
|
343 |
+
test_model_kwargs['inpaint_mask'] = Resize([z_inpaint.shape[-2], z_inpaint.shape[-1]])(
|
344 |
+
test_model_kwargs['inpaint_mask'])
|
345 |
+
|
346 |
+
warp_feat = model.encode_first_stage(feat_tensor)
|
347 |
+
warp_feat = model.get_first_stage_encoding(warp_feat).detach()
|
348 |
+
|
349 |
+
ts = torch.full((1,), 999, device=device, dtype=torch.long)
|
350 |
+
start_code = model.q_sample(warp_feat, ts)
|
351 |
+
|
352 |
+
# local_controlnet
|
353 |
+
ehs_cf = model.pose_model(skeleton_cf)
|
354 |
+
ehs_cb = model.pose_model(skeleton_cb)
|
355 |
+
ehs_p = model.pose_model(skeleton_p)
|
356 |
+
ehs_text = torch.zeros((c.shape[0], 1, 768)).to("cuda")
|
357 |
+
# controlnet_cond = torch.cat((controlnet_cond_f, controlnet_cond_b, ehs_cf, ehs_cb, ehs_p), dim=1)
|
358 |
+
x_noisy = torch.cat(
|
359 |
+
(start_code, test_model_kwargs['inpaint_image'], test_model_kwargs['inpaint_mask']), dim=1)
|
360 |
+
|
361 |
+
down_samples_f, mid_samples_f = model.local_controlnet(x_noisy, ts,
|
362 |
+
encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_f, ehs_c=ehs_cf, ehs_p=ehs_p)
|
363 |
+
down_samples_b, mid_samples_b = model.local_controlnet(x_noisy, ts,
|
364 |
+
encoder_hidden_states=ehs_text.to("cuda"), controlnet_cond=controlnet_cond_b, ehs_c=ehs_cb, ehs_p=ehs_p)
|
365 |
+
|
366 |
+
# print(torch.max(down_samples_f[0]))
|
367 |
+
# print(torch.min(down_samples_f[0]))
|
368 |
+
|
369 |
+
# normalized_tensor = (down_samples_f[0] + 1) / 2
|
370 |
+
|
371 |
+
# # 将张量值范围从[0,1]转换到[0,255]
|
372 |
+
# scaled_tensor = normalized_tensor * 255
|
373 |
+
|
374 |
+
# # 将张量转换为NumPy数组
|
375 |
+
# numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
376 |
+
|
377 |
+
# # 将NumPy数组转换为PIL图像
|
378 |
+
# image = Image.fromarray(numpy_array)
|
379 |
+
|
380 |
+
# # 保存图像
|
381 |
+
# image.save("down_samples_f.jpg")
|
382 |
+
|
383 |
+
# normalized_tensor = (down_samples_b[0] + 1) / 2
|
384 |
+
|
385 |
+
# # 将张量值范围从[0,1]转换到[0,255]
|
386 |
+
# scaled_tensor = normalized_tensor * 255
|
387 |
+
|
388 |
+
# # 将张量转换为NumPy数组
|
389 |
+
# numpy_array = scaled_tensor.squeeze().cpu().numpy().astype(np.uint8)
|
390 |
+
|
391 |
+
# # 将NumPy数组转换为PIL图像
|
392 |
+
# image = Image.fromarray(numpy_array)
|
393 |
+
|
394 |
+
# # 保存图像
|
395 |
+
# image.save("down_samples_b.jpg")
|
396 |
+
|
397 |
+
mid_samples = mid_samples_f + mid_samples_b
|
398 |
+
down_samples = ()
|
399 |
+
for ds in range(len(down_samples_f)):
|
400 |
+
tmp = torch.cat((down_samples_f[ds], down_samples_b[ds]), dim=1)
|
401 |
+
down_samples = down_samples + (tmp,)
|
402 |
+
|
403 |
+
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
|
404 |
+
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
|
405 |
+
conditioning=c,
|
406 |
+
batch_size=opt.n_samples,
|
407 |
+
shape=shape,
|
408 |
+
verbose=False,
|
409 |
+
unconditional_guidance_scale=opt.scale,
|
410 |
+
unconditional_conditioning=uc,
|
411 |
+
eta=opt.ddim_eta,
|
412 |
+
x_T=start_code,
|
413 |
+
down_samples=down_samples,
|
414 |
+
test_model_kwargs=test_model_kwargs)
|
415 |
+
|
416 |
+
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
417 |
+
x_sample_result = x_samples_ddim
|
418 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
419 |
+
x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
|
420 |
+
|
421 |
+
x_checked_image = x_samples_ddim
|
422 |
+
x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
|
423 |
+
x_source = torch.clamp((image_tensor + 1.0) / 2.0, min=0.0, max=1.0)
|
424 |
+
x_result = x_checked_image_torch * (1 - mask_tensor) + mask_tensor * x_source
|
425 |
+
# x_result = x_checked_image_torch
|
426 |
+
|
427 |
+
resize = transforms.Resize((opt.H, int(opt.H / 256 * 192)))
|
428 |
+
|
429 |
+
if not opt.skip_save:
|
430 |
+
|
431 |
+
def un_norm(x):
|
432 |
+
return (x + 1.0) / 2.0
|
433 |
+
|
434 |
+
for i, x_sample in enumerate(x_result):
|
435 |
+
filename = data['file_name'][i]
|
436 |
+
# filename = data['file_name']
|
437 |
+
save_x = resize(x_sample)
|
438 |
+
save_x = 255. * rearrange(save_x.cpu().numpy(), 'c h w -> h w c')
|
439 |
+
img = Image.fromarray(save_x.astype(np.uint8))
|
440 |
+
img.save(os.path.join(result_path, filename[:-4] + ".png"))
|
441 |
+
|
442 |
+
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
|
443 |
+
f" \nEnjoy.")
|
444 |
+
|
445 |
+
|
446 |
+
if __name__ == "__main__":
|
447 |
+
main()
|
test.sh
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=3 python test.py --gpu_id 0 \
|
2 |
+
--ddim_steps 50 \
|
3 |
+
--outdir results/try/ \
|
4 |
+
--config configs/viton512.yaml \
|
5 |
+
--dataroot /datasets/NVG \
|
6 |
+
--ckpt checkpoints/mvg.ckpt \
|
7 |
+
--n_samples 1 \
|
8 |
+
--seed 23 \
|
9 |
+
--scale 1 \
|
10 |
+
--H 512 \
|
11 |
+
--W 384
|
12 |
+
|
13 |
+
#!/bin/bash
|
train.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=4,5 python -u main.py --logdir models/oc --pretrained_model checkpoints/model.ckpt --base configs/viton512.yaml --scale_lr False
|