Slep commited on
Commit
c45703e
·
0 Parent(s):

Initial Demo.

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.faiss filter=lfs diff=lfs merge=lfs -text
7
+ *.ftz filter=lfs diff=lfs merge=lfs -text
8
+ *.gz filter=lfs diff=lfs merge=lfs -text
9
+ *.h5 filter=lfs diff=lfs merge=lfs -text
10
+ *.joblib filter=lfs diff=lfs merge=lfs -text
11
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
13
+ *.model filter=lfs diff=lfs merge=lfs -text
14
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
15
+ *.npy filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ artifacts/ filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom
2
+ _*
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ #.idea/
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CondViT LRVSF Demo
3
+ emoji: 🔎
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.32.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import faiss
4
+ import gradio as gr
5
+
6
+ import base64
7
+
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ from src.model import ConditionalViT, B16_Params, categories
12
+ from src.transform import valid_tf
13
+ from src.process_images import process_img, make_img_html
14
+ from src.examples import ExamplesHandler
15
+ from src.js_loader import JavaScriptLoader
16
+
17
+ # Load Model
18
+ m = ConditionalViT(**B16_Params, n_categories=len(categories))
19
+ m.load_state_dict(torch.load("./artifacts/cat_condvit_b16.pth", map_location="cpu"))
20
+ m.eval()
21
+
22
+ # Load data
23
+ index = faiss.read_index("./artifacts/gallery_index.faiss")
24
+ gal_imgs = pd.read_parquet("./artifacts/gallery_imgs.parquet")
25
+ tfs = valid_tf((224, 224))
26
+
27
+ K = 5
28
+
29
+ examples = [
30
+ ["examples/3.jpg", "Outwear"],
31
+ ["examples/3.jpg", "Lower Body"],
32
+ ["examples/3.jpg", "Feet"],
33
+ ["examples/757.jpg", "Bags"],
34
+ ["examples/757.jpg", "Upper Body"],
35
+ ["examples/769.jpg", "Upper Body"],
36
+ ["examples/1811.jpg", "Lower Body"],
37
+ ["examples/1811.jpg", "Bags"],
38
+ ]
39
+
40
+ @torch.inference_mode()
41
+ def retrieval(image, category):
42
+ if image is None or category is None: return
43
+
44
+ q_emb = m(tfs(image).unsqueeze(0), torch.tensor([category]))
45
+
46
+ r = index.search(q_emb, K)
47
+
48
+ imgs = [process_img(idx, gal_imgs) for idx in r[1][0]]
49
+
50
+ html = [make_img_html(i) for i in imgs]
51
+ html += ["<p></p>"] # Avoid Gradio's last-child{margin-bottom:0!important;}
52
+
53
+ return "\n".join(html)
54
+
55
+
56
+
57
+ JavaScriptLoader("src/custom_functions.js")
58
+ with gr.Blocks(css="src/style.css") as demo:
59
+ with gr.Column():
60
+ gr.Markdown("""
61
+ # Conditional ViT Demo
62
+ [[`Paper`](https://arxiv.org/abs/2306.02928)]
63
+ [[`Code`](https://github.com/Simon-Lepage/CondViT-LRVSF)]
64
+ [[`Dataset`](https://huggingface.co/datasets/Slep/LAION-RVS-Fashion)]
65
+
66
+ *Running on 2 vCPU, 16Go RAM.*
67
+
68
+ - **Model :** Categorical CondViT-B/16
69
+ - **Gallery :** 93K images.
70
+ """)
71
+
72
+ # Input section
73
+ with gr.Row():
74
+ img = gr.Image(label="Query Image", type="pil", elem_id="query_img")
75
+ with gr.Column():
76
+ cat = gr.Dropdown(choices = categories, label="Category", value="Upper Body", type='index', elem_id="dropdown")
77
+ submit = gr.Button("Submit")
78
+
79
+ # Examples
80
+ gr.Examples(examples, inputs=[img, cat], fn=retrieval, elem_id = "preset_examples", examples_per_page=100)
81
+ gr.HTML(value=ExamplesHandler(examples).to_html(), label = "examples", elem_id = "html_examples")
82
+
83
+ # Outputs
84
+ gr.Markdown("# Retrieved Items")
85
+ out = gr.HTML(label="Results", elem_id = "html_output")
86
+
87
+ submit.click(fn=retrieval, inputs=[img, cat], outputs=out)
88
+
89
+ demo.launch()
artifacts/cat_condvit_b16.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:badeeedf937ce955946ebbbbe7a1d531b171374e8fc58697436e7dfca31aa694
3
+ size 344860235
artifacts/gallery_imgs.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f633c261f353bd64498c267a7d6744c424471a4f09f69f2ac8f9df5537a259e
3
+ size 1487433539
artifacts/gallery_index.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ed82e5e12fb2c8d6aaa89336623fc5948f7e29805f961f0e07ce92b0cebd8c2
3
+ size 192118829
examples/1811.jpg ADDED
examples/3.jpg ADDED
examples/757.jpg ADDED
examples/769.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ pillow
4
+ faiss-cpu
5
+ pandas
6
+ pyarrow
7
+ gradio
src/custom_functions.js ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ remap_click = (i) => {
2
+ let examples = document.getElementById("preset_examples");
3
+ let rows = examples.getElementsByTagName("tbody")[0].getElementsByTagName("tr");
4
+ rows[i].click();
5
+ }
src/custom_js.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+
3
+ # Adapted from https://github.com/gradio-app/gradio/discussions/2932
4
+
5
+ class JavaScriptLoader:
6
+ def __init__(self, target):
7
+ #Copy the template response
8
+ self.original_template = gradio.routes.templates.TemplateResponse
9
+ #Prep the js files
10
+ self.load_js(target)
11
+ #reassign the template response to your method, so gradio calls your method instead
12
+ gradio.routes.templates.TemplateResponse = self.template_response
13
+
14
+ def load_js(self, target):
15
+ with open(target, 'r', encoding="utf-8") as file:
16
+ self.loaded_script = f"<script>\n{file.read()}\n</script>"
17
+
18
+ def template_response(self, *args, **kwargs):
19
+ """Once gradio calls your method, you call the original, you modify it to include
20
+ your scripts and you return the modified version
21
+ """
22
+ response = self.original_template(*args, **kwargs)
23
+ response.body = response.body.replace(
24
+ '</head>'.encode('utf-8'), self.loaded_script + "\n</head>".encode("utf-8")
25
+ )
26
+ response.init_headers()
27
+ return response
src/examples.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from .process_images import make_img_html
3
+
4
+
5
+ class ExamplesHandler:
6
+ def __init__(self, examples):
7
+ self.examples = examples
8
+
9
+ def to_html(self):
10
+
11
+ ret = ""
12
+ for i, (img_path, category) in enumerate(self.examples):
13
+ ret += f"<figure id='example_{i}' onclick='remap_click({i})'>"
14
+ img = Image.open(img_path).convert("RGB")
15
+ ret += make_img_html(img)
16
+ ret += f"<figcaption>{category}</figcaption>"
17
+ ret += "</figure>"
18
+
19
+ ret += "<p></p>"
20
+
21
+ return ret
src/js_loader.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio
2
+
3
+ class JavaScriptLoader:
4
+ def __init__(self, target):
5
+ #Copy the template response
6
+ self.original_template = gradio.routes.templates.TemplateResponse
7
+ #Prep the js files
8
+ self.load_js(target)
9
+ #reassign the template response to your method, so gradio calls your method instead
10
+ gradio.routes.templates.TemplateResponse = self.template_response
11
+
12
+ def load_js(self, target):
13
+ with open(target, 'r', encoding="utf-8") as file:
14
+ self.loaded_script = f"<script>\n{file.read()}\n</script>"
15
+
16
+ def template_response(self, *args, **kwargs):
17
+ """Once gradio calls your method, you call the original, you modify it to include
18
+ your scripts and you return the modified version
19
+ """
20
+ response = self.original_template(*args, **kwargs)
21
+ response.body = response.body.replace(
22
+ '</head>'.encode('utf-8'), (self.loaded_script + "\n</head>").encode("utf-8")
23
+ )
24
+ response.init_headers()
25
+ return response
src/model.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ categories = [
2
+ "Bags",
3
+ "Feet",
4
+ "Hands",
5
+ "Head",
6
+ "Lower Body",
7
+ "Neck",
8
+ "Outwear",
9
+ "Upper Body",
10
+ "Waist",
11
+ "Whole Body",
12
+ ]
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from collections import OrderedDict
18
+ import logging
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class LayerNorm(nn.LayerNorm):
24
+ """Subclass torch's LayerNorm to handle fp16."""
25
+
26
+ def forward(self, x: torch.Tensor):
27
+ if self.weight.dtype != x.dtype:
28
+ orig_type = x.dtype
29
+ ret = super().forward(x.type(self.weight.dtype))
30
+ return ret.type(orig_type)
31
+ else:
32
+ return super().forward(x)
33
+
34
+
35
+ class QuickGELU(nn.Module):
36
+ def forward(self, x: torch.Tensor):
37
+ return x * torch.sigmoid(1.702 * x)
38
+
39
+
40
+ class ResidualAttentionBlock(nn.Module):
41
+ def __init__(
42
+ self,
43
+ d_model: int,
44
+ n_head: int,
45
+ attn_mask: torch.Tensor = None,
46
+ ):
47
+ super().__init__()
48
+
49
+ self.attn = nn.MultiheadAttention(d_model, n_head)
50
+ self.ln_1 = LayerNorm(d_model)
51
+ self.mlp = nn.Sequential(
52
+ OrderedDict(
53
+ [
54
+ (
55
+ "c_fc",
56
+ nn.Linear(d_model, d_model * 4),
57
+ ),
58
+ ("gelu", QuickGELU()),
59
+ (
60
+ "c_proj",
61
+ nn.Linear(d_model * 4, d_model),
62
+ ),
63
+ ]
64
+ )
65
+ )
66
+ self.ln_2 = LayerNorm(d_model)
67
+ self.attn_mask = attn_mask
68
+
69
+ def attention(self, x: torch.Tensor):
70
+ self.attn_mask = (
71
+ self.attn_mask.to(dtype=x.dtype, device=x.device)
72
+ if self.attn_mask is not None
73
+ else None
74
+ )
75
+ return self.attn(
76
+ x,
77
+ x,
78
+ x,
79
+ need_weights=False,
80
+ attn_mask=self.attn_mask,
81
+ )[0]
82
+
83
+ def forward(self, x: torch.Tensor):
84
+ x = x + self.attention(self.ln_1(x))
85
+ x = x + self.mlp(self.ln_2(x))
86
+ return x
87
+
88
+
89
+ class Transformer(nn.Module):
90
+ def __init__(
91
+ self,
92
+ width: int,
93
+ layers: int,
94
+ heads: int,
95
+ attn_mask: torch.Tensor = None,
96
+ ):
97
+ super().__init__()
98
+ self.width = width
99
+ self.layers = layers
100
+ self.resblocks = nn.Sequential(
101
+ *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
102
+ )
103
+
104
+ def forward(self, x: torch.Tensor):
105
+ return self.resblocks(x)
106
+
107
+
108
+ class ConditionalViT(nn.Module):
109
+ def __init__(
110
+ self,
111
+ input_resolution: int,
112
+ patch_size: int,
113
+ width: int,
114
+ layers: int,
115
+ heads: int,
116
+ output_dim: int,
117
+ n_categories: int = None,
118
+ **kwargs,
119
+ ):
120
+ if kwargs:
121
+ logger.warning(f"Got unused kwargs : {kwargs}")
122
+
123
+ super().__init__()
124
+ self.input_resolution = input_resolution
125
+ self.output_dim = output_dim
126
+ self.conv1 = nn.Conv2d(
127
+ in_channels=3,
128
+ out_channels=width,
129
+ kernel_size=patch_size,
130
+ stride=patch_size,
131
+ bias=False,
132
+ )
133
+
134
+ scale = width**-0.5
135
+
136
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
137
+
138
+ self.n_categories = n_categories
139
+ if self.n_categories:
140
+ self.c_embedding = nn.Embedding(self.n_categories, width)
141
+ self.c_pos_embedding = nn.Parameter(scale * torch.randn(1, width))
142
+
143
+ self.positional_embedding = nn.Parameter(
144
+ scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
145
+ )
146
+ self.ln_pre = LayerNorm(width)
147
+
148
+ self.transformer = Transformer(width, layers, heads)
149
+ self.ln_post = LayerNorm(width)
150
+ self.logit_scale = torch.nn.Parameter(torch.ones([]) * 4.6052)
151
+
152
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
153
+
154
+ def forward(self, imgs: torch.Tensor, c: torch.Tensor = None):
155
+ """
156
+ imgs : Batch of images
157
+ c : category indices. 0 = "No given category".
158
+ """
159
+
160
+ x = self.conv1(imgs) # shape = [*, width, grid, grid]
161
+ # shape = [*, width, grid ** 2]
162
+ x = x.reshape(x.shape[0], x.shape[1], -1)
163
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
164
+
165
+ # [CLS, grid] + maybe Categories.
166
+ tokens = [self.class_embedding.tile(x.shape[0], 1, 1), x] # NLD
167
+ pos_embed = [self.positional_embedding] # LD
168
+
169
+ if self.n_categories and c is not None: # If c is None, we don't add the token
170
+ tokens += [self.c_embedding(c).unsqueeze(1)] # ND -> N1D
171
+ pos_embed += [self.c_pos_embedding] # 1D
172
+
173
+ x = torch.cat(
174
+ tokens,
175
+ dim=1,
176
+ ) # shape = [*, grid ** 2 + 1|2, width] = N(L|L+1)D
177
+ pos_embed = torch.cat(pos_embed, dim=0).unsqueeze(0) # 1(L|L+1)D
178
+
179
+ x = x + pos_embed
180
+ x = self.ln_pre(x)
181
+
182
+ x = x.permute(1, 0, 2) # NLD -> LND
183
+
184
+ x = self.transformer(x)
185
+ x = x.permute(1, 0, 2) # LND -> NLD
186
+
187
+ x = self.ln_post(x[:, 0, :])
188
+
189
+ x = x @ self.proj
190
+
191
+ return x
192
+
193
+
194
+ # SIZES
195
+ B32_Params = {
196
+ "input_resolution": 224,
197
+ "patch_size": 32,
198
+ "width": 768,
199
+ "layers": 12,
200
+ "heads": 12,
201
+ "output_dim": 512,
202
+ }
203
+
204
+ B16_Params = {
205
+ "input_resolution": 224,
206
+ "patch_size": 16,
207
+ "width": 768,
208
+ "layers": 12,
209
+ "heads": 12,
210
+ "output_dim": 512,
211
+ }
212
+
213
+ params = {"B32": B32_Params, "B16": B16_Params}
src/process_images.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ # Index to PIL
6
+ def process_img(idx, ds):
7
+ img = Image.open(BytesIO(ds.iloc[idx].jpg)).convert("RGB")
8
+ return img
9
+
10
+ def make_img_html(img):
11
+ b = BytesIO()
12
+ img.save(b, format='PNG')
13
+ buffer = b.getvalue()
14
+
15
+ return f'<img height=200px src="data:image/png;base64,{base64.b64encode(buffer).decode()}"></img>'
src/style.css ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* OUTPUT */
2
+ #html_output, #html_examples {
3
+ display: flex;
4
+ align-items: center;
5
+ justify-content: center;
6
+ flex-wrap: wrap;
7
+ }
8
+
9
+ #html_output > img {
10
+ align-self: center;
11
+ height:200px;
12
+ border: 2px solid;
13
+ border-color: var(--block-border-color);
14
+ border-radius: var(--block-radius);
15
+ margin:1.5em;
16
+ }
17
+
18
+ /* EXAMPLE */
19
+ #html_examples > figure > img {
20
+ align-self: center;
21
+ height: 100px;
22
+ border: 2px solid;
23
+ border-color: var(--block-border-color);
24
+ border-radius: var(--block-radius);
25
+ margin:.7em;
26
+ }
27
+
28
+ #html_examples > figure {
29
+ transition-duration: 0.2s;
30
+ }
31
+
32
+ #html_examples > figure:hover{
33
+ transform: scale(1.2);
34
+ }
35
+
36
+ #html_examples > figure > figcaption {
37
+ text-align: center;
38
+ }
39
+
40
+ #preset_examples {
41
+ display: none;
42
+ }
src/transform.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision.transforms import transforms as tf
2
+ import torchvision.transforms.functional as F
3
+
4
+
5
+ class SquarePad:
6
+ def __init__(self, color):
7
+ self.col = color
8
+
9
+ def __call__(self, image):
10
+ max_wh = max(image.size)
11
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
12
+ p_right, p_bottom = [
13
+ max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])
14
+ ]
15
+ padding = (p_left, p_top, p_right, p_bottom)
16
+ return F.pad(image, padding, self.col, "constant")
17
+
18
+ def valid_tf(size):
19
+ return tf.Compose(
20
+ [
21
+ SquarePad(255),
22
+ tf.Resize(size),
23
+ tf.ToTensor(),
24
+ tf.Normalize(
25
+ mean=(0.48145466, 0.4578275, 0.40821073),
26
+ std=(0.26862954, 0.26130258, 0.27577711),
27
+ ),
28
+ ]
29
+ )