Commit
·
1c2b941
1
Parent(s):
ee52108
Made functions use gpu
Browse files
app.py
CHANGED
@@ -68,16 +68,18 @@ def init_model():
|
|
68 |
dtype=torch.float32
|
69 |
)
|
70 |
|
71 |
-
return model, pipeline
|
72 |
|
73 |
# Global variables to store model and pipeline
|
74 |
model = None
|
75 |
pipeline = None
|
76 |
|
|
|
77 |
def get_model():
|
78 |
-
|
|
|
79 |
if model is None or pipeline is None:
|
80 |
-
model, pipeline = init_model()
|
81 |
return model, pipeline
|
82 |
|
83 |
rembg_session = rembg.new_session()
|
@@ -144,6 +146,7 @@ def add_random_background(image, color):
|
|
144 |
background = Image.new("RGBA", image.size, color)
|
145 |
return Image.alpha_composite(background, image)
|
146 |
|
|
|
147 |
def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
|
148 |
"""Preprocess the input image"""
|
149 |
try:
|
@@ -169,6 +172,7 @@ def preprocess_image(input_image, background_choice, foreground_ratio, back_grou
|
|
169 |
print(f"Error in preprocess_image: {str(e)}")
|
170 |
raise e
|
171 |
|
|
|
172 |
def gen_image(processed_image, seed, scale, step):
|
173 |
"""Generate the 3D model"""
|
174 |
try:
|
|
|
68 |
dtype=torch.float32
|
69 |
)
|
70 |
|
71 |
+
return model, pipeline, args
|
72 |
|
73 |
# Global variables to store model and pipeline
|
74 |
model = None
|
75 |
pipeline = None
|
76 |
|
77 |
+
@spaces.GPU
|
78 |
def get_model():
|
79 |
+
"""Lazy initialization of model and pipeline"""
|
80 |
+
global model, pipeline, args
|
81 |
if model is None or pipeline is None:
|
82 |
+
model, pipeline, args = init_model()
|
83 |
return model, pipeline
|
84 |
|
85 |
rembg_session = rembg.new_session()
|
|
|
146 |
background = Image.new("RGBA", image.size, color)
|
147 |
return Image.alpha_composite(background, image)
|
148 |
|
149 |
+
@spaces.GPU
|
150 |
def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
|
151 |
"""Preprocess the input image"""
|
152 |
try:
|
|
|
172 |
print(f"Error in preprocess_image: {str(e)}")
|
173 |
raise e
|
174 |
|
175 |
+
@spaces.GPU
|
176 |
def gen_image(processed_image, seed, scale, step):
|
177 |
"""Generate the 3D model"""
|
178 |
try:
|