BLURB / app.py
SOSSY's picture
Update app.py
52f5763 verified
raw
history blame
9.11 kB
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageFilter, ImageOps
import numpy as np
import requests
import cv2
# Dictionary of available segmentation models
SEGMENTATION_MODELS = {
"NVIDIA SegFormer (Cityscapes)": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024",
"NVIDIA SegFormer (ADE20K)": "nvidia/segformer-b0-finetuned-ade-512-512",
"Facebook MaskFormer (COCO)": "facebook/maskformer-swin-base-ade",
"OneFormer (COCO)": "shi-labs/oneformer_coco_swin_large",
"NVIDIA SegFormer (B5)": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
}
# Dictionary of available depth estimation models
DEPTH_MODELS = {
"Intel ZoeDepth (NYU-KITTI)": "Intel/zoedepth-nyu-kitti",
"DPT (Large)": "Intel/dpt-large",
"DPT (Hybrid)": "Intel/dpt-hybrid-midas",
"GLPDepth": "vinvino02/glpn-nyu"
}
# Initialize model placeholders
segmentation_model = None
depth_estimator = None
def load_segmentation_model(model_name):
"""Load the selected segmentation model"""
global segmentation_model
model_path = SEGMENTATION_MODELS[model_name]
print(f"Loading segmentation model: {model_path}...")
segmentation_model = pipeline("image-segmentation", model=model_path)
return f"Loaded segmentation model: {model_name}"
def load_depth_model(model_name):
"""Load the selected depth estimation model"""
global depth_estimator
model_path = DEPTH_MODELS[model_name]
print(f"Loading depth estimation model: {model_path}...")
depth_estimator = pipeline("depth-estimation", model=model_path)
return f"Loaded depth model: {model_name}"
def lens_blur(image, radius):
"""
Apply a more realistic lens blur (bokeh effect) using OpenCV.
"""
if radius < 1:
return image
# Convert PIL image to OpenCV format
img_np = np.array(image)
# Create a circular kernel for the bokeh effect
kernel_size = 2 * radius + 1
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
center = radius
for i in range(kernel_size):
for j in range(kernel_size):
# Create circular kernel
if np.sqrt((i - center) ** 2 + (j - center) ** 2) <= radius:
kernel[i, j] = 1.0
# Normalize the kernel
if kernel.sum() != 0:
kernel = kernel / kernel.sum()
# Apply the filter to each channel separately
channels = cv2.split(img_np)
blurred_channels = []
for channel in channels:
blurred_channel = cv2.filter2D(channel, -1, kernel)
blurred_channels.append(blurred_channel)
# Merge the channels back
blurred_img = cv2.merge(blurred_channels)
# Convert back to PIL image
return Image.fromarray(blurred_img)
def process_image(input_image, method, blur_intensity, blur_type):
"""
Process the input image using one of two methods:
1. Segmented Background Blur:
- Uses segmentation to extract a foreground mask.
- Applies the selected blur (Gaussian or Lens) to the background.
- Composites the final image.
2. Depth-based Variable Blur:
- Uses depth estimation to generate a depth map.
- Normalizes the depth map to be used as a blending mask.
- Blends a fully blurred version (using the selected blur) with the original image.
Returns:
- output_image: final composited image.
- mask_image: the mask used (binary for segmentation, normalized depth for depth-based).
"""
# Check if models are loaded
if segmentation_model is None or depth_estimator is None:
return input_image, input_image.convert("L")
# Ensure image is in RGB mode
input_image = input_image.convert("RGB")
# Select blur function based on blur_type
if blur_type == "Gaussian Blur":
blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
elif blur_type == "Lens Blur":
blur_fn = lens_blur
else:
blur_fn = lambda img, rad: img.filter(ImageFilter.GaussianBlur(radius=rad))
if method == "Segmented Background Blur":
# Use segmentation to obtain a foreground mask.
results = segmentation_model(input_image)
# Assume the last result is the main foreground object.
foreground_mask = results[-1]["mask"]
# Ensure the mask is grayscale.
foreground_mask = foreground_mask.convert("L")
# Threshold to create a binary mask.
binary_mask = foreground_mask.point(lambda p: 255 if p > 128 else 0)
# Blur the background using the selected blur function.
blurred_background = blur_fn(input_image, blur_intensity)
# Composite the final image: keep foreground and use blurred background elsewhere.
output_image = Image.composite(input_image, blurred_background, binary_mask)
mask_image = binary_mask
elif method == "Depth-based Variable Blur":
# Generate depth map.
depth_results = depth_estimator(input_image)
depth_map = depth_results["depth"]
# Convert depth map to numpy array and normalize to [0, 255]
depth_array = np.array(depth_map).astype(np.float32)
norm = (depth_array - depth_array.min()) / (depth_array.max() - depth_array.min() + 1e-8)
normalized_depth = (norm * 255).astype(np.uint8)
mask_image = Image.fromarray(normalized_depth)
# Create fully blurred version using the selected blur function.
blurred_image = blur_fn(input_image, blur_intensity)
# Convert images to arrays for blending.
orig_np = np.array(input_image).astype(np.float32)
blur_np = np.array(blurred_image).astype(np.float32)
# Reshape mask for broadcasting.
alpha = normalized_depth[..., np.newaxis] / 255.0
# Blend pixels: 0 = original; 1 = fully blurred.
blended_np = (1 - alpha) * orig_np + alpha * blur_np
blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
output_image = Image.fromarray(blended_np)
else:
output_image = input_image
mask_image = input_image.convert("L")
return output_image, mask_image
# Build a Gradio interface
with gr.Blocks() as demo:
gr.Markdown("## Image Processing App: Segmentation & Depth-based Blur")
with gr.Tab("Model Selection"):
with gr.Row():
with gr.Column():
seg_model_dropdown = gr.Dropdown(
label="Segmentation Model",
choices=list(SEGMENTATION_MODELS.keys()),
value=list(SEGMENTATION_MODELS.keys())[0]
)
seg_model_load_btn = gr.Button("Load Segmentation Model")
seg_model_status = gr.Textbox(label="Status", value="No model loaded")
with gr.Column():
depth_model_dropdown = gr.Dropdown(
label="Depth Estimation Model",
choices=list(DEPTH_MODELS.keys()),
value=list(DEPTH_MODELS.keys())[0]
)
depth_model_load_btn = gr.Button("Load Depth Model")
depth_model_status = gr.Textbox(label="Status", value="No model loaded")
with gr.Tab("Image Processing"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
method = gr.Radio(label="Processing Method",
choices=["Segmented Background Blur", "Depth-based Variable Blur"],
value="Segmented Background Blur")
blur_intensity = gr.Slider(label="Blur Intensity (Maximum Blur Radius)",
minimum=1, maximum=30, step=1, value=15)
blur_type = gr.Dropdown(label="Blur Type",
choices=["Gaussian Blur", "Lens Blur"],
value="Gaussian Blur")
run_button = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(label="Output Image")
mask_output = gr.Image(label="Mask")
# Set up event handlers
seg_model_load_btn.click(
fn=load_segmentation_model,
inputs=[seg_model_dropdown],
outputs=[seg_model_status]
)
depth_model_load_btn.click(
fn=load_depth_model,
inputs=[depth_model_dropdown],
outputs=[depth_model_status]
)
run_button.click(
fn=process_image,
inputs=[input_image, method, blur_intensity, blur_type],
outputs=[output_image, mask_output]
)
# Load default models on startup
demo.load(
fn=lambda: (
load_segmentation_model(list(SEGMENTATION_MODELS.keys())[0]),
load_depth_model(list(DEPTH_MODELS.keys())[0])
),
inputs=None,
outputs=[seg_model_status, depth_model_status]
)
# Launch the app
demo.launch()