background-removal / image_processor.py
smolSWE's picture
Refactor app.py for modularity and error handling, and clean up requirements.txt
c9c230c verified
raw
history blame
1.41 kB
import torch
from torchvision import transforms
from PIL import Image
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def process_image(image, birefnet, device="cuda"):
"""Processes the input image to remove the background.
Args:
image (PIL.Image.Image): The image to process.
birefnet (torch.nn.Module): The BiRefNet model.
device (str): The device to run the model on (default: "cuda").
Returns:
PIL.Image.Image: The processed image with background removed.
"""
try:
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to(device)
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
logging.info("Image processed successfully.")
return image
except Exception as e:
logging.error(f"Error processing image: {e}")
raise Exception(f"Error processing image: {e}")