shimu0215 commited on
Commit
dc0de8d
·
verified ·
1 Parent(s): 0290391

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +307 -0
app.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import gradio as gr
4
+ import numpy as np
5
+ import spaces
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from gradio.themes.utils import sizes
9
+ from torchvision import transforms
10
+ from PIL import Image
11
+ import tempfile
12
+ from classes_and_palettes import GOLIATH_PALETTE, GOLIATH_CLASSES
13
+
14
+ class Config:
15
+ ASSETS_DIR = os.path.join(os.path.dirname(__file__), 'assets')
16
+ CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
17
+ CHECKPOINTS = {
18
+ "0.3b": "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2",
19
+ "0.6b": "sapiens_0.6b_goliath_best_goliath_mIoU_7777_epoch_178_torchscript.pt2",
20
+ "1b": "sapiens_1b_goliath_best_goliath_mIoU_7994_epoch_151_torchscript.pt2",
21
+ }
22
+
23
+ class ModelManager:
24
+ @staticmethod
25
+ def load_model(checkpoint_name: str):
26
+ checkpoint_path = os.path.join(Config.CHECKPOINTS_DIR, Config.CHECKPOINTS[checkpoint_name])
27
+ model = torch.jit.load(checkpoint_path)
28
+ model.eval()
29
+ model.to("cuda")
30
+ return model
31
+
32
+ @staticmethod
33
+ @torch.inference_mode()
34
+ def run_model(model, input_tensor, height, width):
35
+ output = model(input_tensor)
36
+ output = F.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
37
+ _, preds = torch.max(output, 1)
38
+ return preds
39
+
40
+ class ImageProcessor:
41
+ def __init__(self):
42
+ self.transform_fn = transforms.Compose([
43
+ transforms.Resize((1024, 768)),
44
+ transforms.ToTensor(),
45
+ transforms.Normalize(mean=[123.5/255, 116.5/255, 103.5/255], std=[58.5/255, 57.0/255, 57.5/255]),
46
+ ])
47
+
48
+ @spaces.GPU
49
+ def process_image(self, image: Image.Image, model_name: str):
50
+ model = ModelManager.load_model(model_name)
51
+ input_tensor = self.transform_fn(image).unsqueeze(0).to("cuda")
52
+
53
+ preds = ModelManager.run_model(model, input_tensor, image.height, image.width)
54
+ mask = preds.squeeze(0).cpu().numpy()
55
+
56
+ # Visualize the segmentation
57
+ blended_image = self.visualize_pred_with_overlay(image, mask)
58
+
59
+ # Create downloadable .npy file
60
+ npy_path = tempfile.mktemp(suffix='.npy')
61
+ np.save(npy_path, mask)
62
+
63
+ return blended_image, npy_path
64
+
65
+ @staticmethod
66
+ def visualize_pred_with_overlay(img, sem_seg, alpha=0.5):
67
+ img_np = np.array(img.convert("RGB"))
68
+ sem_seg = np.array(sem_seg)
69
+
70
+ num_classes = len(GOLIATH_CLASSES)
71
+ ids = np.unique(sem_seg)[::-1]
72
+ legal_indices = ids < num_classes
73
+ ids = ids[legal_indices]
74
+ labels = np.array(ids, dtype=np.int64)
75
+
76
+ colors = [GOLIATH_PALETTE[label] for label in labels]
77
+
78
+ overlay = np.zeros((*sem_seg.shape, 3), dtype=np.uint8)
79
+
80
+ for label, color in zip(labels, colors):
81
+ overlay[sem_seg == label, :] = color
82
+
83
+ blended = np.uint8(img_np * (1 - alpha) + overlay * alpha)
84
+ return Image.fromarray(blended)
85
+
86
+ class GradioInterface:
87
+ def __init__(self):
88
+ self.image_processor = ImageProcessor()
89
+
90
+ def create_interface(self):
91
+ app_styles = """
92
+ <style>
93
+ /* Global Styles */
94
+ body, #root {
95
+ font-family: Helvetica, Arial, sans-serif;
96
+ background-color: #1a1a1a;
97
+ color: #fafafa;
98
+ }
99
+ /* Header Styles */
100
+ .app-header {
101
+ background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
102
+ padding: 24px;
103
+ border-radius: 8px;
104
+ margin-bottom: 24px;
105
+ text-align: center;
106
+ }
107
+ .app-title {
108
+ font-size: 48px;
109
+ margin: 0;
110
+ color: #fafafa;
111
+ }
112
+ .app-subtitle {
113
+ font-size: 24px;
114
+ margin: 8px 0 16px;
115
+ color: #fafafa;
116
+ }
117
+ .app-description {
118
+ font-size: 16px;
119
+ line-height: 1.6;
120
+ opacity: 0.8;
121
+ margin-bottom: 24px;
122
+ }
123
+ /* Button Styles */
124
+ .publication-links {
125
+ display: flex;
126
+ justify-content: center;
127
+ flex-wrap: wrap;
128
+ gap: 8px;
129
+ margin-bottom: 16px;
130
+ }
131
+ .publication-link {
132
+ display: inline-flex;
133
+ align-items: center;
134
+ padding: 8px 16px;
135
+ background-color: #333;
136
+ color: #fff !important;
137
+ text-decoration: none !important;
138
+ border-radius: 20px;
139
+ font-size: 14px;
140
+ transition: background-color 0.3s;
141
+ }
142
+ .publication-link:hover {
143
+ background-color: #555;
144
+ }
145
+ .publication-link i {
146
+ margin-right: 8px;
147
+ }
148
+ /* Content Styles */
149
+ .content-container {
150
+ background-color: #2a2a2a;
151
+ border-radius: 8px;
152
+ padding: 24px;
153
+ margin-bottom: 24px;
154
+ }
155
+ /* Image Styles */
156
+ /* Updated Image Styles */
157
+ .image-preview img {
158
+ max-width: 512px;
159
+ max-height: 512px;
160
+ margin: 0 auto;
161
+ border-radius: 4px;
162
+ display: block;
163
+ object-fit: contain;
164
+ }
165
+
166
+ /* Control Styles */
167
+ .control-panel {
168
+ background-color: #333;
169
+ padding: 16px;
170
+ border-radius: 8px;
171
+ margin-top: 16px;
172
+ }
173
+ /* Gradio Component Overrides */
174
+ .gr-button {
175
+ background-color: #4a4a4a;
176
+ color: #fff;
177
+ border: none;
178
+ border-radius: 4px;
179
+ padding: 8px 16px;
180
+ cursor: pointer;
181
+ transition: background-color 0.3s;
182
+ }
183
+ .gr-button:hover {
184
+ background-color: #5a5a5a;
185
+ }
186
+ .gr-input, .gr-dropdown {
187
+ background-color: #3a3a3a;
188
+ color: #fff;
189
+ border: 1px solid #4a4a4a;
190
+ border-radius: 4px;
191
+ padding: 8px;
192
+ }
193
+ .gr-form {
194
+ background-color: transparent;
195
+ }
196
+ .gr-panel {
197
+ border: none;
198
+ background-color: transparent;
199
+ }
200
+ /* Override any conflicting styles from Bulma */
201
+ .button.is-normal.is-rounded.is-dark {
202
+ color: #fff !important;
203
+ text-decoration: none !important;
204
+ }
205
+ </style>
206
+ """
207
+
208
+ header_html = f"""
209
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/bulma@0.9.3/css/bulma.min.css">
210
+ <link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.4/css/all.css">
211
+ {app_styles}
212
+ <div class="app-header">
213
+ <h1 class="app-title">Sapiens:Body-Part Segmentation</h1>
214
+ <h2 class="app-subtitle">ECCV 2024 (Oral)</h2>
215
+ <p class="app-description">
216
+ Meta presents Sapiens, foundation models for human tasks pretrained on 300 million human images.
217
+ This demo showcases the finetuned body-part segmentation model. <br>
218
+ </p>
219
+ <div class="publication-links">
220
+ <a href="https://arxiv.org/abs/2408.12569" class="publication-link">
221
+ <i class="fas fa-file-pdf"></i>arXiv
222
+ </a>
223
+ <a href="https://github.com/facebookresearch/sapiens" class="publication-link">
224
+ <i class="fab fa-github"></i>Code
225
+ </a>
226
+ <a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" class="publication-link">
227
+ <i class="fas fa-globe"></i>Meta
228
+ </a>
229
+ <a href="https://rawalkhirodkar.github.io/sapiens" class="publication-link">
230
+ <i class="fas fa-chart-bar"></i>Results
231
+ </a>
232
+ </div>
233
+ <div class="publication-links">
234
+ <a href="https://huggingface.co/spaces/facebook/sapiens_pose" class="publication-link">
235
+ <i class="fas fa-user"></i>Demo-Pose
236
+ </a>
237
+ <a href="https://huggingface.co/spaces/facebook/sapiens_seg" class="publication-link">
238
+ <i class="fas fa-puzzle-piece"></i>Demo-Seg
239
+ </a>
240
+ <a href="https://huggingface.co/spaces/facebook/sapiens_depth" class="publication-link">
241
+ <i class="fas fa-cube"></i>Demo-Depth
242
+ </a>
243
+ <a href="https://huggingface.co/spaces/facebook/sapiens_normal" class="publication-link">
244
+ <i class="fas fa-vector-square"></i>Demo-Normal
245
+ </a>
246
+ </div>
247
+ </div>
248
+ """
249
+
250
+ js_func = """
251
+ function refresh() {
252
+ const url = new URL(window.location);
253
+ if (url.searchParams.get('__theme') !== 'dark') {
254
+ url.searchParams.set('__theme', 'dark');
255
+ window.location.href = url.href;
256
+ }
257
+ }
258
+ """
259
+
260
+ def process_image(image, model_name):
261
+ result, npy_path = self.image_processor.process_image(image, model_name)
262
+ return result, npy_path
263
+
264
+ with gr.Blocks(js=js_func, theme=gr.themes.Default()) as demo:
265
+ gr.HTML(header_html)
266
+ with gr.Row(elem_classes="content-container"):
267
+ with gr.Column():
268
+ input_image = gr.Image(label="Input Image", type="pil", format="png", elem_classes="image-preview")
269
+ model_name = gr.Dropdown(
270
+ label="Model Size",
271
+ choices=list(Config.CHECKPOINTS.keys()),
272
+ value="1b",
273
+ )
274
+ example_model = gr.Examples(
275
+ inputs=input_image,
276
+ examples_per_page=14,
277
+ examples=[
278
+ os.path.join(Config.ASSETS_DIR, "images", img)
279
+ for img in os.listdir(os.path.join(Config.ASSETS_DIR, "images"))
280
+ ],
281
+ )
282
+ with gr.Column():
283
+ result_image = gr.Image(label="Segmentation Result", type="pil", elem_classes="image-preview")
284
+ npy_output = gr.File(label="Segmentation (.npy)")
285
+ run_button = gr.Button("Run")
286
+ gr.Image(os.path.join(Config.ASSETS_DIR, "palette.jpg"), label="Class Palette", type="filepath", elem_classes="image-preview")
287
+
288
+ run_button.click(
289
+ fn=process_image,
290
+ inputs=[input_image, model_name],
291
+ outputs=[result_image, npy_output],
292
+ )
293
+
294
+ return demo
295
+
296
+ def main():
297
+ # Configure CUDA if available
298
+ if torch.cuda.is_available() and torch.cuda.get_device_properties(0).major >= 8:
299
+ torch.backends.cuda.matmul.allow_tf32 = True
300
+ torch.backends.cudnn.allow_tf32 = True
301
+
302
+ interface = GradioInterface()
303
+ demo = interface.create_interface()
304
+ demo.launch(share=False)
305
+
306
+ if __name__ == "__main__":
307
+ main()