mac9087 commited on
Commit
ffe4279
·
verified ·
1 Parent(s): 6dd39d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -15
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import torch
3
  import time
@@ -17,7 +18,6 @@ import trimesh
17
  from transformers import pipeline
18
  from scipy.ndimage import gaussian_filter
19
  import open3d as o3d
20
- from rembg import remove
21
  import cv2
22
 
23
  # Force CPU usage
@@ -98,26 +98,55 @@ def process_with_timeout(function, args, timeout):
98
  def allowed_file(filename):
99
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
100
 
101
- # Image preprocessing: Remove background and resize
102
  def preprocess_image(image_path):
103
  try:
104
  # Load image
105
  with Image.open(image_path) as img:
106
- # Remove background using rembg
107
- img_no_bg = remove(img)
108
- # Convert to RGB if it has an alpha channel
109
- if img_no_bg.mode == 'RGBA':
110
- img_no_bg = img_no_bg.convert('RGB')
 
 
 
 
 
111
  # Resize to 512x512
112
- img_no_bg = img_no_bg.resize((512, 512), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- # Use cv2 for additional segmentation
115
- img_array = np.array(img_no_bg)
116
- gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
117
- _, mask = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
118
- img_array = cv2.bitwise_and(img_array, img_array, mask=mask)
119
 
120
- return Image.fromarray(img_array)
121
  except Exception as e:
122
  raise Exception(f"Error preprocessing image: {str(e)}")
123
 
@@ -516,4 +545,4 @@ def index():
516
  if __name__ == '__main__':
517
  cleanup_old_jobs()
518
  port = int(os.environ.get('PORT', 7860))
519
- app.run(host='0.0.0.0', port=port)
 
1
+
2
  import os
3
  import torch
4
  import time
 
18
  from transformers import pipeline
19
  from scipy.ndimage import gaussian_filter
20
  import open3d as o3d
 
21
  import cv2
22
 
23
  # Force CPU usage
 
98
  def allowed_file(filename):
99
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
100
 
101
+ # Image preprocessing: Remove background using cv2
102
  def preprocess_image(image_path):
103
  try:
104
  # Load image
105
  with Image.open(image_path) as img:
106
+ # Convert to RGB or handle transparency
107
+ if img.mode == 'RGBA':
108
+ # Use alpha channel as initial mask
109
+ img_array = np.array(img)
110
+ alpha = img_array[:, :, 3]
111
+ img_rgb = img_array[:, :, :3]
112
+ else:
113
+ img_rgb = np.array(img.convert('RGB'))
114
+ alpha = None
115
+
116
  # Resize to 512x512
117
+ img_rgb = cv2.resize(img_rgb, (512, 512), interpolation=cv2.INTER_LANCZOS4)
118
+
119
+ # Convert to grayscale
120
+ gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
121
+
122
+ # Adaptive thresholding for initial mask
123
+ thresh = cv2.adaptiveThreshold(
124
+ gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2
125
+ )
126
+
127
+ # If alpha channel exists, combine with threshold
128
+ if alpha is not None:
129
+ alpha_resized = cv2.resize(alpha, (512, 512), interpolation=cv2.INTER_LANCZOS4)
130
+ thresh = cv2.bitwise_and(thresh, alpha_resized)
131
+
132
+ # Refine with GrabCut
133
+ mask = np.zeros((512, 512), np.uint8)
134
+ mask[thresh == 255] = cv2.GC_PR_FGD # Probable foreground
135
+ mask[thresh == 0] = cv2.GC_PR_BGD # Probable background
136
+
137
+ bgdModel = np.zeros((1, 65), np.float64)
138
+ fgdModel = np.zeros((1, 65), np.float64)
139
+
140
+ rect = (10, 10, 492, 492) # ROI for GrabCut
141
+ cv2.grabCut(img_rgb, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_MASK)
142
+
143
+ # Create final mask (foreground = 1, background = 0)
144
+ mask2 = np.where((mask == cv2.GC_FGD) | (mask == cv2.GC_PR_FGD), 255, 0).astype('uint8')
145
 
146
+ # Apply mask to image
147
+ img_foreground = cv2.bitwise_and(img_rgb, img_rgb, mask=mask2)
 
 
 
148
 
149
+ return Image.fromarray(img_foreground)
150
  except Exception as e:
151
  raise Exception(f"Error preprocessing image: {str(e)}")
152
 
 
545
  if __name__ == '__main__':
546
  cleanup_old_jobs()
547
  port = int(os.environ.get('PORT', 7860))
548
+ app.run(host='0.0.0.0', port=port)