File size: 10,094 Bytes
0d0e451 7872317 ada67da 0d0e451 933cc55 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 8a2c015 7872317 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da ba1da16 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 933cc55 ada67da 0d0e451 ada67da 933cc55 ada67da 61d2ddd ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da 0d0e451 ada67da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
# import argparse # Not strictly needed for weights_only=False, but good practice if dealing with argparse.Namespace
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")
#Codes reference : https://github.com/luost26/diffusion-point-cloud
from models.vae_gaussian import *
from models.vae_flow import *
airplane_model_path = hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt", revision="main")
# IMPORTANT: GEN_chair.pt must be present in the root directory where this script is run.
# This script does NOT download GEN_chair.pt. You need to manually place it there.
# The original repository (https://github.com/luost26/diffusion-point-cloud)
# mentions downloading checkpoints from Google Drive.
chair_model_path = "./GEN_chair.pt"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Start of PyTorch 2.6+ loading considerations ---
# Option 1: Set weights_only=False for each load (Simpler, if you trust the source)
# This is the approach being applied here as per previous interactions.
ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device), weights_only=False)
ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device), weights_only=False) # <--- FIX APPLIED HERE
# Option 2: For a more robust/secure approach with PyTorch 2.6+ (if you have many models)
# You could do this at the top, after importing torch and argparse:
# import argparse
# torch.serialization.add_safe_globals([argparse.Namespace])
# Then, the torch.load calls below would not need weights_only=False (they'd use the default weights_only=True)
# ckpt_airplane = torch.load(airplane_model_path, map_location=torch.device(device))
# ckpt_chair = torch.load(chair_model_path, map_location=torch.device(device))
# --- End of PyTorch 2.6+ loading considerations ---
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs, mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
else: # Fallback if mode is not recognized, though your code doesn't use this branch with current inputs
shift = 0
scale = 1
# Prevent division by zero or very small scale
if scale < 1e-8:
scale = torch.tensor(1.0).reshape(1,1)
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
def predict(Seed, ckpt):
if Seed is None:
Seed = 777
seed_all(int(Seed))
# --- MODIFICATION START ---
# Try to get the original args from the checkpoint first
# The key might be 'args', 'config', or something similar.
# We need to inspect the actual keys of a loaded ckpt if this doesn't work.
if 'args' in ckpt and hasattr(ckpt['args'], 'model'):
actual_args = ckpt['args']
print("Using 'args' found in checkpoint.")
else:
# Fallback to constructing a mock_args if 'args' is not as expected
# This part needs to be more robust and include all necessary defaults
print("Warning: 'args' not found or 'args.model' missing in checkpoint. Constructing mock_args.")
# Defaults - these might need to be adjusted based on the original training scripts
# or by inspecting a correctly loaded checkpoint from the original repo.
default_latent_dim = 128
default_hyper = None # Or some sensible default if PointwiseNet/etc. need it
default_residual = True # Common default for PointwiseNet, but needs verification
default_flow_depth = 10
default_flow_hidden_dim = 256
default_model_type = 'gaussian' # Default if not found
default_num_points = 2048
default_flexibility = 0.0
# Try to get values from ckpt if they exist at the top level
# (some checkpoints might store them flatly instead of under an 'args' key)
model_type = ckpt.get('model', default_model_type) # Check if 'model' key exists directly
latent_dim = ckpt.get('latent_dim', default_latent_dim)
hyper = ckpt.get('hyper', default_hyper)
residual = ckpt.get('residual', default_residual)
flow_depth = ckpt.get('flow_depth', default_flow_depth)
flow_hidden_dim = ckpt.get('flow_hidden_dim', default_flow_hidden_dim)
num_points_to_generate = ckpt.get('num_points', default_num_points)
flexibility = ckpt.get('flexibility', default_flexibility)
# Create the mock_args object
actual_args = type('Args', (), {
'model': model_type,
'latent_dim': latent_dim,
'hyper': hyper,
'residual': residual, # Added residual
'flow_depth': flow_depth,
'flow_hidden_dim': flow_hidden_dim,
'num_points': num_points_to_generate,
'flexibility': flexibility
# Add any other attributes that models might expect from 'args'
})()
# --- MODIFICATION END ---
# Now use actual_args to instantiate models
if actual_args.model == 'gaussian':
model = GaussianVAE(actual_args).to(device)
elif actual_args.model == 'flow':
model = FlowVAE(actual_args).to(device)
else:
raise ValueError(f"Unknown model type: {actual_args.model}")
model.load_state_dict(ckpt['state_dict'])
model.eval()
gen_pcs = []
with torch.no_grad():
z = torch.randn([1, actual_args.latent_dim]).to(device)
x = model.sample(z, actual_args.num_points, flexibility=actual_args.flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs_tensor = torch.cat(gen_pcs, dim=0)[:1]
gen_pcs_normalized = normalize_point_clouds(gen_pcs_tensor.clone(), mode="shape_bbox")
return gen_pcs_normalized[0]
def generate(seed, value):
if value == "Airplane":
ckpt = ckpt_airplane
elif value == "Chair":
ckpt = ckpt_chair
else:
# Default case or handle error
# For now, defaulting to airplane if 'value' is unexpected
print(f"Warning: Unknown model type '{value}'. Defaulting to Airplane.")
ckpt = ckpt_airplane
colors = (238, 75, 43) # RGB tuple for plotly
# Ensure seed is not None and is an int for the predict function
current_seed = seed
if current_seed is None:
current_seed = random.randint(0, 2**16 -1) # Generate a random seed if None
current_seed = int(current_seed)
points = predict(current_seed, ckpt)
# num_points = points.shape[0] # Not used directly in fig
fig = go.Figure(
data=[
go.Scatter3d(
x=points[:, 0], y=points[:, 1], z=points[:, 2],
mode='markers',
marker=dict(size=2, color=f'rgb({colors[0]},{colors[1]},{colors[2]})') # plotly expects rgb string
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=True, title='X', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
yaxis=dict(visible=True, title='Y', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
zaxis=dict(visible=True, title='Z', backgroundcolor="rgb(230, 230,230)", gridcolor="white", zerolinecolor="white"),
aspectmode='data' # Ensures proportional axes
),
margin=dict(l=0, r=0, b=0, t=40), # Adjust margins
title=f"Generated {value} (Seed: {current_seed})"
)
)
return fig
markdown = f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
### Future Work based on interest
- Adding new models for new type objects
- New Customization
It is running on **{device.upper()}**
---
**Note:** The `GEN_chair.pt` file must be manually placed in the root directory for the "Chair" model to work.
It is not downloaded automatically by this script.
Check the [original repository's instructions](https://github.com/luost26/diffusion-point-cloud#pretrained-models) for downloading checkpoints.
---
'''
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed_slider = gr.Slider(minimum=0, maximum=2**16 - 1, step=1, label='Seed (0 for random)', value=777) # Set initial value
model_dropdown = gr.Dropdown(choices=["Airplane", "Chair"], label="Choose Model Type", value="Airplane") # Set initial value
btn = gr.Button(value="Generate Point Cloud")
point_cloud_plot = gr.Plot() # Changed variable name for clarity
# demo.load(generate, [seed_slider, model_dropdown], point_cloud_plot) # demo.load usually runs on page load
btn.click(generate, [seed_slider, model_dropdown], point_cloud_plot)
if __name__ == "__main__":
# Ensure GEN_chair.pt exists if Chair model might be selected
if not os.path.exists(chair_model_path):
print(f"WARNING: Chair model checkpoint '{chair_model_path}' not found.")
print(f"The 'Chair' option in the UI may not work unless this file is present.")
print(f"Please download it from the original project repository and place it at '{chair_model_path}'.")
demo.launch() |