FlexTok / app.py
roman-bachmann's picture
Initial commit
a4f1fc6
raw
history blame
7.13 kB
from typing import List
import os
import spaces
import gradio as gr
import random
from PIL import Image
import matplotlib.pyplot as plt
import einops
import numpy as np
import torch
from torchvision import transforms
import torchvision.transforms.functional as TF
from flextok.flextok_wrapper import FlexTokFromHub
from flextok.utils.demo import imgs_from_urls, denormalize, batch_to_pil
from flextok.utils.misc import detect_bf16_support, get_bf16_context, get_generator
# We recommend running this demo on an A100 GPU
if torch.cuda.is_available():
device = "cuda"
gpu_type = torch.cuda.get_device_name(torch.cuda.current_device())
power_device = f"{gpu_type}"
torch.cuda.max_memory_allocated(device=device)
# Detect if bf16 is enabled or not
enable_bf16 = detect_bf16_support()
else:
device, power_device, enable_bf16 = "cpu", "CPU", False
print(f'Device: {device}, GPU type: {gpu_type}')
print('BF16 enabled:', enable_bf16)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
# Global no_grad
torch.set_grad_enabled(False)
MAX_SEED = np.iinfo(np.int32).max
MODEL_ID = 'EPFL-VILAB/flextok_d18_d28_dfn'
MODEL_NAME = 'FlexTok d18-d28 (DFN)'
# Load FlexTok model from HF Hub
flextok_model = FlexTokFromHub.from_pretrained(MODEL_ID).to(device).eval()
def img_from_path(
path: str,
img_size: int = 256,
mean: List[float] = [0.5, 0.5, 0.5],
std: List[float] = [0.5, 0.5, 0.5],
) -> torch.Tensor:
# Image loading helper function
img_pil = Image.open(path).convert("RGB")
transform = transforms.Compose(
[
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
return transform(img_pil).unsqueeze(0)
@spaces.GPU(duration=20)
def infer(img_path, seed=0, randomize_seed=False, timesteps=20, cfg_scale=7.5, perform_norm_guidance=True):
if randomize_seed:
seed = None
imgs = img_from_path(img_path).to(device)
# Tokenize images once
with get_bf16_context(enable_bf16):
tokens_list = flextok_model.tokenize(imgs)
# Create all token subsequences
k_keep_list = [1, 2, 4, 8, 16, 32, 64, 128, 256]
tokens_list = tokens_list*len(k_keep_list)
subseq_list = [seq[:,:k_keep].clone() for seq, k_keep in zip(tokens_list, k_keep_list)]
# Detokenize various subsequences in parallel. Batch size is 9.
with get_bf16_context(enable_bf16):
generator = get_generator(seed=seed, device=device)
all_reconst = flextok_model.detokenize(
subseq_list, timesteps=timesteps,
guidance_scale=cfg_scale, perform_norm_guidance=perform_norm_guidance,
generator=generator, verbose=False,
)
# Transform to PIL images
all_images = [
(TF.to_pil_image(denormalize(reconst_k).clamp(0,1)), f'{k_keep} tokens')
for reconst_k, k_keep in zip(all_reconst, k_keep_list)
]
return all_images
examples = [
'examples/0.png', 'examples/1.png', 'examples/2.png',
'examples/3.png', 'examples/4.png', 'examples/5.png',
]
css="""
#col-container {
margin: 0 auto;
max-width: 1500px;
}
#col-input-container {
margin: 0 auto;
max-width: 400px;
}
#run-button {
margin: 0 auto;
}
#gallery {
aspect-ratio: 1/1 !important;
height: auto !important;
}
"""
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# FlexTok: Resampling Images into 1D Token Sequences of Flexible Length
""")
with gr.Row():
with gr.Column(elem_id="col-input-container"):
gr.Markdown(f"""
[`Website`](https://flextok.epfl.ch) | [`arXiv`](https://arxiv.org/abs/2502.13967) | [`GitHub`](https://github.com/apple/ml-flextok)
Official demo for: <br>
[**FlexTok: Resampling Images into 1D Token Sequences of Flexible Length**](https://arxiv.org/abs/2502.13967), arXiv 2025 <br>
*[Roman Bachmann](https://roman-bachmann.github.io/)\*, [Jesse Allardice](https://github.com/JesseAllardice)\*, [David Mizrahi](https://dmizrahi.com/)\*, [Enrico Fini](https://scholar.google.com/citations?user=OQMtSKIAAAAJ), [Oğuzhan Fatih Kar](https://ofkar.github.io/), [Elmira Amirloo](https://elamirloo.github.io/), [Alaaeldin El-Nouby](https://aelnouby.github.io/), [Amir Zamir](https://vilab.epfl.ch/zamir/), [Afshin Dehghan](https://scholar.google.com/citations?user=wcX-UW4AAAAJ)*
This demo uses the FlexTok tokenizer to autoencode the given RGB input, using [{MODEL_ID}](https://huggingface.co/{MODEL_ID}), running on *{power_device}*. The FlexTok encoder produces a 1D sequence of discrete tokens that are ordered in a coarse-to-fine manner. We show reconstructions from truncated subsequences, using the first 1, 2, 4, 8, ..., 256 tokens. As you will see, the first tokens capture the high-level semantic content, while subsequent ones add more fine-grained detail.
""")
img_path = gr.Image(label='RGB input image', type='filepath')
run_button = gr.Button(f"Autoencode with {MODEL_NAME}", scale=0, elem_id="run-button")
with gr.Accordion("Advanced Settings", open=False):
gr.Markdown(f"""
The FlexTok decoder is a rectified flow model. The following settings control the seed of the initial noise, the number of denoising timesteps, the guidance scale, and whether to perform [Adaptive Projected Guidance](https://arxiv.org/abs/2410.02416) (we recommend enabling it).
""")
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
timesteps = gr.Slider(label="Denoising timesteps", minimum=1, maximum=1000, step=1, value=20)
cfg_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.1, value=7.5)
perform_norm_guidance = gr.Checkbox(label="Perform Adaptive Projected Guidance", value=True)
result = gr.Gallery(
label="Reconstructions", show_label=True, elem_id="gallery", type='pil',
columns=[3], rows=None, object_fit="contain", height=800
)
gr.Examples(
examples = examples,
fn = infer,
inputs = [img_path],
outputs = [result],
cache_examples='lazy',
)
run_button.click(
fn = infer,
inputs = [img_path, seed, randomize_seed, timesteps, cfg_scale, perform_norm_guidance],
outputs = [result]
)
demo.queue(max_size=10).launch(share=True)