File size: 1,413 Bytes
c9c230c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

import torch
from torchvision import transforms
from PIL import Image
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

transform_image = transforms.Compose(
    [
        transforms.Resize((1024, 1024)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

def process_image(image, birefnet, device="cuda"):
    """Processes the input image to remove the background.



    Args:

        image (PIL.Image.Image): The image to process.

        birefnet (torch.nn.Module): The BiRefNet model.

        device (str): The device to run the model on (default: "cuda").



    Returns:

        PIL.Image.Image: The processed image with background removed.

    """
    try:
        image_size = image.size
        input_images = transform_image(image).unsqueeze(0).to(device)

        # Prediction
        with torch.no_grad():
            preds = birefnet(input_images)[-1].sigmoid().cpu()
        pred = preds[0].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(image_size)
        image.putalpha(mask)
        logging.info("Image processed successfully.")
        return image
    except Exception as e:
        logging.error(f"Error processing image: {e}")
        raise Exception(f"Error processing image: {e}")