semanticist_tokenizer / tok_demo.py
tennant's picture
Update tok_demo.py
8a219b6 verified
import gradio as gr
import spaces
import numpy as np
from PIL import Image
import os
import torch
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download
from semanticist.engine.trainer_utils import instantiate_from_config
from semanticist.stage1.diffuse_slot import DiffuseSlot
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl', cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
config_path = 'configs/tokenizer_xl.yaml'
cfg = OmegaConf.load(config_path)
ckpt = torch.load(ckpt_path, map_location='cpu')
from semanticist.utils.datasets import vae_transforms
from PIL import Image
transform = vae_transforms('test')
print(f"Is CUDA available: {torch.cuda.is_available()}")
if device == 'cuda':
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
def norm_ip(img, low, high):
img.clamp_(min=low, max=high)
img.sub_(low).div_(max(high - low, 1e-5))
def norm_range(t, value_range):
if value_range is not None:
norm_ip(t, value_range[0], value_range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
from PIL import Image
def convert_np(img):
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
return ndarr
def convert_PIL(img):
ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
img = Image.fromarray(ndarr)
return img
ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()}
model = DiffuseSlot(**cfg['trainer']['params']['model']['params'])
msg = model.load_state_dict(ckpt, strict=False)
model = model.to(device)
model = model.eval()
model.enable_nest = True
@spaces.GPU(duration=300)
def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False):
img.to(device)
n_slots_inf = []
for i, num_slots_to_inference in enumerate(nums):
recon_n = model(
img, sample=True, cfg=cfg,
inference_with_n_slots=num_slots_to_inference,
)
n_slots_inf.append(recon_n)
print(f"processing {i} / {len(nums)}")
return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
# Removed process_image function as its functionality is now in the update_outputs function
with gr.Blocks() as demo:
with gr.Row():
# First column - Input and configs
with gr.Column(scale=1):
gr.Markdown("## Input")
input_image = gr.Image(label="Upload an image", type="numpy")
with gr.Group():
gr.Markdown("### Configuration")
show_gallery = gr.Checkbox(label="Show Gallery", value=True)
# You can add more config options here
# slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity")
slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
labels_input = gr.Textbox(
label="Number of tokens to reconstruct (comma-separated)",
value="1, 4, 16, 64, 256",
placeholder="Enter comma-separated numbers for the number of slots to use"
)
# Second column - Output (conditionally rendered)
with gr.Column(scale=1):
gr.Markdown("## Output")
# Container for conditional rendering
with gr.Group(visible=True) as gallery_container:
gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
# Always visible output image
output_image = gr.Image(label="Processed Image", type="numpy")
# Handle form submission
submit_btn = gr.Button("Process")
# Define the processing logic
def update_outputs(image, show_gallery_value, slider_value, labels_text):
# Update the visibility of the gallery container
gallery_container.visible = show_gallery_value
try:
# Parse the labels from the text input
if labels_text and "," in labels_text:
labels = [int(label.strip()) for label in labels_text.split(",")]
else:
# Default labels if none provided or in wrong format
labels = [1, 4, 16, 64, 256]
except:
labels = [1, 4, 16, 64, 256]
while len(labels) < 3:
labels.append(256)
# Process the image based on configurations
if image is None:
# Return placeholder if no image is uploaded
placeholder = np.zeros((300, 300, 3), dtype=np.uint8)
return gallery_container, [], placeholder
image = Image.fromarray(image)
img = transform(image)
img = img.unsqueeze(0)#.to(device)
recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0]
if not show_gallery_value:
# If only the image should be shown, return just the processed image
return gallery_container, [], recon
else:
model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value)
# Create image variations and pair them with labels
gallery_images = [
(image, 'GT'),
# (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]),
# (np.array(Image.fromarray(image).rotate(180)), labels[2])
] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
return gallery_container, gallery_images, recon
# Connect the inputs and outputs
submit_btn.click(
fn=update_outputs,
inputs=[input_image, show_gallery, slider, labels_input],
outputs=[gallery_container, gallery, output_image]
)
# Also update when checkbox changes
show_gallery.change(
fn=lambda value: gr.update(visible=value),
inputs=[show_gallery],
outputs=[gallery_container]
)
# Add examples
examples = [
["examples/city.jpg", True, 4.0, "1,4,16,64,256"],
["examples/food.jpg", True, 4.0, "1,4,16,64,256"],
["examples/highland.webp", True, 4.0, "1,4,16,64,256"],
]
gr.Examples(
examples=examples,
inputs=[input_image, show_gallery, slider, labels_input],
outputs=[gallery_container, gallery, output_image],
fn=update_outputs,
cache_examples=True
)
# Launch the demo
if __name__ == "__main__":
demo.launch()