ajsbsd commited on
Commit
baa104a
·
1 Parent(s): db89b87
Files changed (1) hide show
  1. app.py +26 -29
app.py CHANGED
@@ -13,25 +13,23 @@ import tempfile
13
  # Set environment variable to reduce memory fragmentation
14
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
 
16
- # Check if CUDA is available, fallback to CPU
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
19
 
20
- # Load pipeline with error handling for HF Spaces
21
- try:
22
- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
23
- "stabilityai/stable-diffusion-xl-refiner-1.0",
24
- torch_dtype=torch_dtype,
25
- variant="fp16" if device == "cuda" else None,
26
- use_safetensors=True
27
- )
28
-
29
- # Move to device
30
- pipe = pipe.to(device)
31
-
32
- # Enable optimizations based on available hardware
33
- if device == "cuda":
34
- # Use CPU offloading to reduce VRAM usage on GPU
35
  pipe.enable_model_cpu_offload()
36
 
37
  # Try to enable memory efficient attention
@@ -40,13 +38,9 @@ try:
40
  except (ModuleNotFoundError, ImportError):
41
  print("xformers not available, using attention slicing")
42
  pipe.enable_attention_slicing()
43
- else:
44
- # For CPU inference, enable attention slicing
45
- pipe.enable_attention_slicing()
46
 
47
- except Exception as e:
48
- print(f"Error loading pipeline: {e}")
49
- pipe = None
50
 
51
 
52
  @spaces.GPU
@@ -60,8 +54,11 @@ def img2img(
60
  num_inference_steps: int = 50,
61
  seed: int = -1,
62
  ):
63
- if pipe is None:
64
- return None, "❌ Model failed to load. Please try again later.", None
 
 
 
65
 
66
  try:
67
  # Choose image source
@@ -86,9 +83,9 @@ def img2img(
86
 
87
  # Set seed and generator
88
  if seed == -1:
89
- generator = torch.Generator(device=device)
90
  else:
91
- generator = torch.Generator(device=device).manual_seed(seed)
92
 
93
  # Validate inputs
94
  if not prompt.strip():
@@ -120,7 +117,7 @@ def img2img(
120
  "steps": num_inference_steps,
121
  "width": result.width,
122
  "height": result.height,
123
- "device": device
124
  }
125
 
126
  # Save metadata into PNG
 
13
  # Set environment variable to reduce memory fragmentation
14
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
 
16
+ # Initialize pipeline as None - will be loaded in GPU function
17
+ pipe = None
 
18
 
19
+ def load_pipeline():
20
+ """Load the pipeline on GPU when needed"""
21
+ global pipe
22
+ if pipe is None:
23
+ print("Loading pipeline...")
24
+ pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
25
+ "stabilityai/stable-diffusion-xl-refiner-1.0",
26
+ torch_dtype=torch.float16,
27
+ variant="fp16",
28
+ use_safetensors=True,
29
+ device_map="auto"
30
+ )
31
+
32
+ # Enable memory optimizations
 
33
  pipe.enable_model_cpu_offload()
34
 
35
  # Try to enable memory efficient attention
 
38
  except (ModuleNotFoundError, ImportError):
39
  print("xformers not available, using attention slicing")
40
  pipe.enable_attention_slicing()
 
 
 
41
 
42
+ print("Pipeline loaded successfully!")
43
+ return pipe
 
44
 
45
 
46
  @spaces.GPU
 
54
  num_inference_steps: int = 50,
55
  seed: int = -1,
56
  ):
57
+ # Load pipeline inside GPU context
58
+ try:
59
+ pipe = load_pipeline()
60
+ except Exception as e:
61
+ return None, f"❌ Failed to load model: {str(e)}", None
62
 
63
  try:
64
  # Choose image source
 
83
 
84
  # Set seed and generator
85
  if seed == -1:
86
+ generator = torch.Generator(device="cuda")
87
  else:
88
+ generator = torch.Generator(device="cuda").manual_seed(seed)
89
 
90
  # Validate inputs
91
  if not prompt.strip():
 
117
  "steps": num_inference_steps,
118
  "width": result.width,
119
  "height": result.height,
120
+ "device": "cuda"
121
  }
122
 
123
  # Save metadata into PNG