Spaces:
Running
Running
Commit
·
c45703e
0
Parent(s):
Initial Demo.
Browse files- .gitattributes +36 -0
- .gitignore +163 -0
- README.md +11 -0
- app.py +89 -0
- artifacts/cat_condvit_b16.pth +3 -0
- artifacts/gallery_imgs.parquet +3 -0
- artifacts/gallery_index.faiss +3 -0
- examples/1811.jpg +0 -0
- examples/3.jpg +0 -0
- examples/757.jpg +0 -0
- examples/769.jpg +0 -0
- requirements.txt +7 -0
- src/custom_functions.js +5 -0
- src/custom_js.py +27 -0
- src/examples.py +21 -0
- src/js_loader.py +25 -0
- src/model.py +213 -0
- src/process_images.py +15 -0
- src/style.css +42 -0
- src/transform.py +29 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
27 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
artifacts/ filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Custom
|
2 |
+
_*
|
3 |
+
|
4 |
+
# Byte-compiled / optimized / DLL files
|
5 |
+
__pycache__/
|
6 |
+
*.py[cod]
|
7 |
+
*$py.class
|
8 |
+
|
9 |
+
# C extensions
|
10 |
+
*.so
|
11 |
+
|
12 |
+
# Distribution / packaging
|
13 |
+
.Python
|
14 |
+
build/
|
15 |
+
develop-eggs/
|
16 |
+
dist/
|
17 |
+
downloads/
|
18 |
+
eggs/
|
19 |
+
.eggs/
|
20 |
+
lib/
|
21 |
+
lib64/
|
22 |
+
parts/
|
23 |
+
sdist/
|
24 |
+
var/
|
25 |
+
wheels/
|
26 |
+
share/python-wheels/
|
27 |
+
*.egg-info/
|
28 |
+
.installed.cfg
|
29 |
+
*.egg
|
30 |
+
MANIFEST
|
31 |
+
|
32 |
+
# PyInstaller
|
33 |
+
# Usually these files are written by a python script from a template
|
34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35 |
+
*.manifest
|
36 |
+
*.spec
|
37 |
+
|
38 |
+
# Installer logs
|
39 |
+
pip-log.txt
|
40 |
+
pip-delete-this-directory.txt
|
41 |
+
|
42 |
+
# Unit test / coverage reports
|
43 |
+
htmlcov/
|
44 |
+
.tox/
|
45 |
+
.nox/
|
46 |
+
.coverage
|
47 |
+
.coverage.*
|
48 |
+
.cache
|
49 |
+
nosetests.xml
|
50 |
+
coverage.xml
|
51 |
+
*.cover
|
52 |
+
*.py,cover
|
53 |
+
.hypothesis/
|
54 |
+
.pytest_cache/
|
55 |
+
cover/
|
56 |
+
|
57 |
+
# Translations
|
58 |
+
*.mo
|
59 |
+
*.pot
|
60 |
+
|
61 |
+
# Django stuff:
|
62 |
+
*.log
|
63 |
+
local_settings.py
|
64 |
+
db.sqlite3
|
65 |
+
db.sqlite3-journal
|
66 |
+
|
67 |
+
# Flask stuff:
|
68 |
+
instance/
|
69 |
+
.webassets-cache
|
70 |
+
|
71 |
+
# Scrapy stuff:
|
72 |
+
.scrapy
|
73 |
+
|
74 |
+
# Sphinx documentation
|
75 |
+
docs/_build/
|
76 |
+
|
77 |
+
# PyBuilder
|
78 |
+
.pybuilder/
|
79 |
+
target/
|
80 |
+
|
81 |
+
# Jupyter Notebook
|
82 |
+
.ipynb_checkpoints
|
83 |
+
|
84 |
+
# IPython
|
85 |
+
profile_default/
|
86 |
+
ipython_config.py
|
87 |
+
|
88 |
+
# pyenv
|
89 |
+
# For a library or package, you might want to ignore these files since the code is
|
90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
91 |
+
# .python-version
|
92 |
+
|
93 |
+
# pipenv
|
94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97 |
+
# install all needed dependencies.
|
98 |
+
#Pipfile.lock
|
99 |
+
|
100 |
+
# poetry
|
101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103 |
+
# commonly ignored for libraries.
|
104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105 |
+
#poetry.lock
|
106 |
+
|
107 |
+
# pdm
|
108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109 |
+
#pdm.lock
|
110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111 |
+
# in version control.
|
112 |
+
# https://pdm.fming.dev/#use-with-ide
|
113 |
+
.pdm.toml
|
114 |
+
|
115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
116 |
+
__pypackages__/
|
117 |
+
|
118 |
+
# Celery stuff
|
119 |
+
celerybeat-schedule
|
120 |
+
celerybeat.pid
|
121 |
+
|
122 |
+
# SageMath parsed files
|
123 |
+
*.sage.py
|
124 |
+
|
125 |
+
# Environments
|
126 |
+
.env
|
127 |
+
.venv
|
128 |
+
env/
|
129 |
+
venv/
|
130 |
+
ENV/
|
131 |
+
env.bak/
|
132 |
+
venv.bak/
|
133 |
+
|
134 |
+
# Spyder project settings
|
135 |
+
.spyderproject
|
136 |
+
.spyproject
|
137 |
+
|
138 |
+
# Rope project settings
|
139 |
+
.ropeproject
|
140 |
+
|
141 |
+
# mkdocs documentation
|
142 |
+
/site
|
143 |
+
|
144 |
+
# mypy
|
145 |
+
.mypy_cache/
|
146 |
+
.dmypy.json
|
147 |
+
dmypy.json
|
148 |
+
|
149 |
+
# Pyre type checker
|
150 |
+
.pyre/
|
151 |
+
|
152 |
+
# pytype static type analyzer
|
153 |
+
.pytype/
|
154 |
+
|
155 |
+
# Cython debug symbols
|
156 |
+
cython_debug/
|
157 |
+
|
158 |
+
# PyCharm
|
159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
163 |
+
#.idea/
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: CondViT LRVSF Demo
|
3 |
+
emoji: 🔎
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.32.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
app.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import torch
|
3 |
+
import faiss
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
import base64
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
from src.model import ConditionalViT, B16_Params, categories
|
12 |
+
from src.transform import valid_tf
|
13 |
+
from src.process_images import process_img, make_img_html
|
14 |
+
from src.examples import ExamplesHandler
|
15 |
+
from src.js_loader import JavaScriptLoader
|
16 |
+
|
17 |
+
# Load Model
|
18 |
+
m = ConditionalViT(**B16_Params, n_categories=len(categories))
|
19 |
+
m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
|
20 |
+
m.eval()
|
21 |
+
|
22 |
+
# Load data
|
23 |
+
index = faiss.read_index("./artifacts/gallery_index.faiss")
|
24 |
+
gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
|
25 |
+
tfs = valid_tf((224, 224))
|
26 |
+
|
27 |
+
K = 5
|
28 |
+
|
29 |
+
examples = [
|
30 |
+
["examples/3.jpg", "Outwear"],
|
31 |
+
["examples/3.jpg", "Lower Body"],
|
32 |
+
["examples/3.jpg", "Feet"],
|
33 |
+
["examples/757.jpg", "Bags"],
|
34 |
+
["examples/757.jpg", "Upper Body"],
|
35 |
+
["examples/769.jpg", "Upper Body"],
|
36 |
+
["examples/1811.jpg", "Lower Body"],
|
37 |
+
["examples/1811.jpg", "Bags"],
|
38 |
+
]
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def retrieval(image, category):
|
42 |
+
if image is None or category is None: return
|
43 |
+
|
44 |
+
q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
|
45 |
+
|
46 |
+
r = index.search(q_emb, K)
|
47 |
+
|
48 |
+
imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
|
49 |
+
|
50 |
+
html = [make_img_html(i) for i in imgs]
|
51 |
+
html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
|
52 |
+
|
53 |
+
return "\n".join(html)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
JavaScriptLoader("src/custom_functions.js")
|
58 |
+
with gr.Blocks(css="src/style.css") as demo:
|
59 |
+
with gr.Column():
|
60 |
+
gr.Markdown("""
|
61 |
+
# Conditional ViT Demo
|
62 |
+
[[`Paper`](https://arxiv.org/abs/2306.02928)]
|
63 |
+
[[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
|
64 |
+
[[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
|
65 |
+
|
66 |
+
*Running on 2 vCPU, 16Go RAM.*
|
67 |
+
|
68 |
+
- **Model :** Categorical CondViT-B/16
|
69 |
+
- **Gallery :** 93K images.
|
70 |
+
""")
|
71 |
+
|
72 |
+
# Input section
|
73 |
+
with gr.Row():
|
74 |
+
img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
|
75 |
+
with gr.Column():
|
76 |
+
cat = gr.Dropdown(choices = categories, label="Category", value="Upper Body", type='index', elem_id="dropdown")
|
77 |
+
submit = gr.Button("Submit")
|
78 |
+
|
79 |
+
# Examples
|
80 |
+
gr.Examples(examples, inputs=[img, cat], fn=retrieval, elem_id = "preset_examples", examples_per_page=100)
|
81 |
+
gr.HTML(value=ExamplesHandler(examples).to_html(), label = "examples", elem_id = "html_examples")
|
82 |
+
|
83 |
+
# Outputs
|
84 |
+
gr.Markdown("# Retrieved Items")
|
85 |
+
out = gr.HTML(label="Results", elem_id = "html_output")
|
86 |
+
|
87 |
+
submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
|
88 |
+
|
89 |
+
demo.launch()
|
artifacts/cat_condvit_b16.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:badeeedf937ce955946ebbbbe7a1d531b171374e8fc58697436e7dfca31aa694
|
3 |
+
size 344860235
|
artifacts/gallery_imgs.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f633c261f353bd64498c267a7d6744c424471a4f09f69f2ac8f9df5537a259e
|
3 |
+
size 1487433539
|
artifacts/gallery_index.faiss
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2ed82e5e12fb2c8d6aaa89336623fc5948f7e29805f961f0e07ce92b0cebd8c2
|
3 |
+
size 192118829
|
examples/1811.jpg
ADDED
![]() |
examples/3.jpg
ADDED
![]() |
examples/757.jpg
ADDED
![]() |
examples/769.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
pillow
|
4 |
+
faiss-cpu
|
5 |
+
pandas
|
6 |
+
pyarrow
|
7 |
+
gradio
|
src/custom_functions.js
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
remap_click = (i) => {
|
2 |
+
let examples = document.getElementById("preset_examples");
|
3 |
+
let rows = examples.getElementsByTagName("tbody")[0].getElementsByTagName("tr");
|
4 |
+
rows[i].click();
|
5 |
+
}
|
src/custom_js.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
|
3 |
+
# Adapted from https://github.com/gradio-app/gradio/discussions/2932
|
4 |
+
|
5 |
+
class JavaScriptLoader:
|
6 |
+
def __init__(self, target):
|
7 |
+
#Copy the template response
|
8 |
+
self.original_template = gradio.routes.templates.TemplateResponse
|
9 |
+
#Prep the js files
|
10 |
+
self.load_js(target)
|
11 |
+
#reassign the template response to your method, so gradio calls your method instead
|
12 |
+
gradio.routes.templates.TemplateResponse = self.template_response
|
13 |
+
|
14 |
+
def load_js(self, target):
|
15 |
+
with open(target, 'r', encoding="utf-8") as file:
|
16 |
+
self.loaded_script = f"<script>\n{file.read()}\n</script>"
|
17 |
+
|
18 |
+
def template_response(self, *args, **kwargs):
|
19 |
+
"""Once gradio calls your method, you call the original, you modify it to include
|
20 |
+
your scripts and you return the modified version
|
21 |
+
"""
|
22 |
+
response = self.original_template(*args, **kwargs)
|
23 |
+
response.body = response.body.replace(
|
24 |
+
'</head>'.encode('utf-8'), self.loaded_script + "\n</head>".encode("utf-8")
|
25 |
+
)
|
26 |
+
response.init_headers()
|
27 |
+
return response
|
src/examples.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from .process_images import make_img_html
|
3 |
+
|
4 |
+
|
5 |
+
class ExamplesHandler:
|
6 |
+
def __init__(self, examples):
|
7 |
+
self.examples = examples
|
8 |
+
|
9 |
+
def to_html(self):
|
10 |
+
|
11 |
+
ret = ""
|
12 |
+
for i, (img_path, category) in enumerate(self.examples):
|
13 |
+
ret += f"<figure id='example_{i}' onclick='remap_click({i})'>"
|
14 |
+
img = Image.open(img_path).convert("RGB")
|
15 |
+
ret += make_img_html(img)
|
16 |
+
ret += f"<figcaption>{category}</figcaption>"
|
17 |
+
ret += "</figure>"
|
18 |
+
|
19 |
+
ret += "<p></p>"
|
20 |
+
|
21 |
+
return ret
|
src/js_loader.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio
|
2 |
+
|
3 |
+
class JavaScriptLoader:
|
4 |
+
def __init__(self, target):
|
5 |
+
#Copy the template response
|
6 |
+
self.original_template = gradio.routes.templates.TemplateResponse
|
7 |
+
#Prep the js files
|
8 |
+
self.load_js(target)
|
9 |
+
#reassign the template response to your method, so gradio calls your method instead
|
10 |
+
gradio.routes.templates.TemplateResponse = self.template_response
|
11 |
+
|
12 |
+
def load_js(self, target):
|
13 |
+
with open(target, 'r', encoding="utf-8") as file:
|
14 |
+
self.loaded_script = f"<script>\n{file.read()}\n</script>"
|
15 |
+
|
16 |
+
def template_response(self, *args, **kwargs):
|
17 |
+
"""Once gradio calls your method, you call the original, you modify it to include
|
18 |
+
your scripts and you return the modified version
|
19 |
+
"""
|
20 |
+
response = self.original_template(*args, **kwargs)
|
21 |
+
response.body = response.body.replace(
|
22 |
+
'</head>'.encode('utf-8'), (self.loaded_script + "\n</head>").encode("utf-8")
|
23 |
+
)
|
24 |
+
response.init_headers()
|
25 |
+
return response
|
src/model.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
categories = [
|
2 |
+
"Bags",
|
3 |
+
"Feet",
|
4 |
+
"Hands",
|
5 |
+
"Head",
|
6 |
+
"Lower Body",
|
7 |
+
"Neck",
|
8 |
+
"Outwear",
|
9 |
+
"Upper Body",
|
10 |
+
"Waist",
|
11 |
+
"Whole Body",
|
12 |
+
]
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from collections import OrderedDict
|
18 |
+
import logging
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
class LayerNorm(nn.LayerNorm):
|
24 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
25 |
+
|
26 |
+
def forward(self, x: torch.Tensor):
|
27 |
+
if self.weight.dtype != x.dtype:
|
28 |
+
orig_type = x.dtype
|
29 |
+
ret = super().forward(x.type(self.weight.dtype))
|
30 |
+
return ret.type(orig_type)
|
31 |
+
else:
|
32 |
+
return super().forward(x)
|
33 |
+
|
34 |
+
|
35 |
+
class QuickGELU(nn.Module):
|
36 |
+
def forward(self, x: torch.Tensor):
|
37 |
+
return x * torch.sigmoid(1.702 * x)
|
38 |
+
|
39 |
+
|
40 |
+
class ResidualAttentionBlock(nn.Module):
|
41 |
+
def __init__(
|
42 |
+
self,
|
43 |
+
d_model: int,
|
44 |
+
n_head: int,
|
45 |
+
attn_mask: torch.Tensor = None,
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
50 |
+
self.ln_1 = LayerNorm(d_model)
|
51 |
+
self.mlp = nn.Sequential(
|
52 |
+
OrderedDict(
|
53 |
+
[
|
54 |
+
(
|
55 |
+
"c_fc",
|
56 |
+
nn.Linear(d_model, d_model * 4),
|
57 |
+
),
|
58 |
+
("gelu", QuickGELU()),
|
59 |
+
(
|
60 |
+
"c_proj",
|
61 |
+
nn.Linear(d_model * 4, d_model),
|
62 |
+
),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
)
|
66 |
+
self.ln_2 = LayerNorm(d_model)
|
67 |
+
self.attn_mask = attn_mask
|
68 |
+
|
69 |
+
def attention(self, x: torch.Tensor):
|
70 |
+
self.attn_mask = (
|
71 |
+
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
72 |
+
if self.attn_mask is not None
|
73 |
+
else None
|
74 |
+
)
|
75 |
+
return self.attn(
|
76 |
+
x,
|
77 |
+
x,
|
78 |
+
x,
|
79 |
+
need_weights=False,
|
80 |
+
attn_mask=self.attn_mask,
|
81 |
+
)[0]
|
82 |
+
|
83 |
+
def forward(self, x: torch.Tensor):
|
84 |
+
x = x + self.attention(self.ln_1(x))
|
85 |
+
x = x + self.mlp(self.ln_2(x))
|
86 |
+
return x
|
87 |
+
|
88 |
+
|
89 |
+
class Transformer(nn.Module):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
width: int,
|
93 |
+
layers: int,
|
94 |
+
heads: int,
|
95 |
+
attn_mask: torch.Tensor = None,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
self.width = width
|
99 |
+
self.layers = layers
|
100 |
+
self.resblocks = nn.Sequential(
|
101 |
+
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, x: torch.Tensor):
|
105 |
+
return self.resblocks(x)
|
106 |
+
|
107 |
+
|
108 |
+
class ConditionalViT(nn.Module):
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
input_resolution: int,
|
112 |
+
patch_size: int,
|
113 |
+
width: int,
|
114 |
+
layers: int,
|
115 |
+
heads: int,
|
116 |
+
output_dim: int,
|
117 |
+
n_categories: int = None,
|
118 |
+
**kwargs,
|
119 |
+
):
|
120 |
+
if kwargs:
|
121 |
+
logger.warning(f"Got unused kwargs : {kwargs}")
|
122 |
+
|
123 |
+
super().__init__()
|
124 |
+
self.input_resolution = input_resolution
|
125 |
+
self.output_dim = output_dim
|
126 |
+
self.conv1 = nn.Conv2d(
|
127 |
+
in_channels=3,
|
128 |
+
out_channels=width,
|
129 |
+
kernel_size=patch_size,
|
130 |
+
stride=patch_size,
|
131 |
+
bias=False,
|
132 |
+
)
|
133 |
+
|
134 |
+
scale = width**-0.5
|
135 |
+
|
136 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
137 |
+
|
138 |
+
self.n_categories = n_categories
|
139 |
+
if self.n_categories:
|
140 |
+
self.c_embedding = nn.Embedding(self.n_categories, width)
|
141 |
+
self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width))
|
142 |
+
|
143 |
+
self.positional_embedding = nn.Parameter(
|
144 |
+
scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
|
145 |
+
)
|
146 |
+
self.ln_pre = LayerNorm(width)
|
147 |
+
|
148 |
+
self.transformer = Transformer(width, layers, heads)
|
149 |
+
self.ln_post = LayerNorm(width)
|
150 |
+
self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052)
|
151 |
+
|
152 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
153 |
+
|
154 |
+
def forward(self, imgs: torch.Tensor, c: torch.Tensor = None):
|
155 |
+
"""
|
156 |
+
imgs : Batch of images
|
157 |
+
c : category indices. 0 = "No given category".
|
158 |
+
"""
|
159 |
+
|
160 |
+
x = self.conv1(imgs) # shape = [*, width, grid, grid]
|
161 |
+
# shape = [*, width, grid ** 2]
|
162 |
+
x = x.reshape(x.shape[0], x.shape[1], -1)
|
163 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
164 |
+
|
165 |
+
# [CLS, grid] + maybe Categories.
|
166 |
+
tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD
|
167 |
+
pos_embed = [self.positional_embedding] # LD
|
168 |
+
|
169 |
+
if self.n_categories and c is not None: # If c is None, we don't add the token
|
170 |
+
tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D
|
171 |
+
pos_embed += [self.c_pos_embedding] # 1D
|
172 |
+
|
173 |
+
x = torch.cat(
|
174 |
+
tokens,
|
175 |
+
dim=1,
|
176 |
+
) # shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D
|
177 |
+
pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D
|
178 |
+
|
179 |
+
x = x + pos_embed
|
180 |
+
x = self.ln_pre(x)
|
181 |
+
|
182 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
183 |
+
|
184 |
+
x = self.transformer(x)
|
185 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
186 |
+
|
187 |
+
x = self.ln_post(x[:, 0, :])
|
188 |
+
|
189 |
+
x = x @ self.proj
|
190 |
+
|
191 |
+
return x
|
192 |
+
|
193 |
+
|
194 |
+
# SIZES
|
195 |
+
B32_Params = {
|
196 |
+
"input_resolution": 224,
|
197 |
+
"patch_size": 32,
|
198 |
+
"width": 768,
|
199 |
+
"layers": 12,
|
200 |
+
"heads": 12,
|
201 |
+
"output_dim": 512,
|
202 |
+
}
|
203 |
+
|
204 |
+
B16_Params = {
|
205 |
+
"input_resolution": 224,
|
206 |
+
"patch_size": 16,
|
207 |
+
"width": 768,
|
208 |
+
"layers": 12,
|
209 |
+
"heads": 12,
|
210 |
+
"output_dim": 512,
|
211 |
+
}
|
212 |
+
|
213 |
+
params = {"B32": B32_Params, "B16": B16_Params}
|
src/process_images.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
|
5 |
+
# Index to PIL
|
6 |
+
def process_img(idx, ds):
|
7 |
+
img = Image.open(BytesIO(ds.iloc[idx].jpg)).convert("RGB")
|
8 |
+
return img
|
9 |
+
|
10 |
+
def make_img_html(img):
|
11 |
+
b = BytesIO()
|
12 |
+
img.save(b, format='PNG')
|
13 |
+
buffer = b.getvalue()
|
14 |
+
|
15 |
+
return f'<img height=200px src="data:image/png;base64,{base64.b64encode(buffer).decode()}"></img>'
|
src/style.css
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* OUTPUT */
|
2 |
+
#html_output, #html_examples {
|
3 |
+
display: flex;
|
4 |
+
align-items: center;
|
5 |
+
justify-content: center;
|
6 |
+
flex-wrap: wrap;
|
7 |
+
}
|
8 |
+
|
9 |
+
#html_output > img {
|
10 |
+
align-self: center;
|
11 |
+
height:200px;
|
12 |
+
border: 2px solid;
|
13 |
+
border-color: var(--block-border-color);
|
14 |
+
border-radius: var(--block-radius);
|
15 |
+
margin:1.5em;
|
16 |
+
}
|
17 |
+
|
18 |
+
/* EXAMPLE */
|
19 |
+
#html_examples > figure > img {
|
20 |
+
align-self: center;
|
21 |
+
height: 100px;
|
22 |
+
border: 2px solid;
|
23 |
+
border-color: var(--block-border-color);
|
24 |
+
border-radius: var(--block-radius);
|
25 |
+
margin:.7em;
|
26 |
+
}
|
27 |
+
|
28 |
+
#html_examples > figure {
|
29 |
+
transition-duration: 0.2s;
|
30 |
+
}
|
31 |
+
|
32 |
+
#html_examples > figure:hover{
|
33 |
+
transform: scale(1.2);
|
34 |
+
}
|
35 |
+
|
36 |
+
#html_examples > figure > figcaption {
|
37 |
+
text-align: center;
|
38 |
+
}
|
39 |
+
|
40 |
+
#preset_examples {
|
41 |
+
display: none;
|
42 |
+
}
|
src/transform.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import transforms as tf
|
2 |
+
import torchvision.transforms.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
class SquarePad:
|
6 |
+
def __init__(self, color):
|
7 |
+
self.col = color
|
8 |
+
|
9 |
+
def __call__(self, image):
|
10 |
+
max_wh = max(image.size)
|
11 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
12 |
+
p_right, p_bottom = [
|
13 |
+
max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])
|
14 |
+
]
|
15 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
16 |
+
return F.pad(image, padding, self.col, "constant")
|
17 |
+
|
18 |
+
def valid_tf(size):
|
19 |
+
return tf.Compose(
|
20 |
+
[
|
21 |
+
SquarePad(255),
|
22 |
+
tf.Resize(size),
|
23 |
+
tf.ToTensor(),
|
24 |
+
tf.Normalize(
|
25 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
26 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
27 |
+
),
|
28 |
+
]
|
29 |
+
)
|