Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
1c2b941
·
1 Parent(s): ee52108

Made functions use gpu

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -68,16 +68,18 @@ def init_model():
68
  dtype=torch.float32
69
  )
70
 
71
- return model, pipeline
72
 
73
  # Global variables to store model and pipeline
74
  model = None
75
  pipeline = None
76
 
 
77
  def get_model():
78
- global model, pipeline
 
79
  if model is None or pipeline is None:
80
- model, pipeline = init_model()
81
  return model, pipeline
82
 
83
  rembg_session = rembg.new_session()
@@ -144,6 +146,7 @@ def add_random_background(image, color):
144
  background = Image.new("RGBA", image.size, color)
145
  return Image.alpha_composite(background, image)
146
 
 
147
  def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
148
  """Preprocess the input image"""
149
  try:
@@ -169,6 +172,7 @@ def preprocess_image(input_image, background_choice, foreground_ratio, back_grou
169
  print(f"Error in preprocess_image: {str(e)}")
170
  raise e
171
 
 
172
  def gen_image(processed_image, seed, scale, step):
173
  """Generate the 3D model"""
174
  try:
 
68
  dtype=torch.float32
69
  )
70
 
71
+ return model, pipeline, args
72
 
73
  # Global variables to store model and pipeline
74
  model = None
75
  pipeline = None
76
 
77
+ @spaces.GPU
78
  def get_model():
79
+ """Lazy initialization of model and pipeline"""
80
+ global model, pipeline, args
81
  if model is None or pipeline is None:
82
+ model, pipeline, args = init_model()
83
  return model, pipeline
84
 
85
  rembg_session = rembg.new_session()
 
146
  background = Image.new("RGBA", image.size, color)
147
  return Image.alpha_composite(background, image)
148
 
149
+ @spaces.GPU
150
  def preprocess_image(input_image, background_choice, foreground_ratio, back_groud_color):
151
  """Preprocess the input image"""
152
  try:
 
172
  print(f"Error in preprocess_image: {str(e)}")
173
  raise e
174
 
175
+ @spaces.GPU
176
  def gen_image(processed_image, seed, scale, step):
177
  """Generate the 3D model"""
178
  try: