Commit
·
e0ca513
0
Parent(s):
initial commit
Browse files- .gitattributes +34 -0
- .gitignore +2 -0
- .gitmodules +3 -0
- README.md +13 -0
- app.py +183 -0
- examples/banana.jpg +0 -0
- examples/dogs.jpg +0 -0
- examples/frodo_sam_gollum.jpg +0 -0
- examples/mb_mj.jpg +0 -0
- examples/voc_1029.jpg +0 -0
- examples/voc_1136.jpg +0 -0
- examples/voc_1296.jpg +0 -0
- examples/voc_266.jpg +0 -0
- examples/voc_294.jpg +0 -0
- examples/voc_296.jpg +0 -0
- examples/voc_567.jpg +0 -0
- examples/voc_59.jpg +0 -0
- examples/voc_84.jpg +0 -0
- examples/voc_864.jpg +0 -0
- examples/voc_97.jpg +0 -0
- packages.txt +1 -0
- predictor.py +159 -0
- requirements.txt +16 -0
- tcl +1 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.pth
|
2 |
+
__pycache__
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "tcl"]
|
2 |
+
path = tcl
|
3 |
+
url = https://github.com/kakaobrain/tcl.git
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: TCL
|
3 |
+
emoji: 💩
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.23.0
|
8 |
+
app_file: app.py
|
9 |
+
app_port: 9718
|
10 |
+
pinned: false
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from importlib.util import find_spec
|
4 |
+
|
5 |
+
print("Prepare demo ...")
|
6 |
+
if not os.path.exists("tcl.pth"):
|
7 |
+
print("Download TCL checkpoint ...")
|
8 |
+
os.system("wget -q https://github.com/kakaobrain/tcl/releases/download/v1.0.0/tcl.pth")
|
9 |
+
|
10 |
+
if not (find_spec("mmcv") and find_spec("mmseg")):
|
11 |
+
print("Install mmcv & mmseg ...")
|
12 |
+
os.system("mim install mmcv-full==1.6.2 mmsegmentation==0.27.0")
|
13 |
+
|
14 |
+
if not find_spec("detectron2"):
|
15 |
+
print("Install detectron ...")
|
16 |
+
os.system("pip install git+https://github.com/facebookresearch/detectron2.git")
|
17 |
+
|
18 |
+
sys.path.insert(0, "./tcl/")
|
19 |
+
|
20 |
+
print(" -- done.")
|
21 |
+
|
22 |
+
import json
|
23 |
+
from contextlib import ExitStack
|
24 |
+
import gradio as gr
|
25 |
+
import torch
|
26 |
+
from torch.cuda.amp import autocast
|
27 |
+
|
28 |
+
from detectron2.evaluation import inference_context
|
29 |
+
|
30 |
+
from predictor import build_demo_model
|
31 |
+
|
32 |
+
|
33 |
+
model = build_demo_model()
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = torch.device("cuda")
|
36 |
+
else:
|
37 |
+
device = torch.device("cpu")
|
38 |
+
|
39 |
+
print(f"device: {device}")
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
|
43 |
+
title = "TCL: Text-grounded Contrastive Learning"
|
44 |
+
description_head = """
|
45 |
+
<p style='text-align: center'> <a href='https://arxiv.org/abs/2212.00785' target='_blank'>Paper</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Code</a> </p>
|
46 |
+
"""
|
47 |
+
|
48 |
+
description_body = f"""
|
49 |
+
Gradio Demo for "Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs".
|
50 |
+
|
51 |
+
Explore TCL's capability to perform open-world semantic segmentation **without any mask annotations**. Choose from provided examples or upload your own image. Use the query format `bg; class1; class2; ...`, with `;` as the separator, and the `bg` background query being optional (as in the third example).
|
52 |
+
|
53 |
+
This demo highlights the strengths and limitations of unsupervised open-world segmentation methods. Although TCL can handle arbitrary concepts, accurately capturing object boundaries without mask annotation remains a challenge.
|
54 |
+
"""
|
55 |
+
|
56 |
+
if device.type == "cpu":
|
57 |
+
description_body += f"\nInference takes about 10 seconds since this demo is running on the free CPU device."
|
58 |
+
|
59 |
+
description = description_head + description_body
|
60 |
+
|
61 |
+
article = """
|
62 |
+
<p style='text-align: center'><a href='https://arxiv.org/abs/2212.00785' target='_blank'>Learning to Generate Text-grounded Mask for Open-world Semantic Segmentation from Only Image-Text Pairs</a> | <a href='https://github.com/kakaobrain/tcl' target='_blank'>Github Repo</a></p>
|
63 |
+
"""
|
64 |
+
|
65 |
+
voc_examples = [
|
66 |
+
["examples/voc_59.jpg", "bg; cat; dog"],
|
67 |
+
["examples/voc_97.jpg", "bg; car"],
|
68 |
+
["examples/voc_266.jpg", "bg; dog"],
|
69 |
+
["examples/voc_294.jpg", "bg; bird"],
|
70 |
+
["examples/voc_864.jpg", "bg; cat"],
|
71 |
+
["examples/voc_1029.jpg", "bg; bus"],
|
72 |
+
]
|
73 |
+
|
74 |
+
examples = [
|
75 |
+
[
|
76 |
+
"examples/dogs.jpg",
|
77 |
+
"bg; corgi; shepherd",
|
78 |
+
],
|
79 |
+
[
|
80 |
+
"examples/dogs.jpg",
|
81 |
+
"bg; dog",
|
82 |
+
],
|
83 |
+
[
|
84 |
+
"examples/dogs.jpg",
|
85 |
+
"corgi; shepherd; lawn, trees, and fallen leaves",
|
86 |
+
],
|
87 |
+
[
|
88 |
+
"examples/banana.jpg",
|
89 |
+
"bg; banana",
|
90 |
+
],
|
91 |
+
[
|
92 |
+
"examples/banana.jpg",
|
93 |
+
"bg; red banana; green banana; yellow banana",
|
94 |
+
],
|
95 |
+
[
|
96 |
+
"examples/frodo_sam_gollum.jpg",
|
97 |
+
"bg; frodo; gollum; samwise",
|
98 |
+
],
|
99 |
+
[
|
100 |
+
"examples/frodo_sam_gollum.jpg",
|
101 |
+
"bg; rocks; monster; boys with cape"
|
102 |
+
],
|
103 |
+
[
|
104 |
+
"examples/mb_mj.jpg",
|
105 |
+
"bg; marlon brando; michael jackson",
|
106 |
+
],
|
107 |
+
]
|
108 |
+
|
109 |
+
examples = examples + voc_examples
|
110 |
+
|
111 |
+
|
112 |
+
def inference(img, query):
|
113 |
+
query = query.split(";")
|
114 |
+
query = [v.strip() for v in query]
|
115 |
+
|
116 |
+
with ExitStack() as stack:
|
117 |
+
stack.enter_context(inference_context(model))
|
118 |
+
stack.enter_context(torch.no_grad())
|
119 |
+
|
120 |
+
with autocast():
|
121 |
+
visualized_output = model.forward_vis(img, query)
|
122 |
+
|
123 |
+
return visualized_output
|
124 |
+
|
125 |
+
|
126 |
+
theme = gr.themes.Soft(text_size=gr.themes.sizes.text_md, primary_hue="teal")
|
127 |
+
with gr.Blocks(title=title, theme=theme) as demo:
|
128 |
+
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>" + title + "</h1>")
|
129 |
+
gr.Markdown(description)
|
130 |
+
input_components = []
|
131 |
+
output_components = []
|
132 |
+
|
133 |
+
with gr.Row():
|
134 |
+
with gr.Column(scale=4, variant="panel"):
|
135 |
+
output_image_gr = gr.outputs.Image(label="Segmentation", type="pil").style(height=300)
|
136 |
+
output_components.append(output_image_gr)
|
137 |
+
|
138 |
+
with gr.Row():
|
139 |
+
input_gr = gr.inputs.Image(type="pil")
|
140 |
+
query_gr = gr.inputs.Textbox(default="", label="Query")
|
141 |
+
input_components.extend([input_gr, query_gr])
|
142 |
+
|
143 |
+
with gr.Row():
|
144 |
+
clear_btn = gr.Button("Clear")
|
145 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
146 |
+
|
147 |
+
inputs = [c for c in input_components if not isinstance(c, gr.State)]
|
148 |
+
outputs = [c for c in output_components if not isinstance(c, gr.State)]
|
149 |
+
with gr.Column(scale=2):
|
150 |
+
examples_handler = gr.Examples(
|
151 |
+
examples=examples,
|
152 |
+
inputs=inputs,
|
153 |
+
outputs=outputs,
|
154 |
+
fn=inference,
|
155 |
+
cache_examples=True,
|
156 |
+
examples_per_page=7,
|
157 |
+
)
|
158 |
+
|
159 |
+
gr.Markdown(article)
|
160 |
+
|
161 |
+
submit_btn.click(
|
162 |
+
inference,
|
163 |
+
input_components,
|
164 |
+
output_components,
|
165 |
+
scroll_to_output=True,
|
166 |
+
)
|
167 |
+
|
168 |
+
clear_btn.click(
|
169 |
+
None,
|
170 |
+
[],
|
171 |
+
(input_components + output_components),
|
172 |
+
_js=f"""() => {json.dumps(
|
173 |
+
[component.cleared_value if hasattr(component, "cleared_value") else None
|
174 |
+
for component in input_components + output_components] + (
|
175 |
+
[gr.Column.update(visible=True)]
|
176 |
+
)
|
177 |
+
+ ([gr.Column.update(visible=False)])
|
178 |
+
)}
|
179 |
+
""",
|
180 |
+
)
|
181 |
+
|
182 |
+
demo.launch()
|
183 |
+
# demo.launch(server_name="0.0.0.0", server_port=9718)
|
examples/banana.jpg
ADDED
![]() |
examples/dogs.jpg
ADDED
![]() |
examples/frodo_sam_gollum.jpg
ADDED
![]() |
examples/mb_mj.jpg
ADDED
![]() |
examples/voc_1029.jpg
ADDED
![]() |
examples/voc_1136.jpg
ADDED
![]() |
examples/voc_1296.jpg
ADDED
![]() |
examples/voc_266.jpg
ADDED
![]() |
examples/voc_294.jpg
ADDED
![]() |
examples/voc_296.jpg
ADDED
![]() |
examples/voc_567.jpg
ADDED
![]() |
examples/voc_59.jpg
ADDED
![]() |
examples/voc_84.jpg
ADDED
![]() |
examples/voc_864.jpg
ADDED
![]() |
examples/voc_97.jpg
ADDED
![]() |
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
wget
|
predictor.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision import transforms as T
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from typing import List
|
6 |
+
from mmseg import datasets as mmseg_datasets
|
7 |
+
|
8 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
from PIL import Image
|
12 |
+
from detectron2.data import MetadataCatalog
|
13 |
+
from detectron2.utils.visualizer import Visualizer
|
14 |
+
|
15 |
+
# TCL
|
16 |
+
from models import build_model
|
17 |
+
from models.tcl.pamr import PAMR
|
18 |
+
from datasets.builder import build_text_transform
|
19 |
+
from segmentation.evaluation.builder import build_dataset_class_tokens
|
20 |
+
|
21 |
+
PALETTE = mmseg_datasets.PascalVOCDataset.PALETTE + mmseg_datasets.COCOStuffDataset.PALETTE
|
22 |
+
PALETTE *= 5
|
23 |
+
|
24 |
+
|
25 |
+
def build_demo_model(ckpt_path="./tcl.pth", size=224):
|
26 |
+
# Load TCL model
|
27 |
+
print(f"Load {ckpt_path} ...")
|
28 |
+
ckpt = torch.load(ckpt_path)
|
29 |
+
cfg = OmegaConf.load("./tcl/configs/tcl.yml")
|
30 |
+
model = build_model(cfg.model)
|
31 |
+
|
32 |
+
# The (minimal) checkpoint only contains learned parameters; Frozen CLIP params are not contained.
|
33 |
+
model.load_state_dict(ckpt['model'], strict=False)
|
34 |
+
model.eval()
|
35 |
+
|
36 |
+
# build TCLDemo
|
37 |
+
demo = TCLDemo(model, size)
|
38 |
+
|
39 |
+
return demo
|
40 |
+
|
41 |
+
|
42 |
+
def _convert_image_to_rgb(image):
|
43 |
+
return image.convert("RGB")
|
44 |
+
|
45 |
+
|
46 |
+
def _transform(n_px):
|
47 |
+
return T.Compose([
|
48 |
+
T.Resize(n_px, interpolation=T.InterpolationMode.BICUBIC),
|
49 |
+
_convert_image_to_rgb,
|
50 |
+
T.ToTensor(),
|
51 |
+
T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
|
52 |
+
])
|
53 |
+
|
54 |
+
|
55 |
+
class TCLDemo(nn.Module):
|
56 |
+
"""
|
57 |
+
Args:
|
58 |
+
model: TCL model
|
59 |
+
size: resize shorter side of image to `size`
|
60 |
+
"""
|
61 |
+
def __init__(self, model, size=224):
|
62 |
+
super().__init__()
|
63 |
+
self.model = model
|
64 |
+
self.size = size
|
65 |
+
|
66 |
+
self.preprocess = _transform(size)
|
67 |
+
self.tokenizer = build_text_transform()
|
68 |
+
self.pamr = PAMR(10, [1, 2, 4, 8, 12, 24]).eval()
|
69 |
+
|
70 |
+
@property
|
71 |
+
def device(self):
|
72 |
+
return next(self.model.parameters()).device
|
73 |
+
|
74 |
+
def build_text_embedding(self, texts: List[str]):
|
75 |
+
text_tokens = build_dataset_class_tokens(self.tokenizer, "custom", texts)
|
76 |
+
text_embeddings = self.model.build_text_embedding(text_tokens)
|
77 |
+
return text_embeddings
|
78 |
+
|
79 |
+
def forward(self, image, texts: List[str], apply_pamr=True):
|
80 |
+
"""
|
81 |
+
Args:
|
82 |
+
image: PIL.Image
|
83 |
+
texts: List[str]
|
84 |
+
"""
|
85 |
+
with_bg = False
|
86 |
+
if texts[0] in ["bg", "background"]:
|
87 |
+
with_bg = True
|
88 |
+
texts = texts[1:]
|
89 |
+
|
90 |
+
# preprocess
|
91 |
+
image = self.preprocess(image).unsqueeze(0).to(self.device)
|
92 |
+
text_embs = self.build_text_embedding(texts)
|
93 |
+
|
94 |
+
# forward
|
95 |
+
mask, simmap = self.model.generate_masks(
|
96 |
+
image,
|
97 |
+
text_embs,
|
98 |
+
)
|
99 |
+
|
100 |
+
# refinement
|
101 |
+
if apply_pamr:
|
102 |
+
mask = self.pamr(image, mask)
|
103 |
+
|
104 |
+
I, T, H, W = mask.shape
|
105 |
+
if with_bg:
|
106 |
+
bg_thresh = 0.4 if apply_pamr else 0.5
|
107 |
+
bg = torch.full(
|
108 |
+
[I, 1, H, W],
|
109 |
+
bg_thresh,
|
110 |
+
dtype=torch.float,
|
111 |
+
device=mask.device
|
112 |
+
)
|
113 |
+
mask = torch.cat([bg, mask], dim=1)
|
114 |
+
|
115 |
+
return mask
|
116 |
+
|
117 |
+
def visualize(self, image, texts, mask):
|
118 |
+
"""
|
119 |
+
Args:
|
120 |
+
image (PIL.Image)
|
121 |
+
texts (List[str])
|
122 |
+
mask (Tensor)
|
123 |
+
"""
|
124 |
+
with_bg = texts[0] in ["bg", "background"]
|
125 |
+
|
126 |
+
N = len(texts)
|
127 |
+
if with_bg:
|
128 |
+
palette = PALETTE
|
129 |
+
else:
|
130 |
+
palette = PALETTE[1:]
|
131 |
+
|
132 |
+
MetadataCatalog.pop("__unused", None)
|
133 |
+
md = MetadataCatalog.get("__unused")
|
134 |
+
md.set(
|
135 |
+
thing_classes=texts,
|
136 |
+
thing_colors=palette,
|
137 |
+
stuff_classes=texts,
|
138 |
+
stuff_colors=palette,
|
139 |
+
)
|
140 |
+
|
141 |
+
seg_res = mask.squeeze(0).argmax(0).cpu()
|
142 |
+
if with_bg:
|
143 |
+
seg_res[seg_res == 0] = N + 10
|
144 |
+
|
145 |
+
image = image.resize(mask.shape[2:][::-1])
|
146 |
+
image = np.asarray(image)
|
147 |
+
|
148 |
+
visualizer = Visualizer(image, md)
|
149 |
+
r = visualizer.draw_sem_seg(seg_res)
|
150 |
+
|
151 |
+
res = Image.fromarray(r.get_image())
|
152 |
+
|
153 |
+
return res
|
154 |
+
|
155 |
+
def forward_vis(self, image, texts, apply_pamr=True):
|
156 |
+
mask = self(image, texts, apply_pamr=apply_pamr)
|
157 |
+
res = self.visualize(image, texts, mask)
|
158 |
+
|
159 |
+
return res
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.1
|
2 |
+
torchvision==0.13.1
|
3 |
+
|
4 |
+
webdataset==0.1.103
|
5 |
+
timm==0.6.7
|
6 |
+
einops==0.4.1
|
7 |
+
tqdm==4.62.3
|
8 |
+
wandb==0.12.18
|
9 |
+
regex==2022.6.2
|
10 |
+
braceexpand==0.1.7
|
11 |
+
ftfy==6.1.1
|
12 |
+
numpy==1.21.2
|
13 |
+
omegaconf==2.2.2
|
14 |
+
Pillow==9.3.0
|
15 |
+
termcolor==1.1.0
|
16 |
+
openmim
|
tcl
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit e8e84cac4f31c3718356137208c9269477aa1ef8
|