Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
import os | |
import sys | |
import torch | |
import trimesh | |
from torch import nn | |
from transformers import AutoModelForCausalLM | |
from transformers.generation.logits_process import LogitsProcessorList | |
from einops import rearrange | |
from modules.bbox_gen.models.image_encoder import DINOv2ImageEncoder | |
from modules.bbox_gen.config import parse_structured | |
from modules.bbox_gen.models.bboxopt import BBoxOPT, BBoxOPTConfig | |
from modules.bbox_gen.utils.bbox_tokenizer import BoundsTokenizerDiag | |
from modules.bbox_gen.models.bbox_gen_models import GroupEmbedding, MultiModalProjector, MeshDecodeLogitsProcessor, SparseStructureEncoder | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
modules_dir = os.path.dirname(os.path.dirname(current_dir)) | |
partfield_dir = os.path.join(modules_dir, 'PartField') | |
if partfield_dir not in sys.path: | |
sys.path.insert(0, partfield_dir) | |
import importlib.util | |
from partfield.config import default_argument_parser, setup | |
class BboxGen(nn.Module): | |
class Config: | |
# encoder config | |
encoder_dim_feat: int = 3 | |
encoder_dim: int = 64 | |
encoder_heads: int = 4 | |
encoder_token_num: int = 256 | |
encoder_qkv_bias: bool = False | |
encoder_use_ln_post: bool = True | |
encoder_use_checkpoint: bool = False | |
encoder_num_embed_freqs: int = 8 | |
encoder_embed_include_pi: bool = False | |
encoder_init_scale: float = 0.25 | |
encoder_random_fps: bool = True | |
encoder_learnable_query: bool = False | |
encoder_layers: int = 4 | |
group_embedding_dim: int = 64 | |
# decoder config | |
vocab_size: int = 518 | |
decoder_hidden_size: int = 1536 | |
decoder_num_hidden_layers: int = 24 | |
decoder_ffn_dim: int = 6144 | |
decoder_heads: int = 16 | |
decoder_use_flash_attention: bool = True | |
decoder_gradient_checkpointing: bool = True | |
# data config | |
bins: int = 64 | |
BOS_id: int = 64 | |
EOS_id: int = 65 | |
PAD_id: int = 66 | |
max_length: int = 2187 # bos + 50x2x3 + 1374 + 512 | |
voxel_token_length: int = 1886 | |
voxel_token_placeholder: int = -1 | |
# tokenizer config | |
max_group_size: int = 50 | |
# voxel encoder | |
partfield_encoder_path: str = "" | |
cfg: Config | |
def __init__(self, cfg): | |
super().__init__() | |
self.cfg = parse_structured(self.Config, cfg) | |
self.image_encoder = DINOv2ImageEncoder( | |
model_name="facebook/dinov2-with-registers-large", | |
) | |
self.image_projector = MultiModalProjector( | |
in_features=(1024 + self.cfg.group_embedding_dim), | |
out_features=self.cfg.decoder_hidden_size, | |
) | |
self.group_embedding = GroupEmbedding( | |
max_group_size=self.cfg.max_group_size, | |
hidden_size=self.cfg.group_embedding_dim, | |
) | |
self.decoder_config = BBoxOPTConfig( | |
vocab_size=self.cfg.vocab_size, | |
hidden_size=self.cfg.decoder_hidden_size, | |
num_hidden_layers=self.cfg.decoder_num_hidden_layers, | |
ffn_dim=self.cfg.decoder_ffn_dim, | |
max_position_embeddings=self.cfg.max_length, | |
num_attention_heads=self.cfg.decoder_heads, | |
pad_token_id=self.cfg.PAD_id, | |
bos_token_id=self.cfg.BOS_id, | |
eos_token_id=self.cfg.EOS_id, | |
use_cache=True, | |
init_std=0.02, | |
) | |
if self.cfg.decoder_use_flash_attention: | |
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config( | |
self.decoder_config, | |
torch_dtype=torch.bfloat16, | |
attn_implementation="flash_attention_2" | |
) | |
else: | |
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config( | |
self.decoder_config, | |
) | |
if self.cfg.decoder_gradient_checkpointing: | |
self.decoder.gradient_checkpointing_enable() | |
self.logits_processor = LogitsProcessorList() | |
self.logits_processor.append(MeshDecodeLogitsProcessor( | |
bins=self.cfg.bins, | |
BOS_id=self.cfg.BOS_id, | |
EOS_id=self.cfg.EOS_id, | |
PAD_id=self.cfg.PAD_id, | |
vertices_num=2, | |
)) | |
self.tokenizer = BoundsTokenizerDiag( | |
bins=self.cfg.bins, | |
BOS_id=self.cfg.BOS_id, | |
EOS_id=self.cfg.EOS_id, | |
PAD_id=self.cfg.PAD_id, | |
) | |
self._load_partfield_encoder() | |
self.partfield_voxel_encoder = SparseStructureEncoder( | |
in_channels=451, | |
channels=[448, 448, 448, 1024], | |
latent_channels=448, | |
num_res_blocks=1, | |
num_res_blocks_middle=1, | |
norm_type="layer", | |
) | |
def _load_partfield_encoder(self): | |
# Load PartField encoder | |
model_spec = importlib.util.spec_from_file_location( | |
"partfield.partfield_encoder", | |
os.path.join(partfield_dir, "partfield", "partfield_encoder.py") | |
) | |
model_module = importlib.util.module_from_spec(model_spec) | |
model_spec.loader.exec_module(model_module) | |
Model = model_module.Model | |
parser = default_argument_parser() | |
args = [] | |
args.extend(["-c", os.path.join(partfield_dir, "configs/final/demo.yaml")]) | |
args.append("--opts") | |
args.extend(["continue_ckpt", self.cfg.partfield_encoder_path]) | |
parsed_args = parser.parse_args(args) | |
cfg = setup(parsed_args, freeze=False) | |
self.partfield_encoder = Model(cfg) | |
self.partfield_encoder.eval() | |
weights = torch.load(self.cfg.partfield_encoder_path)["state_dict"] | |
self.partfield_encoder.load_state_dict(weights) | |
for param in self.partfield_encoder.parameters(): | |
param.requires_grad = False | |
print("PartField encoder loaded") | |
def _prepare_lm_inputs(self, voxel_token, input_ids): | |
inputs_embeds = torch.zeros(input_ids.shape[0], input_ids.shape[1], self.cfg.decoder_hidden_size, device=input_ids.device, dtype=voxel_token.dtype) | |
voxel_token_mask = (input_ids == self.cfg.voxel_token_placeholder) | |
inputs_embeds[voxel_token_mask] = voxel_token.view(-1, self.cfg.decoder_hidden_size) | |
inputs_embeds[~voxel_token_mask] = self.decoder.get_input_embeddings()(input_ids[~voxel_token_mask]).to(dtype=inputs_embeds.dtype) | |
attention_mask = (input_ids != self.cfg.PAD_id) | |
return inputs_embeds, attention_mask.long() | |
def forward(self, batch): | |
image_latents = self.image_encoder(batch['images']) | |
masks = batch['masks'] | |
masks_emb = self.group_embedding(masks) | |
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C | |
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype) | |
group_emb[:, :masks_emb.shape[1], :] = masks_emb | |
image_latents = torch.cat([image_latents, group_emb], dim=-1) | |
image_latents = self.image_projector(image_latents) | |
points = batch['points'][..., :3] | |
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype) | |
rot_points = torch.matmul(points, rot_matrix) | |
rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1) | |
partfield_feat = self.partfield_encoder.encode(rot_points) | |
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype) | |
whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3) | |
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1] | |
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m) | |
batch_flat = batch_indices.flatten() # (b*m,) | |
x_flat = whole_voxel_index[..., 0].flatten() # (b*m,) | |
y_flat = whole_voxel_index[..., 1].flatten() # (b*m,) | |
z_flat = whole_voxel_index[..., 2].flatten() # (b*m,) | |
partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448) | |
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat | |
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype) | |
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3) | |
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1) | |
feat_volume = self.partfield_voxel_encoder(feat_volume) | |
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c') | |
voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D | |
input_ids = batch['input_ids'] | |
inputs_embeds, attention_mask = self._prepare_lm_inputs(voxel_token, input_ids) | |
output = self.decoder( | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
return_dict=True, | |
) | |
return { | |
"logits": output.logits, | |
} | |
def gen_mesh_from_bounds(self, bounds, random_color): | |
bboxes = [] | |
for j in range(bounds.shape[0]): | |
bbox = trimesh.primitives.Box(bounds=bounds[j]) | |
color = random_color[j] | |
bbox.visual.vertex_colors = color | |
bboxes.append(bbox) | |
mesh = trimesh.Scene(bboxes) | |
return mesh | |
def generate(self, batch): | |
image_latents = self.image_encoder(batch['images']) | |
masks = batch['masks'] | |
masks_emb = self.group_embedding(masks) | |
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C | |
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype) | |
group_emb[:, :masks_emb.shape[1], :] = masks_emb | |
image_latents = torch.cat([image_latents, group_emb], dim=-1) | |
image_latents = self.image_projector(image_latents) | |
points = batch['points'][..., :3] | |
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype) | |
rot_points = torch.matmul(points, rot_matrix) | |
rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1) | |
partfield_feat = self.partfield_encoder.encode(rot_points) | |
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype) | |
whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3) | |
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1] | |
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m) | |
batch_flat = batch_indices.flatten() # (b*m,) | |
x_flat = whole_voxel_index[..., 0].flatten() # (b*m,) | |
y_flat = whole_voxel_index[..., 1].flatten() # (b*m,) | |
z_flat = whole_voxel_index[..., 2].flatten() # (b*m,) | |
partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448) | |
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat | |
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype) | |
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3) | |
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1) | |
feat_volume = self.partfield_voxel_encoder(feat_volume) | |
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c') | |
voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D | |
meshes = [] | |
mesh_names = [] | |
bboxes = [] | |
output = self.decoder.generate( | |
inputs_embeds=voxel_token, | |
max_new_tokens=self.cfg.max_length - voxel_token.shape[1], | |
logits_processor=self.logits_processor, | |
do_sample=True, | |
top_k=5, | |
top_p=0.95, | |
temperature=0.5, | |
use_cache=True, | |
) | |
for i in range(output.shape[0]): | |
bounds = self.tokenizer.decode(output[i].detach().cpu().numpy(), coord_rg=(-0.5, 0.5)) | |
# mesh = self.gen_mesh_from_bounds(bounds, batch['random_color'][i]) | |
# meshes.append(mesh) | |
mesh_names.append("topk=5") | |
bboxes.append(bounds) | |
return { | |
# 'meshes': meshes, | |
'mesh_names': mesh_names, | |
'bboxes': bboxes, | |
} | |