khanrc commited on
Commit
e0ca513
·
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pth
2
+ __pycache__
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "tcl"]
2
+ path = tcl
3
+ url = https://github.com/kakaobrain/tcl.git
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: TCL
3
+ emoji: 💩
4
+ colorFrom: yellow
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 3.23.0
8
+ app_file: app.py
9
+ app_port: 9718
10
+ pinned: false
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from importlib.util import find_spec
4
+
5
+ print("Prepare demo ...")
6
+ if not os.path.exists("tcl.pth"):
7
+ print("Download TCL checkpoint ...")
8
+ os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth")
9
+
10
+ if not (find_spec("mmcv") and find_spec("mmseg")):
11
+ print("Install mmcv & mmseg ...")
12
+ os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0")
13
+
14
+ if not find_spec("detectron2"):
15
+ print("Install detectron ...")
16
+ os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
17
+
18
+ sys.path.insert(0, "./tcl/")
19
+
20
+ print(" -- done.")
21
+
22
+ import json
23
+ from contextlib import ExitStack
24
+ import gradio as gr
25
+ import torch
26
+ from torch.cuda.amp import autocast
27
+
28
+ from detectron2.evaluation import inference_context
29
+
30
+ from predictor import build_demo_model
31
+
32
+
33
+ model = build_demo_model()
34
+ if torch.cuda.is_available():
35
+ device = torch.device("cuda")
36
+ else:
37
+ device = torch.device("cpu")
38
+
39
+ print(f"device: {device}")
40
+ model.to(device)
41
+
42
+
43
+ title = "TCL: Text-grounded Contrastive Learning"
44
+ description_head = """
45
+ <p style='text-align: center'> <a href='https://arxiv.org/abs/2212.00785' target='_blank'>Paper</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Code</a> </p>
46
+ """
47
+
48
+ description_body = f"""
49
+ Gradio Demo for "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs".
50
+
51
+ Explore TCL's capability to perform open-world semantic segmentation **without any mask annotations**. Choose from provided examples or upload your own image. Use the query format `bg; class1; class2; ...`, with `;` as the separator, and the `bg` background query being optional (as in the third example).
52
+
53
+ This demo highlights the strengths and limitations of unsupervised open-world segmentation methods. Although TCL can handle arbitrary concepts, accurately capturing object boundaries without mask annotation remains a challenge.
54
+ """
55
+
56
+ if device.type == "cpu":
57
+ description_body += f"\nInference takes about 10 seconds since this demo is running on the free CPU device."
58
+
59
+ description = description_head + description_body
60
+
61
+ article = """
62
+ <p style='text-align: center'><a href='https://arxiv.org/abs/2212.00785' target='_blank'>Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Github Repo</a></p>
63
+ """
64
+
65
+ voc_examples = [
66
+ ["examples/voc_59.jpg", "bg; cat; dog"],
67
+ ["examples/voc_97.jpg", "bg; car"],
68
+ ["examples/voc_266.jpg", "bg; dog"],
69
+ ["examples/voc_294.jpg", "bg; bird"],
70
+ ["examples/voc_864.jpg", "bg; cat"],
71
+ ["examples/voc_1029.jpg", "bg; bus"],
72
+ ]
73
+
74
+ examples = [
75
+ [
76
+ "examples/dogs.jpg",
77
+ "bg; corgi; shepherd",
78
+ ],
79
+ [
80
+ "examples/dogs.jpg",
81
+ "bg; dog",
82
+ ],
83
+ [
84
+ "examples/dogs.jpg",
85
+ "corgi; shepherd; lawn, trees, and fallen leaves",
86
+ ],
87
+ [
88
+ "examples/banana.jpg",
89
+ "bg; banana",
90
+ ],
91
+ [
92
+ "examples/banana.jpg",
93
+ "bg; red banana; green banana; yellow banana",
94
+ ],
95
+ [
96
+ "examples/frodo_sam_gollum.jpg",
97
+ "bg; frodo; gollum; samwise",
98
+ ],
99
+ [
100
+ "examples/frodo_sam_gollum.jpg",
101
+ "bg; rocks; monster; boys with cape"
102
+ ],
103
+ [
104
+ "examples/mb_mj.jpg",
105
+ "bg; marlon brando; michael jackson",
106
+ ],
107
+ ]
108
+
109
+ examples = examples + voc_examples
110
+
111
+
112
+ def inference(img, query):
113
+ query = query.split(";")
114
+ query = [v.strip() for v in query]
115
+
116
+ with ExitStack() as stack:
117
+ stack.enter_context(inference_context(model))
118
+ stack.enter_context(torch.no_grad())
119
+
120
+ with autocast():
121
+ visualized_output = model.forward_vis(img, query)
122
+
123
+ return visualized_output
124
+
125
+
126
+ theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal")
127
+ with gr.Blocks(title=title, theme=theme) as demo:
128
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
129
+ gr.Markdown(description)
130
+ input_components = []
131
+ output_components = []
132
+
133
+ with gr.Row():
134
+ with gr.Column(scale=4, variant="panel"):
135
+ output_image_gr = gr.outputs.Image(label="Segmentation", type="pil").style(height=300)
136
+ output_components.append(output_image_gr)
137
+
138
+ with gr.Row():
139
+ input_gr = gr.inputs.Image(type="pil")
140
+ query_gr = gr.inputs.Textbox(default="", label="Query")
141
+ input_components.extend([input_gr, query_gr])
142
+
143
+ with gr.Row():
144
+ clear_btn = gr.Button("Clear")
145
+ submit_btn = gr.Button("Submit", variant="primary")
146
+
147
+ inputs = [c for c in input_components if not isinstance(c, gr.State)]
148
+ outputs = [c for c in output_components if not isinstance(c, gr.State)]
149
+ with gr.Column(scale=2):
150
+ examples_handler = gr.Examples(
151
+ examples=examples,
152
+ inputs=inputs,
153
+ outputs=outputs,
154
+ fn=inference,
155
+ cache_examples=True,
156
+ examples_per_page=7,
157
+ )
158
+
159
+ gr.Markdown(article)
160
+
161
+ submit_btn.click(
162
+ inference,
163
+ input_components,
164
+ output_components,
165
+ scroll_to_output=True,
166
+ )
167
+
168
+ clear_btn.click(
169
+ None,
170
+ [],
171
+ (input_components + output_components),
172
+ _js=f"""() => {json.dumps(
173
+ [component.cleared_value if hasattr(component, "cleared_value") else None
174
+ for component in input_components + output_components] + (
175
+ [gr.Column.update(visible=True)]
176
+ )
177
+ + ([gr.Column.update(visible=False)])
178
+ )}
179
+ """,
180
+ )
181
+
182
+ demo.launch()
183
+ # demo.launch(server_name="0.0.0.0", server_port=9718)
examples/banana.jpg ADDED
examples/dogs.jpg ADDED
examples/frodo_sam_gollum.jpg ADDED
examples/mb_mj.jpg ADDED
examples/voc_1029.jpg ADDED
examples/voc_1136.jpg ADDED
examples/voc_1296.jpg ADDED
examples/voc_266.jpg ADDED
examples/voc_294.jpg ADDED
examples/voc_296.jpg ADDED
examples/voc_567.jpg ADDED
examples/voc_59.jpg ADDED
examples/voc_84.jpg ADDED
examples/voc_864.jpg ADDED
examples/voc_97.jpg ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ wget
predictor.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms as T
4
+ from omegaconf import OmegaConf
5
+ from typing import List
6
+ from mmseg import datasets as mmseg_datasets
7
+
8
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+ from detectron2.data import MetadataCatalog
13
+ from detectron2.utils.visualizer import Visualizer
14
+
15
+ # TCL
16
+ from models import build_model
17
+ from models.tcl.pamr import PAMR
18
+ from datasets.builder import build_text_transform
19
+ from segmentation.evaluation.builder import build_dataset_class_tokens
20
+
21
+ PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE
22
+ PALETTE *= 5
23
+
24
+
25
+ def build_demo_model(ckpt_path="./tcl.pth", size=224):
26
+ # Load TCL model
27
+ print(f"Load {ckpt_path} ...")
28
+ ckpt = torch.load(ckpt_path)
29
+ cfg = OmegaConf.load("./tcl/configs/tcl.yml")
30
+ model = build_model(cfg.model)
31
+
32
+ # The (minimal) checkpoint only contains learned parameters; Frozen CLIP params are not contained.
33
+ model.load_state_dict(ckpt['model'], strict=False)
34
+ model.eval()
35
+
36
+ # build TCLDemo
37
+ demo = TCLDemo(model, size)
38
+
39
+ return demo
40
+
41
+
42
+ def _convert_image_to_rgb(image):
43
+ return image.convert("RGB")
44
+
45
+
46
+ def _transform(n_px):
47
+ return T.Compose([
48
+ T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
49
+ _convert_image_to_rgb,
50
+ T.ToTensor(),
51
+ T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
52
+ ])
53
+
54
+
55
+ class TCLDemo(nn.Module):
56
+ """
57
+ Args:
58
+ model: TCL model
59
+ size: resize shorter side of image to `size`
60
+ """
61
+ def __init__(self, model, size=224):
62
+ super().__init__()
63
+ self.model = model
64
+ self.size = size
65
+
66
+ self.preprocess = _transform(size)
67
+ self.tokenizer = build_text_transform()
68
+ self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval()
69
+
70
+ @property
71
+ def device(self):
72
+ return next(self.model.parameters()).device
73
+
74
+ def build_text_embedding(self, texts: List[str]):
75
+ text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts)
76
+ text_embeddings = self.model.build_text_embedding(text_tokens)
77
+ return text_embeddings
78
+
79
+ def forward(self, image, texts: List[str], apply_pamr=True):
80
+ """
81
+ Args:
82
+ image: PIL.Image
83
+ texts: List[str]
84
+ """
85
+ with_bg = False
86
+ if texts[0] in ["bg", "background"]:
87
+ with_bg = True
88
+ texts = texts[1:]
89
+
90
+ # preprocess
91
+ image = self.preprocess(image).unsqueeze(0).to(self.device)
92
+ text_embs = self.build_text_embedding(texts)
93
+
94
+ # forward
95
+ mask, simmap = self.model.generate_masks(
96
+ image,
97
+ text_embs,
98
+ )
99
+
100
+ # refinement
101
+ if apply_pamr:
102
+ mask = self.pamr(image, mask)
103
+
104
+ I, T, H, W = mask.shape
105
+ if with_bg:
106
+ bg_thresh = 0.4 if apply_pamr else 0.5
107
+ bg = torch.full(
108
+ [I, 1, H, W],
109
+ bg_thresh,
110
+ dtype=torch.float,
111
+ device=mask.device
112
+ )
113
+ mask = torch.cat([bg, mask], dim=1)
114
+
115
+ return mask
116
+
117
+ def visualize(self, image, texts, mask):
118
+ """
119
+ Args:
120
+ image (PIL.Image)
121
+ texts (List[str])
122
+ mask (Tensor)
123
+ """
124
+ with_bg = texts[0] in ["bg", "background"]
125
+
126
+ N = len(texts)
127
+ if with_bg:
128
+ palette = PALETTE
129
+ else:
130
+ palette = PALETTE[1:]
131
+
132
+ MetadataCatalog.pop("__unused", None)
133
+ md = MetadataCatalog.get("__unused")
134
+ md.set(
135
+ thing_classes=texts,
136
+ thing_colors=palette,
137
+ stuff_classes=texts,
138
+ stuff_colors=palette,
139
+ )
140
+
141
+ seg_res = mask.squeeze(0).argmax(0).cpu()
142
+ if with_bg:
143
+ seg_res[seg_res == 0] = N + 10
144
+
145
+ image = image.resize(mask.shape[2:][::-1])
146
+ image = np.asarray(image)
147
+
148
+ visualizer = Visualizer(image, md)
149
+ r = visualizer.draw_sem_seg(seg_res)
150
+
151
+ res = Image.fromarray(r.get_image())
152
+
153
+ return res
154
+
155
+ def forward_vis(self, image, texts, apply_pamr=True):
156
+ mask = self(image, texts, apply_pamr=apply_pamr)
157
+ res = self.visualize(image, texts, mask)
158
+
159
+ return res
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ torchvision==0.13.1
3
+
4
+ webdataset==0.1.103
5
+ timm==0.6.7
6
+ einops==0.4.1
7
+ tqdm==4.62.3
8
+ wandb==0.12.18
9
+ regex==2022.6.2
10
+ braceexpand==0.1.7
11
+ ftfy==6.1.1
12
+ numpy==1.21.2
13
+ omegaconf==2.2.2
14
+ Pillow==9.3.0
15
+ termcolor==1.1.0
16
+ openmim
tcl ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit e8e84cac4f31c3718356137208c9269477aa1ef8