TripoSG / app.py
Robledo Gularte Gonçalves
remove examples code
1527622
raw
history blame
33 kB
import spaces
import os
import gradio as gr
import numpy as np
import torch
from PIL import Image
import trimesh
import random
from transformers import AutoModelForImageSegmentation
from torchvision import transforms
from huggingface_hub import hf_hub_download, snapshot_download
import subprocess
import shutil
# install others
subprocess.run("pip install spandrel==0.4.1 --no-deps", shell=True, check=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16
print("DEVICE: ", DEVICE)
print("CUDA DEVICE NAME: ", torch.cuda.get_device_name(torch.cuda.current_device()))
DEFAULT_FACE_NUMBER = 100000
MAX_SEED = np.iinfo(np.int32).max
TRIPOSG_REPO_URL = "https://github.com/VAST-AI-Research/TripoSG.git"
MV_ADAPTER_REPO_URL = "https://github.com/huanngzh/MV-Adapter.git"
RMBG_PRETRAINED_MODEL = "checkpoints/RMBG-1.4"
TRIPOSG_PRETRAINED_MODEL = "checkpoints/TripoSG"
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_DIR, exist_ok=True)
TRIPOSG_CODE_DIR = "./triposg"
if not os.path.exists(TRIPOSG_CODE_DIR):
os.system(f"git clone {TRIPOSG_REPO_URL} {TRIPOSG_CODE_DIR}")
MV_ADAPTER_CODE_DIR = "./mv_adapter"
if not os.path.exists(MV_ADAPTER_CODE_DIR):
os.system(f"git clone {MV_ADAPTER_REPO_URL} {MV_ADAPTER_CODE_DIR} && cd {MV_ADAPTER_CODE_DIR} && git checkout 7d37a97e9bc223cdb8fd26a76bd8dd46504c7c3d")
import sys
sys.path.append(TRIPOSG_CODE_DIR)
sys.path.append(os.path.join(TRIPOSG_CODE_DIR, "scripts"))
sys.path.append(MV_ADAPTER_CODE_DIR)
sys.path.append(os.path.join(MV_ADAPTER_CODE_DIR, "scripts"))
# Custom styling constants
NESTLE_BLUE = "#0066b1"
NESTLE_BLUE_DARK = "#004a82"
ACCENT_COLOR = "#10b981"
# # triposg
from image_process import prepare_image
from briarmbg import BriaRMBG
snapshot_download("briaai/RMBG-1.4", local_dir=RMBG_PRETRAINED_MODEL)
rmbg_net = BriaRMBG.from_pretrained(RMBG_PRETRAINED_MODEL).to(DEVICE)
rmbg_net.eval()
from triposg.pipelines.pipeline_triposg import TripoSGPipeline
snapshot_download("VAST-AI/TripoSG", local_dir=TRIPOSG_PRETRAINED_MODEL)
triposg_pipe = TripoSGPipeline.from_pretrained(TRIPOSG_PRETRAINED_MODEL).to(DEVICE, DTYPE)
# mv adapter
NUM_VIEWS = 6
from inference_ig2mv_sdxl import prepare_pipeline, preprocess_image, remove_bg
from mvadapter.utils import get_orthogonal_camera, tensor_to_image, make_image_grid
from mvadapter.utils.render import NVDiffRastContextWrapper, load_mesh, render
mv_adapter_pipe = prepare_pipeline(
base_model="stabilityai/stable-diffusion-xl-base-1.0",
vae_model="madebyollin/sdxl-vae-fp16-fix",
unet_model=None,
lora_model=None,
adapter_path="huanngzh/mv-adapter",
scheduler=None,
num_views=NUM_VIEWS,
device=DEVICE,
dtype=torch.float16,
)
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to(DEVICE)
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, DEVICE)
if not os.path.exists("checkpoints/RealESRGAN_x2plus.pth"):
hf_hub_download("dtarnow/UPscaler", filename="RealESRGAN_x2plus.pth", local_dir="checkpoints")
if not os.path.exists("checkpoints/big-lama.pt"):
subprocess.run("wget -P checkpoints/ https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt", shell=True, check=True)
def start_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(save_dir, exist_ok=True)
print("start session, mkdir", save_dir)
def end_session(req: gr.Request):
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
shutil.rmtree(save_dir)
def get_random_hex():
random_bytes = os.urandom(8)
random_hex = random_bytes.hex()
return random_hex
def get_random_seed(randomize_seed, seed):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU(duration=180)
def run_full(image: str, req: gr.Request):
seed = 0
num_inference_steps = 50
guidance_scale = 7.5
simplify = True
target_face_num = DEFAULT_FACE_NUMBER
image_seg = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
outputs = triposg_pipe(
image=image_seg,
generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).samples[0]
print("mesh extraction done")
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
if simplify:
print("start simplify")
from utils import simplify_mesh
mesh = simplify_mesh(mesh, target_face_num)
save_dir = os.path.join(TMP_DIR, "examples")
os.makedirs(save_dir, exist_ok=True)
mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
mesh.export(mesh_path)
print("save to ", mesh_path)
torch.cuda.empty_cache()
height, width = 768, 768
# Prepare cameras
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
distance=[1.8] * NUM_VIEWS,
left=-0.55,
right=0.55,
bottom=-0.55,
top=0.55,
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
device=DEVICE,
)
ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
render_out = render(
ctx,
mesh,
cameras,
height=height,
width=width,
render_attr=False,
normal_background=0.0,
)
control_images = (
torch.cat(
[
(render_out.pos + 0.5).clamp(0, 1),
(render_out.normal / 2 + 0.5).clamp(0, 1),
],
dim=-1,
)
.permute(0, 3, 1, 2)
.to(DEVICE)
)
image = Image.open(image)
image = remove_bg_fn(image)
image = preprocess_image(image, height, width)
pipe_kwargs = {}
if seed != -1 and isinstance(seed, int):
pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
images = mv_adapter_pipe(
"high quality",
height=height,
width=width,
num_inference_steps=15,
guidance_scale=3.0,
num_images_per_prompt=NUM_VIEWS,
control_image=control_images,
control_conditioning_scale=1.0,
reference_image=image,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
cross_attention_kwargs={"scale": 1.0},
**pipe_kwargs,
).images
torch.cuda.empty_cache()
mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
make_image_grid(images, rows=1).save(mv_image_path)
from texture import TexturePipeline, ModProcessConfig
texture_pipe = TexturePipeline(
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
inpaint_ckpt_path="checkpoints/big-lama.pt",
device=DEVICE,
)
textured_glb_path = texture_pipe(
mesh_path=mesh_path,
save_dir=save_dir,
save_name=f"texture_mesh_{get_random_hex()}.glb",
uv_unwarp=True,
uv_size=4096,
rgb_path=mv_image_path,
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
)
return image_seg, mesh_path, textured_glb_path
@spaces.GPU()
@torch.no_grad()
def run_segmentation(image: str):
print("run_segmentation pre!")
image = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
print("run_segmentation pos!")
print("run_segmentation image: ", image)
return image
@spaces.GPU(duration=90)
@torch.no_grad()
def image_to_3d(
image: Image.Image,
seed: int,
num_inference_steps: int,
guidance_scale: float,
simplify: bool,
target_face_num: int,
req: gr.Request
):
outputs = triposg_pipe(
image=image,
generator=torch.Generator(device=triposg_pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).samples[0]
print("mesh extraction done")
mesh = trimesh.Trimesh(outputs[0].astype(np.float32), np.ascontiguousarray(outputs[1]))
if simplify:
print("start simplify")
from utils import simplify_mesh
mesh = simplify_mesh(mesh, target_face_num)
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
mesh_path = os.path.join(save_dir, f"triposg_{get_random_hex()}.glb")
mesh.export(mesh_path)
print("save to ", mesh_path)
torch.cuda.empty_cache()
return mesh_path
@spaces.GPU(duration=120)
@torch.no_grad()
def run_texture(image: Image, mesh_path: str, seed: int, text_prompt: str, req: gr.Request):
height, width = 768, 768
# Prepare cameras
cameras = get_orthogonal_camera(
elevation_deg=[0, 0, 0, 0, 89.99, -89.99],
distance=[1.8] * NUM_VIEWS,
left=-0.55,
right=0.55,
bottom=-0.55,
top=0.55,
azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
device=DEVICE,
)
ctx = NVDiffRastContextWrapper(device=DEVICE, context_type="cuda")
mesh = load_mesh(mesh_path, rescale=True, device=DEVICE)
render_out = render(
ctx,
mesh,
cameras,
height=height,
width=width,
render_attr=False,
normal_background=0.0,
)
control_images = (
torch.cat(
[
(render_out.pos + 0.5).clamp(0, 1),
(render_out.normal / 2 + 0.5).clamp(0, 1),
],
dim=-1,
)
.permute(0, 3, 1, 2)
.to(DEVICE)
)
image = Image.open(image)
image = remove_bg_fn(image)
image = preprocess_image(image, height, width)
pipe_kwargs = {}
if seed != -1 and isinstance(seed, int):
pipe_kwargs["generator"] = torch.Generator(device=DEVICE).manual_seed(seed)
images = mv_adapter_pipe(
text_prompt,
height=height,
width=width,
num_inference_steps=15,
guidance_scale=3.0,
num_images_per_prompt=NUM_VIEWS,
control_image=control_images,
control_conditioning_scale=1.0,
reference_image=image,
reference_conditioning_scale=1.0,
negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
cross_attention_kwargs={"scale": 1.0},
**pipe_kwargs,
).images
torch.cuda.empty_cache()
save_dir = os.path.join(TMP_DIR, str(req.session_hash))
mv_image_path = os.path.join(save_dir, f"mv_adapter_{get_random_hex()}.png")
make_image_grid(images, rows=1).save(mv_image_path)
from texture import TexturePipeline, ModProcessConfig
texture_pipe = TexturePipeline(
upscaler_ckpt_path="checkpoints/RealESRGAN_x2plus.pth",
inpaint_ckpt_path="checkpoints/big-lama.pt",
device=DEVICE,
)
textured_glb_path = texture_pipe(
mesh_path=mesh_path,
save_dir=save_dir,
save_name=f"texture_mesh_{get_random_hex()}.glb",
uv_unwarp=True,
uv_size=4096,
rgb_path=mv_image_path,
rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"),
camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]],
)
return textured_glb_path
# Custom UI components
def create_header():
return f"""
<div class="card" style="background: linear-gradient(135deg, {NESTLE_BLUE} 0%, {NESTLE_BLUE_DARK} 100%); color: white; border: none;">
<div style="display: flex; align-items: center; gap: 20px;">
<div style="background: white; padding: 12px; border-radius: 12px; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);">
<img src="https://logodownload.org/wp-content/uploads/2016/11/nestle-logo-1.png"
alt="Nestlé Logo" style="height: 48px; width: auto;">
</div>
<div style="flex: 1;">
<h1 style="margin: 0; font-size: 2.5rem; font-weight: 700; letter-spacing: -0.025em;">
Nestlé 3D Generator
</h1>
<p style="margin: 0.5rem 0 0 0; opacity: 0.9; font-size: 1.1rem;">
Transform your product images into stunning 3D models with AI
</p>
</div>
<div class="badge primary">Beta v2.0</div>
</div>
</div>
"""
def create_tabs():
return """
<div class="tabs-container">
<div class="tabs-list">
<button class="tab-button active" onclick="switchTab('segmentation')">
🔍 Segmentation
</button>
<button class="tab-button" onclick="switchTab('model')">
🎨 3D Model
</button>
<button class="tab-button" onclick="switchTab('textured')">
✨ Textured Model
</button>
</div>
<div id="segmentation-tab" class="tab-content active">
<div style="text-align: center; color: #1e293b;">
<div style="font-size: 4rem; margin-bottom: 1rem;">📤</div>
<p>Upload an image to see segmentation results</p>
</div>
</div>
<div id="model-tab" class="tab-content">
<div style="text-align: center; color: #1e293b;">
<div style="font-size: 4rem; margin-bottom: 1rem;">🎯</div>
<p>3D model will appear here after generation</p>
</div>
</div>
<div id="textured-tab" class="tab-content">
<div style="text-align: center; color: #1e293b;">
<div style="font-size: 4rem; margin-bottom: 1rem;">🎨</div>
<p>Textured model will appear here</p>
</div>
</div>
</div>
"""
def create_progress_bar():
return """
<div class="progress-container" style="display: none;" id="progress-container">
<div class="progress-header">
<span>Generating 3D model...</span>
<span id="progress-text">0%</span>
</div>
<div class="progress-bar-container">
<div class="progress-bar" id="progress-bar"></div>
</div>
</div>
"""
# JavaScript
ADVANCED_JS = """
<script>
// React-like state management simulation
window.appState = {
currentTab: 'segmentation',
isGenerating: false,
progress: 0
};
// Tab switching functionality
function switchTab(tabName) {
window.appState.currentTab = tabName;
// Hide all tab contents
document.querySelectorAll('.tab-content').forEach(el => {
el.style.display = 'none';
});
// Show selected tab
const selectedTab = document.getElementById(tabName + '-tab');
if (selectedTab) {
selectedTab.style.display = 'block';
}
// Update tab buttons
document.querySelectorAll('.tab-button').forEach(btn => {
btn.classList.remove('active');
});
const activeBtn = document.querySelector(`[onclick="switchTab('${tabName}')"]`);
if (activeBtn) {
activeBtn.classList.add('active');
}
}
// Progress simulation
function simulateProgress() {
window.appState.isGenerating = true;
window.appState.progress = 0;
const progressBar = document.getElementById('progress-bar');
const progressText = document.getElementById('progress-text');
const interval = setInterval(() => {
window.appState.progress += 10;
if (progressBar) {
progressBar.style.width = window.appState.progress + '%';
}
if (progressText) {
progressText.textContent = window.appState.progress + '%';
}
if (window.appState.progress >= 100) {
clearInterval(interval);
window.appState.isGenerating = false;
}
}, 300);
}
// Drag and drop simulation
function setupDragDrop() {
const uploadArea = document.querySelector('.upload-area');
if (uploadArea) {
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('drag-over');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('drag-over');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('drag-over');
// Handle file drop
});
}
}
// Initialize when DOM is ready
document.addEventListener('DOMContentLoaded', function() {
setupDragDrop();
switchTab('segmentation');
});
</script>
"""
# CSS
ADVANCED_CSS = f"""
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
:root {{
--nestle-blue: {NESTLE_BLUE};
--nestle-blue-dark: {NESTLE_BLUE_DARK};
--accent: {ACCENT_COLOR};
--shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.1);
--shadow-md: 0 4px 12px rgba(0, 0, 0, 0.1);
--shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1);
--border-radius: 12px;
}}
* {{
font-family: 'Inter', sans-serif !important;
}}
body, .gradio-container {{
background: linear-gradient(135deg, #f8fafc 0%, #e2e8f0 100%) !important;
margin: 0 !important;
padding: 0 !important;
min-height: 100vh !important;
color: #ffffff !important;
font-size: 1rem !important;
}}
/* AGGRESSIVE TEXT COLOR FIXES - Higher specificity */
.gradio-container *,
.gradio-container div,
.gradio-container span,
.gradio-container p,
.gradio-container label,
.gradio-container h1,
.gradio-container h2,
.gradio-container h3,
.gradio-container h4,
.gradio-container h5,
.gradio-container h6 {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
}}
/* Force white text on all Gradio components */
.gr-group *,
.gr-form *,
.gr-block *,
.gr-box *,
div[class*="gr-"] *,
div[class*="svelte-"] *,
span[class*="svelte-"] *,
label[class*="svelte-"] *,
p[class*="svelte-"] * {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
font-weight: 600 !important;
}}
/* Specific targeting for card descriptions and titles */
.card-description,
.card-title,
div.card-description,
div.card-title,
p.card-description,
h3.card-title {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
font-weight: 700 !important;
background: rgba(0, 0, 0, 0.3) !important;
padding: 4px 8px !important;
border-radius: 6px !important;
margin: 0.5rem 0 !important;
display: inline-block !important;
}}
/* Card Components */
.card {{
background: white;
border: 1px solid #e2e8f0;
border-radius: var(--border-radius);
box-shadow: var(--shadow-md);
padding: 1.5rem;
transition: all 0.2s ease;
margin-bottom: 1rem;
}}
.card:hover {{
box-shadow: var(--shadow-lg);
transform: translateY(-2px);
}}
.card-header {{
margin-bottom: 1rem;
padding-bottom: 1rem;
border-bottom: 1px solid #e2e8f0;
}}
/* Tabs */
.tabs-container {{
background: white;
border-radius: var(--border-radius);
box-shadow: var(--shadow-md);
overflow: hidden;
}}
.tabs-list {{
display: flex;
background: #f8fafc;
border-bottom: 1px solid #e2e8f0;
}}
.tab-button {{
flex: 1;
padding: 1rem;
background: none;
border: none;
cursor: pointer;
font-weight: 600;
color: #334155 !important;
font-size: 1rem;
transition: all 0.2s ease;
position: relative;
}}
.tab-button:hover {{
background: #f1f5f9;
color: #1e293b !important;
}}
.tab-button.active {{
color: var(--nestle-blue) !important;
background: white;
font-weight: 800;
}}
.tab-button.active::after {{
content: '';
position: absolute;
bottom: 0;
left: 0;
right: 0;
height: 2px;
background: var(--nestle-blue);
}}
.tab-content {{
padding: 2rem;
min-height: 400px;
display: none;
}}
.tab-content.active {{
display: block;
}}
.tab-content * {{
color: #1e293b !important;
text-shadow: none !important;
}}
/* Progress Component */
.progress-container {{
margin: 1rem 0;
padding: 1rem;
background: #f8fafc;
border-radius: var(--border-radius);
border: 1px solid #e2e8f0;
}}
.progress-header {{
display: flex;
justify-content: space-between;
margin-bottom: 0.5rem;
font-size: 1rem;
color: #334155 !important;
font-weight: 600;
}}
.progress-bar-container {{
width: 100%;
height: 8px;
background: #e2e8f0;
border-radius: 4px;
overflow: hidden;
}}
.progress-bar {{
height: 100%;
background: linear-gradient(90deg, var(--nestle-blue) 0%, var(--accent) 100%);
width: 0%;
transition: width 0.3s ease;
border-radius: 4px;
}}
/* Badge */
.badge {{
display: inline-flex;
align-items: center;
padding: 0.25rem 0.75rem;
background: #e2e8f0;
color: #1e293b !important;
border-radius: 9999px;
font-size: 0.85rem;
font-weight: 600;
}}
.badge.primary {{
background: var(--nestle-blue);
color: #fff !important;
}}
/* Button variants */
.btn, .btn-primary, .btn-secondary, .gr-button {{
display: inline-flex;
align-items: center;
justify-content: center;
gap: 0.5rem;
padding: 0.75rem 1.5rem;
border-radius: var(--border-radius);
font-weight: 700 !important;
font-size: 1rem !important;
border: none;
cursor: pointer;
transition: all 0.2s ease;
text-decoration: none;
letter-spacing: -0.01em;
}}
.btn-primary, .gr-button {{
background: linear-gradient(135deg, var(--nestle-blue) 0%, var(--nestle-blue-dark) 100%) !important;
color: white !important;
box-shadow: var(--shadow-sm) !important;
}}
.btn-primary:hover, .gr-button:hover {{
transform: translateY(-1px) !important;
box-shadow: var(--shadow-md) !important;
}}
.btn-secondary {{
background: white !important;
color: #374151 !important;
border: 1px solid #d1d5db !important;
}}
.btn-secondary:hover {{
background: #f9fafb !important;
}}
/* Enhanced Gradio component styling */
.gr-image, .gr-model3d {{
border: 2px solid #e2e8f0 !important;
border-radius: var(--border-radius) !important;
box-shadow: var(--shadow-sm) !important;
transition: all 0.2s ease !important;
}}
.gr-slider .noUi-connect {{
background: linear-gradient(90deg, var(--nestle-blue) 0%, var(--accent) 100%) !important;
}}
.gr-slider .noUi-handle {{
background: white !important;
border: 3px solid var(--nestle-blue) !important;
border-radius: 50% !important;
box-shadow: var(--shadow-md) !important;
}}
/* Responsive design */
@media (max-width: 768px) {{
.tabs-list {{
flex-direction: column;
}}
.card {{
padding: 1rem;
}}
}}
/* SUPER AGGRESSIVE TEXT FIXES */
/* Target every possible Gradio text element */
.gradio-container .gr-group .gr-form label,
.gradio-container .gr-group .gr-form span,
.gradio-container .gr-group .gr-form div,
.gradio-container .gr-group .gr-form p,
.gradio-container .gr-block label,
.gradio-container .gr-block span,
.gradio-container .gr-block div,
.gradio-container .gr-block p,
.gradio-container .gr-box label,
.gradio-container .gr-box span,
.gradio-container .gr-box div,
.gradio-container .gr-box p {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
font-weight: 600 !important;
opacity: 1 !important;
}}
/* Target Svelte components specifically */
[class*="svelte-"] {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
}}
/* Target slider labels and info text */
.gr-slider label,
.gr-slider .gr-text,
.gr-slider span,
.gr-checkbox label,
.gr-checkbox span {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
font-weight: 600 !important;
}}
/* Target info text specifically */
.gr-info,
[class*="info"],
.info {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
font-weight: 500 !important;
background: rgba(0, 0, 0, 0.2) !important;
padding: 2px 6px !important;
border-radius: 4px !important;
}}
/* Fix for image action icons */
.gr-image .image-button,
.gr-image button,
.gr-image .icon-button,
.gr-image [role="button"],
.gr-image .svelte-1pijsyv,
.gr-image .svelte-1pijsyv button {{
background: rgba(255, 255, 255, 0.95) !important;
border: 1px solid #e2e8f0 !important;
border-radius: 8px !important;
padding: 8px !important;
margin: 2px !important;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15) !important;
transition: all 0.2s ease !important;
color: #374151 !important;
font-size: 16px !important;
min-width: 36px !important;
min-height: 36px !important;
display: flex !important;
align-items: center !important;
justify-content: center !important;
}}
.gr-image .image-button:hover,
.gr-image button:hover,
.gr-image .icon-button:hover,
.gr-image [role="button"]:hover,
.gr-image .svelte-1pijsyv:hover,
.gr-image .svelte-1pijsyv button:hover {{
background: rgba(255, 255, 255, 1) !important;
transform: translateY(-1px) !important;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2) !important;
color: var(--nestle-blue) !important;
}}
/* Upload area text */
.gr-image .upload-text,
.gr-image .drag-text,
.gr-image .svelte-1ipelgc {{
color: #1e293b !important;
font-weight: 600 !important;
text-shadow: 0 0 4px white !important;
background: rgba(255, 255, 255, 0.9) !important;
padding: 8px 12px !important;
border-radius: 8px !important;
margin: 4px !important;
}}
/* Nuclear option - force all text to be white with shadow */
* {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.8) !important;
}}
/* But override for specific areas that should be dark */
.tabs-container *,
.tab-content *,
.badge *,
.btn *,
.gr-button *,
.upload-area *,
.gr-image .upload-text *,
.gr-image .drag-text *,
.gr-image .svelte-1ipelgc *,
.progress-container * {{
color: #1e293b !important;
text-shadow: 0 0 2px white !important;
}}
/* Header text should remain white */
.card[style*="linear-gradient"] *,
.card[style*="linear-gradient"] h1,
.card[style*="linear-gradient"] p {{
color: #ffffff !important;
text-shadow: 0 1px 3px rgba(0, 0, 0, 0.5) !important;
}}
"""
# interface
with gr.Blocks(
title="Nestlé 3D Generator",
css=ADVANCED_CSS,
head=ADVANCED_JS,
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
neutral_hue="slate",
font=gr.themes.GoogleFont("Inter")
)
) as demo:
# Header
gr.HTML(create_header())
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.HTML("""
<div class="card-header">
<h3 class="card-title">📤 Product Image Upload</h3>
<p class="card-description">Upload a clear image of your Nestlé product</p>
</div>
""")
image_prompts = gr.Image(
label="",
type="filepath",
show_label=False,
height=350,
elem_classes=["upload-area"]
)
# Settings Card
with gr.Group():
gr.HTML("""
<div class="card-header">
<h3 class="card-title">⚙️ Generation Settings</h3>
<p class="card-description">Configure your 3D model generation</p>
</div>
""")
text_prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", value="high quality")
with gr.Row():
randomize_seed = gr.Checkbox(
label="🎲 Randomize Seed",
value=True
)
seed = gr.Slider(
label="Seed Value",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0
)
num_inference_steps = gr.Slider(
label="🔄 Inference Steps",
minimum=8,
maximum=50,
step=1,
value=50,
info="Higher values = better quality, slower generation"
)
guidance_scale = gr.Slider(
label="🎯 Guidance Scale",
minimum=0.0,
maximum=20.0,
step=0.1,
value=7.0,
info="Controls how closely the model follows the input"
)
with gr.Row():
reduce_face = gr.Checkbox(
label="🔧 Optimize Mesh",
value=True,
info="Reduce polygon count for better performance"
)
target_face_num = gr.Slider(
label="Target Faces",
maximum=1_000_000,
minimum=10_000,
value=DEFAULT_FACE_NUMBER,
step=1000
)
with gr.Column(scale=2):
gr.HTML("""
<div class="card-header">
<h3 class="card-title">3D Model Generation</h3>
<p class="card-description">View your generated 3D models and apply textures</p>
</div>
""")
# CT React-like
gr.HTML(create_tabs())
# PB
gr.HTML(create_progress_bar())
# Hidden Gradio components for actual functionality
with gr.Row(visible=False):
seg_image = gr.Image(type="pil", format="png", interactive=False)
model_output = gr.Model3D(interactive=False)
textured_model_output = gr.Model3D(interactive=False)
# Action Buttons
with gr.Row():
gen_button = gr.Button(
"🚀 Generate 3D Model",
variant="primary",
size="lg",
elem_classes=["btn", "btn-primary"]
)
gen_texture_button = gr.Button(
"🎨 Apply Texture",
variant="secondary",
size="lg",
interactive=False,
elem_classes=["btn", "btn-secondary"]
)
download_button = gr.Button(
"💾 Download Model",
variant="secondary",
size="lg",
elem_classes=["btn", "btn-secondary"]
)
status_display = gr.HTML(
"""<div style='text-align: center; padding: 1rem; color: #1e293b;'>
<span style='display: inline-block; width: 8px; height: 8px; border-radius: 50%; background: #10b981; margin-right: 8px;'></span>
Ready to generate your 3D model
</div>"""
)
# Event Handlers with JavaScript integration
gen_button.click(
fn=run_segmentation,
inputs=[image_prompts],
outputs=[seg_image],
js="() => { simulateProgress(); document.getElementById('progress-container').style.display = 'block'; }",
).then(
get_random_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[
seg_image,
seed,
num_inference_steps,
guidance_scale,
reduce_face,
target_face_num
],
outputs=[model_output]
).then(
fn=lambda: gr.Button(interactive=True),
outputs=[gen_texture_button]
)
gen_texture_button.click(
run_texture,
inputs=[image_prompts, model_output, seed, text_prompt],
outputs=[textured_model_output]
)
# with gr.Row():
# examples = gr.Examples(
# examples=[
# f"./examples/{image}"
# for image in os.listdir(f"./examples/")
# ],
# fn=run_full,
# inputs=[image_prompts],
# outputs=[seg_image, model_output, textured_model_output],
# cache_examples=False,
# )
demo.load(start_session)
demo.unload(end_session)
if __name__ == "__main__":
demo.launch(share=False, show_error=True)