import spaces import gradio as gr import numpy as np import torch from chrislib.general import uninvert, invert, view, view_scale from intrinsic.pipeline import load_models, run_pipeline DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Instead of loading models at startup, we'll create a cache for models model_cache = {} model_cache['v2'] = load_models('v2', device=DEVICE) model_cache['v2.1'] = load_models('v2.1', device=DEVICE) def generate_pipeline(models): def pipeline_func(image, **kwargs): return run_pipeline(models, image, **kwargs) return pipeline_func @spaces.GPU def process_image(image, model_version): # Check if image is provided if image is None: return [None, None, None] print(f"Processing with model version: {model_version}") print(image.shape) image = image.astype(np.single) / 255. # Get or load the selected model models = model_cache[model_version] pipeline_func = generate_pipeline(models) result = pipeline_func(image, device=DEVICE, resize_conf=1024) return [view(result['hr_alb']), 1 - invert(result['dif_shd']), view_scale(result['pos_res'])] with gr.Blocks( css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } .image-gallery { display: flex; flex-wrap: wrap; gap: 10px; justify-content: center; } .image-gallery > * { flex: 1; min-width: 200px; } """, ) as demo: gr.Markdown( """ # Colorful Diffuse Intrinsic Image Decomposition in the Wild """ ) # Model version selector with information panel with gr.Row(): model_version = gr.Dropdown( choices=["v2", "v2.1"], value="v2", label="Model Version", info="Select which model weights to use", scale=1 ) gr.Markdown("""

badge-github-stars **V2**: Original weights from the paper. **V2.1**: More albedo detail and improved diffuse shading estimation. """) # Gallery-style layout for all images with gr.Row(elem_classes="image-gallery"): input_img = gr.Image(label="Input Image") alb_img = gr.Image(label="Albedo") shd_img = gr.Image(label="Diffuse Shading") dif_img = gr.Image(label="Diffuse Image") # Update to pass model_version to process_image input_img.change( process_image, inputs=[input_img, model_version], outputs=[alb_img, shd_img, dif_img] ) # Add event handler for when model_version changes model_version.change( process_image, inputs=[input_img, model_version], outputs=[alb_img, shd_img, dif_img] ) demo.launch(show_error=True, ssr_mode=False)