Refactor app.py for modularity and error handling, and clean up requirements.txt

#18
by smolSWE - opened
Files changed (4) hide show
  1. app.py +87 -71
  2. image_loader.py +74 -0
  3. image_processor.py +43 -0
  4. requirements.txt +18 -17
app.py CHANGED
@@ -1,71 +1,87 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- import spaces
4
- from transformers import AutoModelForImageSegmentation
5
- import torch
6
- from torchvision import transforms
7
-
8
- torch.set_float32_matmul_precision(["high", "highest"][0])
9
-
10
- birefnet = AutoModelForImageSegmentation.from_pretrained(
11
- "ZhengPeng7/BiRefNet", trust_remote_code=True
12
- )
13
- birefnet.to("cuda")
14
-
15
- transform_image = transforms.Compose(
16
- [
17
- transforms.Resize((1024, 1024)),
18
- transforms.ToTensor(),
19
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
20
- ]
21
- )
22
-
23
- def fn(image):
24
- im = load_img(image, output_type="pil")
25
- im = im.convert("RGB")
26
- origin = im.copy()
27
- processed_image = process(im)
28
- return (processed_image, origin)
29
-
30
- @spaces.GPU
31
- def process(image):
32
- image_size = image.size
33
- input_images = transform_image(image).unsqueeze(0).to("cuda")
34
- # Prediction
35
- with torch.no_grad():
36
- preds = birefnet(input_images)[-1].sigmoid().cpu()
37
- pred = preds[0].squeeze()
38
- pred_pil = transforms.ToPILImage()(pred)
39
- mask = pred_pil.resize(image_size)
40
- image.putalpha(mask)
41
- return image
42
-
43
- def process_file(f):
44
- name_path = f.rsplit(".", 1)[0] + ".png"
45
- im = load_img(f, output_type="pil")
46
- im = im.convert("RGB")
47
- transparent = process(im)
48
- transparent.save(name_path)
49
- return name_path
50
-
51
- slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
52
- slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
53
- image_upload = gr.Image(label="Upload an image")
54
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
55
- url_input = gr.Textbox(label="Paste an image URL")
56
- output_file = gr.File(label="Output PNG File")
57
-
58
- # Example images
59
- chameleon = load_img("butterfly.jpg", output_type="pil")
60
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
61
-
62
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
63
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
64
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
65
-
66
- demo = gr.TabbedInterface(
67
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
68
- )
69
-
70
- if __name__ == "__main__":
71
- demo.launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import spaces
4
+ import torch
5
+ from image_loader import load_image_from_url, load_image_from_file
6
+ from image_processor import process_image
7
+ import logging
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ torch.set_float32_matmul_precision(["high", "highest"][0])
13
+
14
+ try:
15
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
16
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
17
+ )
18
+ birefnet.to("cuda")
19
+ logging.info("BiRefNet model loaded successfully.")
20
+ except Exception as e:
21
+ logging.error(f"Error loading BiRefNet model: {e}")
22
+ raise Exception(f"Error loading BiRefNet model: {e}")
23
+
24
+ def fn(image_input):
25
+ try:
26
+ if isinstance(image_input, str): # URL input
27
+ img = load_image_from_url(image_input)
28
+ else: # File upload
29
+ img = load_image_from_file(image_input)
30
+
31
+ img = img.convert("RGB")
32
+ origin = img.copy()
33
+ processed_image = process(img)
34
+ return (processed_image, origin)
35
+ except Exception as e:
36
+ logging.error(f"Error in fn function: {e}")
37
+ return None, None # Return None or a placeholder image
38
+
39
+ @spaces.GPU
40
+ def process(image):
41
+ try:
42
+ processed_image = process_image(image, birefnet)
43
+ return processed_image
44
+ except Exception as e:
45
+ logging.error(f"Error in process function: {e}")
46
+ raise gr.Error(f"Error processing image: {e}")
47
+
48
+
49
+ def process_file(file_path):
50
+ try:
51
+ name_path = file_path.rsplit(".", 1)[0] + ".png"
52
+ img = load_image_from_file(file_path)
53
+ img = img.convert("RGB")
54
+ transparent = process(img)
55
+ transparent.save(name_path)
56
+ logging.info(f"Processed image saved to: {name_path}")
57
+ return name_path
58
+ except Exception as e:
59
+ logging.error(f"Error in process_file function: {e}")
60
+ raise gr.Error(f"Error processing file: {e}")
61
+
62
+ slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
63
+ slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
64
+ image_upload = gr.Image(label="Upload an image")
65
+ image_file_upload = gr.Image(label="Upload an image", type="filepath")
66
+ url_input = gr.Textbox(label="Paste an image URL")
67
+ output_file = gr.File(label="Output PNG File")
68
+
69
+ # Example images
70
+ try:
71
+ chameleon = load_image_from_file("butterfly.jpg")
72
+ except Exception as e:
73
+ logging.error(f"Error loading example image: {e}")
74
+ chameleon = None # Or a placeholder image
75
+
76
+ url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
77
+
78
+ tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
79
+ tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
80
+ tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
81
+
82
+ demo = gr.TabbedInterface(
83
+ [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
84
+ )
85
+
86
+ if __name__ == "__main__":
87
+ demo.launch(show_error=True)
image_loader.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+
9
+ def load_image_from_url(url):
10
+ """Loads an image from a URL.
11
+
12
+ Args:
13
+ url (str): The URL of the image.
14
+
15
+ Returns:
16
+ PIL.Image.Image: The loaded image.
17
+
18
+ Raises:
19
+ Exception: If the image cannot be loaded from the URL.
20
+ """
21
+ try:
22
+ response = requests.get(url, stream=True)
23
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
24
+ image = Image.open(BytesIO(response.content))
25
+ logging.info(f"Image loaded successfully from URL: {url}")
26
+ return image
27
+ except requests.exceptions.RequestException as e:
28
+ logging.error(f"Error loading image from URL: {url} - {e}")
29
+ raise Exception(f"Error loading image from URL: {url} - {e}")
30
+ except Exception as e:
31
+ logging.error(f"Error opening image from URL: {url} - {e}")
32
+ raise Exception(f"Error opening image from URL: {url} - {e}")
33
+
34
+
35
+ def load_image_from_file(file_path):
36
+ """Loads an image from a file.
37
+
38
+ Args:
39
+ file_path (str): The path to the image file.
40
+
41
+ Returns:
42
+ PIL.Image.Image: The loaded image.
43
+
44
+ Raises:
45
+ Exception: If the image cannot be loaded from the file.
46
+ """
47
+ try:
48
+ image = Image.open(file_path)
49
+ logging.info(f"Image loaded successfully from file: {file_path}")
50
+ return image
51
+ except FileNotFoundError:
52
+ logging.error(f"File not found: {file_path}")
53
+ raise Exception(f"File not found: {file_path}")
54
+ except Exception as e:
55
+ logging.error(f"Error loading image from file: {file_path} - {e}")
56
+ raise Exception(f"Error loading image from file: {file_path} - {e}")
57
+
58
+ if __name__ == '__main__':
59
+ # Example Usage
60
+ try:
61
+ image_url = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
62
+ image_from_url = load_image_from_url(image_url)
63
+ print("Image loaded from URL successfully!")
64
+ # image_from_url.show() # Display the image (optional)
65
+ except Exception as e:
66
+ print(e)
67
+
68
+ try:
69
+ image_path = "butterfly.jpg"
70
+ image_from_file = load_image_from_file(image_path)
71
+ print("Image loaded from file successfully!")
72
+ # image_from_file.show() # Display the image (optional)
73
+ except Exception as e:
74
+ print(e)
image_processor.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
8
+
9
+ transform_image = transforms.Compose(
10
+ [
11
+ transforms.Resize((1024, 1024)),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
14
+ ]
15
+ )
16
+
17
+ def process_image(image, birefnet, device="cuda"):
18
+ """Processes the input image to remove the background.
19
+
20
+ Args:
21
+ image (PIL.Image.Image): The image to process.
22
+ birefnet (torch.nn.Module): The BiRefNet model.
23
+ device (str): The device to run the model on (default: "cuda").
24
+
25
+ Returns:
26
+ PIL.Image.Image: The processed image with background removed.
27
+ """
28
+ try:
29
+ image_size = image.size
30
+ input_images = transform_image(image).unsqueeze(0).to(device)
31
+
32
+ # Prediction
33
+ with torch.no_grad():
34
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
35
+ pred = preds[0].squeeze()
36
+ pred_pil = transforms.ToPILImage()(pred)
37
+ mask = pred_pil.resize(image_size)
38
+ image.putalpha(mask)
39
+ logging.info("Image processed successfully.")
40
+ return image
41
+ except Exception as e:
42
+ logging.error(f"Error processing image: {e}")
43
+ raise Exception(f"Error processing image: {e}")
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
- torch
2
- accelerate
3
- opencv-python
4
- spaces
5
- pillow
6
- numpy
7
- timm
8
- kornia
9
- prettytable
10
- typing
11
- scikit-image
12
- huggingface_hub
13
- transformers>=4.39.1
14
- gradio
15
- gradio_imageslider
16
- loadimg>=0.1.1
17
- einops
 
 
1
+
2
+ accelerate==0.27.2
3
+ einops==0.7.0
4
+ gradio==4.16.0
5
+ gradio_imageslider==0.2.0
6
+ huggingface_hub==0.20.3
7
+ kornia==0.7.1
8
+ loadimg==0.1.1
9
+ numpy==1.26.4
10
+ opencv-python==4.9.0.54
11
+ pillow==10.2.0
12
+ prettytable==4.0.0
13
+ scikit-image==0.23.0
14
+ spaces==0.35.0
15
+ timm==0.9.12
16
+ torch==2.2.0
17
+ transformers==4.39.1
18
+ typing==3.7.4.3