Spaces:
Running
on
Zero
Running
on
Zero
from typing import List | |
import os | |
import spaces | |
import gradio as gr | |
import random | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import einops | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
import torchvision.transforms.functional as TF | |
from flextok.flextok_wrapper import FlexTokFromHub | |
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil | |
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator | |
# We recommend running this demo on an A100 GPU | |
if torch.cuda.is_available(): | |
device = "cuda" | |
gpu_type = torch.cuda.get_device_name(torch.cuda.current_device()) | |
power_device = f"{gpu_type}" | |
torch.cuda.max_memory_allocated(device=device) | |
# Detect if bf16 is enabled or not | |
enable_bf16 = detect_bf16_support() | |
else: | |
device, power_device, enable_bf16 = "cpu", "CPU", False | |
print(f'Device: {device}, GPU type: {gpu_type}') | |
print('BF16 enabled:', enable_bf16) | |
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. | |
torch.backends.cuda.matmul.allow_tf32 = True | |
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. | |
torch.backends.cudnn.allow_tf32 = True | |
# Global no_grad | |
torch.set_grad_enabled(False) | |
MAX_SEED = np.iinfo(np.int32).max | |
MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn' | |
MODEL_NAME = 'FlexTok d18-d28 (DFN)' | |
# Load FlexTok model from HF Hub | |
flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval() | |
def img_from_path( | |
path: str, | |
img_size: int = 256, | |
mean: List[float] = [0.5, 0.5, 0.5], | |
std: List[float] = [0.5, 0.5, 0.5], | |
) -> torch.Tensor: | |
# Image loading helper function | |
img_pil = Image.open(path).convert("RGB") | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(img_size), | |
transforms.CenterCrop(img_size), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std), | |
] | |
) | |
return transform(img_pil).unsqueeze(0) | |
def infer(img_path, seed=0, randomize_seed=False, timesteps=20, cfg_scale=7.5, perform_norm_guidance=True): | |
if randomize_seed: | |
seed = None | |
imgs = img_from_path(img_path).to(device) | |
# Tokenize images once | |
with get_bf16_context(enable_bf16): | |
tokens_list = flextok_model.tokenize(imgs) | |
# Create all token subsequences | |
k_keep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] | |
tokens_list = tokens_list*len(k_keep_list) | |
subseq_list = [seq[:,:k_keep].clone() for seq, k_keep in zip(tokens_list, k_keep_list)] | |
# Detokenize various subsequences in parallel. Batch size is 9. | |
with get_bf16_context(enable_bf16): | |
generator = get_generator(seed=seed, device=device) | |
all_reconst = flextok_model.detokenize( | |
subseq_list, timesteps=timesteps, | |
guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance, | |
generator=generator, verbose=False, | |
) | |
# Transform to PIL images | |
all_images = [ | |
(TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), f'{k_keep} tokens') | |
for reconst_k, k_keep in zip(all_reconst, k_keep_list) | |
] | |
return all_images | |
examples = [ | |
'examples/0.png', 'examples/1.png', 'examples/2.png', | |
'examples/3.png', 'examples/4.png', 'examples/5.png', | |
] | |
css=""" | |
#col-container { | |
margin: 0 auto; | |
max-width: 1500px; | |
} | |
#col-input-container { | |
margin: 0 auto; | |
max-width: 400px; | |
} | |
#run-button { | |
margin: 0 auto; | |
} | |
#gallery { | |
aspect-ratio: 1/1 !important; | |
height: auto !important; | |
} | |
""" | |
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown(f""" | |
# FlexTok: Resampling Images into 1D Token Sequences of Flexible Length | |
""") | |
with gr.Row(): | |
with gr.Column(elem_id="col-input-container"): | |
gr.Markdown(f""" | |
[`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok) | |
Official demo for: <br> | |
[**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br> | |
*[Roman Bachmann](https://roman-bachmann.github.io/)\*, [Jesse Allardice](https://github.com/JesseAllardice)\*, [David Mizrahi](https://dmizrahi.com/)\*, [Enrico Fini](https://scholar.google.com/citations?user=OQMtSKIAAAAJ), [Oğuzhan Fatih Kar](https://ofkar.github.io/), [Elmira Amirloo](https://elamirloo.github.io/), [Alaaeldin El-Nouby](https://aelnouby.github.io/), [Amir Zamir](https://vilab.epfl.ch/zamir/), [Afshin Dehghan](https://scholar.google.com/citations?user=wcX-UW4AAAAJ)* | |
This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. As you will see, the first tokens capture the high-level semantic content, while subsequent ones add more fine-grained detail. | |
""") | |
img_path = gr.Image(label='RGB input image', type='filepath') | |
run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button") | |
with gr.Accordion("Advanced Settings", open=False): | |
gr.Markdown(f""" | |
The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it). | |
""") | |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=20) | |
cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5) | |
perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True) | |
result = gr.Gallery( | |
label="Reconstructions", show_label=True, elem_id="gallery", type='pil', | |
columns=[3], rows=None, object_fit="contain", height=800 | |
) | |
gr.Examples( | |
examples = examples, | |
fn = infer, | |
inputs = [img_path], | |
outputs = [result], | |
cache_examples='lazy', | |
) | |
run_button.click( | |
fn = infer, | |
inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance], | |
outputs = [result] | |
) | |
demo.queue(max_size=10).launch(share=True) |