import torch from torchvision import transforms from PIL import Image import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') transform_image = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) def process_image(image, birefnet, device="cuda"): """Processes the input image to remove the background. Args: image (PIL.Image.Image): The image to process. birefnet (torch.nn.Module): The BiRefNet model. device (str): The device to run the model on (default: "cuda"). Returns: PIL.Image.Image: The processed image with background removed. """ try: image_size = image.size input_images = transform_image(image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = birefnet(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) logging.info("Image processed successfully.") return image except Exception as e: logging.error(f"Error processing image: {e}") raise Exception(f"Error processing image: {e}")