Aatricks commited on
Commit
8a58a6e
·
verified ·
1 Parent(s): f83908e

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ share/python-wheels/
20
+ *.egg-info/
21
+ .installed.cfg
22
+ *.egg
23
+ MANIFEST
24
+
25
+ # Virtual environments
26
+ .env
27
+ .venv
28
+ env/
29
+ venv/
30
+ ENV/
31
+ env.bak/
32
+ venv.bak/
33
+
34
+ # IDE files
35
+ .vscode/
36
+ .idea/
37
+ *.swp
38
+ *.swo
39
+ *~
40
+
41
+ # OS files
42
+ .DS_Store
43
+ .DS_Store?
44
+ ._*
45
+ .Spotlight-V100
46
+ .Trashes
47
+ ehthumbs.db
48
+ Thumbs.db
49
+
50
+ # Git
51
+ .git/
52
+ .gitignore
53
+
54
+ # Docker
55
+ Dockerfile*
56
+ .dockerignore
57
+ docker-compose*.yml
58
+
59
+ # Documentation
60
+ *.md
61
+ docs/
62
+
63
+ # Large model files (these should be downloaded at runtime)
64
+ *.safetensors
65
+ *.ckpt
66
+ *.pt
67
+ *.pth
68
+ *.bin
69
+ *.gguf
70
+
71
+ # Logs
72
+ *.log
73
+ logs/
74
+
75
+ # Temporary files
76
+ tmp/
77
+ temp/
78
+ *.tmp
79
+
80
+ # Generated images (these will be created at runtime)
81
+ _internal/output/
82
+
83
+ # Large dependencies that will be installed via pip
84
+ stable_fast-*.whl
Dockerfile ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.10 base image
2
+ FROM python:3.10-slim-bullseye
3
+
4
+ # Set environment variables
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONUNBUFFERED=1
7
+ ENV PYTHONDONTWRITEBYTECODE=1
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python3-dev \
12
+ python3-venv \
13
+ python3-pip \
14
+ python3-tk \
15
+ git \
16
+ wget \
17
+ curl \
18
+ build-essential \
19
+ libgl1-mesa-glx \
20
+ libglib2.0-0 \
21
+ libsm6 \
22
+ libxext6 \
23
+ libxrender-dev \
24
+ libgomp1 \
25
+ software-properties-common \
26
+ && rm -rf /var/lib/apt/lists/*
27
+
28
+
29
+ # Set working directory
30
+ WORKDIR /app
31
+
32
+ # Copy requirements first to leverage Docker cache
33
+ COPY requirements.txt .
34
+
35
+ # Upgrade pip and install uv for faster package installation
36
+ RUN python3 -m pip install --upgrade pip
37
+ RUN python3 -m pip install uv
38
+
39
+ # Install PyTorch with CUDA support
40
+ RUN python3 -m uv pip install --system --index-url https://download.pytorch.org/whl/cu128 \
41
+ torch torchvision "triton>=2.1.0"
42
+
43
+ # Install numpy with version constraint
44
+ RUN python3 -m uv pip install --system "numpy<2.0.0"
45
+
46
+ # Install Python dependencies
47
+ RUN python3 -m uv pip install --system -r requirements.txt
48
+
49
+ # Copy the entire project
50
+ COPY . .
51
+
52
+ # Create necessary directories
53
+ RUN mkdir -p ./_internal/output/classic \
54
+ ./_internal/output/Flux \
55
+ ./_internal/output/HiresFix \
56
+ ./_internal/output/Img2Img \
57
+ ./_internal/output/Adetailer \
58
+ ./_internal/checkpoints \
59
+ ./_internal/clip \
60
+ ./_internal/embeddings \
61
+ ./_internal/ESRGAN \
62
+ ./_internal/loras \
63
+ ./_internal/sd1_tokenizer \
64
+ ./_internal/unet \
65
+ ./_internal/vae \
66
+ ./_internal/vae_approx \
67
+ ./_internal/yolos
68
+
69
+ # Create last_seed.txt if it doesn't exist
70
+ RUN echo "42" > ./_internal/last_seed.txt
71
+
72
+ # Create prompt.txt if it doesn't exist
73
+ RUN echo "A beautiful landscape" > ./_internal/prompt.txt
74
+
75
+ # Expose the port that Gradio will run on
76
+ EXPOSE 7860
77
+
78
+ # Set environment variable to indicate this is running in a container
79
+ ENV GRADIO_SERVER_NAME=0.0.0.0
80
+ ENV GRADIO_SERVER_PORT=7860
81
+
82
+ # Health check
83
+ HEALTHCHECK --interval=30s --timeout=30s --start-period=60s --retries=3 \
84
+ CMD curl -f http://localhost:7860/ || exit 1
85
+
86
+ # Run the Gradio app
87
+ CMD ["python3", "app.py"]
README.md CHANGED
@@ -2,7 +2,7 @@
2
  title: LightDiffusion-Next
3
  app_file: app.py
4
  sdk: gradio
5
- sdk_version: 5.20.0
6
  ---
7
  <div align="center">
8
 
@@ -40,7 +40,7 @@ That's when the first version of LightDiffusion was born which only counted [300
40
 
41
  Advanced users can take advantage of features like **attention syntax**, **Hires-Fix** or **ADetailer**. These tools provide better quality and flexibility for generating complex and high-resolution outputs.
42
 
43
- **LightDiffusion-Next** is fine-tuned for **performance**. Features such as **Xformers** acceleration, **BFloat16** precision support, **WaveSpeed** dynamic caching, and **Stable-Fast** model compilation (which offers up to a 70% speed boost) ensure smooth and efficient operation, even on demanding workloads.
44
 
45
  ---
46
 
@@ -49,7 +49,7 @@ Advanced users can take advantage of features like **attention syntax**, **Hires
49
  Here’s what makes LightDiffusion-Next stand out:
50
 
51
  - **Speed and Efficiency**:
52
- Enjoy industry-leading performance with built-in Xformers, Pytorch, Wavespeed and Stable-Fast optimizations, achieving up to 30% faster speeds compared to the rest of the AI image generation backends in SD1.5 and up to 2x for Flux.
53
 
54
  - **Automatic Detailing**:
55
  Effortlessly enhance faces and body details with AI-driven tools based on the [Impact Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack).
@@ -109,6 +109,33 @@ With its unmatched speed and efficiency, LightDiffusion-Next sets the benchmark
109
  2. Run `run.bat` in a terminal.
110
  3. Start creating!
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  ### Command-Line Pipeline
113
 
114
  For a GUI-free experience, use the pipeline:
 
2
  title: LightDiffusion-Next
3
  app_file: app.py
4
  sdk: gradio
5
+ sdk_version: 5.38.0
6
  ---
7
  <div align="center">
8
 
 
40
 
41
  Advanced users can take advantage of features like **attention syntax**, **Hires-Fix** or **ADetailer**. These tools provide better quality and flexibility for generating complex and high-resolution outputs.
42
 
43
+ **LightDiffusion-Next** is fine-tuned for **performance**. Features such as **Xformers** acceleration, **BFloat16** precision support, **WaveSpeed** dynamic caching, **Multi-scale diffusion**, and **Stable-Fast** model compilation (which offers up to a 70% speed boost) ensure smooth and efficient operation, even on demanding workloads.
44
 
45
  ---
46
 
 
49
  Here’s what makes LightDiffusion-Next stand out:
50
 
51
  - **Speed and Efficiency**:
52
+ Enjoy industry-leading performance with built-in Xformers, Pytorch, Wavespeed and Stable-Fast optimizations, Multi-scale diffusion, achieving up to 30% faster speeds compared to the rest of the AI image generation backends in SD1.5 and up to 2x for Flux.
53
 
54
  - **Automatic Detailing**:
55
  Effortlessly enhance faces and body details with AI-driven tools based on the [Impact Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack).
 
109
  2. Run `run.bat` in a terminal.
110
  3. Start creating!
111
 
112
+ ### 🐳 Docker Setup
113
+
114
+ Run LightDiffusion-Next in a containerized environment with GPU acceleration:
115
+
116
+ **Prerequisites:**
117
+ - Docker with NVIDIA Container Toolkit installed
118
+ - NVIDIA GPU with CUDA support
119
+
120
+ **Quick Start with Docker:**
121
+ ```bash
122
+ # Build and run with docker-compose (recommended)
123
+ docker-compose up --build
124
+
125
+ # Or build and run manually
126
+ docker build -t lightdiffusion-next .
127
+ docker run --gpus all -p 7860:7860 -v ./output:/app/_internal/output lightdiffusion-next
128
+ ```
129
+
130
+ **Access the Gradio Web Interface:**
131
+ Open your browser and navigate to `http://localhost:7860`
132
+
133
+ **Volume Mounts:**
134
+ - `./output:/app/_internal/output` - Persist generated images
135
+ - `./checkpoints:/app/_internal/checkpoints` - Store model files
136
+ - `./loras:/app/_internal/loras` - Store LoRA files
137
+ - `./embeddings:/app/_internal/embeddings` - Store embeddings
138
+
139
  ### Command-Line Pipeline
140
 
141
  For a GUI-free experience, use the pipeline:
app.py CHANGED
@@ -1,216 +1,372 @@
1
- import glob
2
- import gradio as gr
3
- import sys
4
- import os
5
- from PIL import Image
6
- import numpy as np
7
- import spaces
8
-
9
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
10
-
11
- from modules.user.pipeline import pipeline
12
- import torch
13
-
14
-
15
- def load_generated_images():
16
- """Load generated images with given prefix from disk"""
17
- image_files = glob.glob("./_internal/output/**/*.png")
18
-
19
- # If there are no image files, return
20
- if not image_files:
21
- return []
22
-
23
- # Sort files by modification time in descending order
24
- image_files.sort(key=os.path.getmtime, reverse=True)
25
-
26
- # Get most recent timestamp
27
- latest_time = os.path.getmtime(image_files[0])
28
-
29
- # Get all images from same batch (within 1 second of most recent)
30
- batch_images = []
31
- for file in image_files:
32
- if abs(os.path.getmtime(file) - latest_time) < 1.0:
33
- try:
34
- img = Image.open(file)
35
- batch_images.append(img)
36
- except:
37
- continue
38
-
39
- if not batch_images:
40
- return []
41
- return batch_images
42
-
43
-
44
- @spaces.GPU(duration=120)
45
- def generate_images(
46
- prompt: str,
47
- width: int = 512,
48
- height: int = 512,
49
- num_images: int = 1,
50
- batch_size: int = 1,
51
- hires_fix: bool = False,
52
- adetailer: bool = False,
53
- enhance_prompt: bool = False,
54
- img2img_enabled: bool = False,
55
- img2img_image: str = None,
56
- stable_fast: bool = False,
57
- reuse_seed: bool = False,
58
- flux_enabled: bool = False,
59
- prio_speed: bool = False,
60
- realistic_model: bool = False,
61
- progress=gr.Progress(),
62
- ):
63
- """Generate images using the LightDiffusion pipeline"""
64
- try:
65
- if img2img_enabled and img2img_image is not None:
66
- # Convert numpy array to PIL Image
67
- if isinstance(img2img_image, np.ndarray):
68
- img_pil = Image.fromarray(img2img_image)
69
- img_pil.save("temp_img2img.png")
70
- prompt = "temp_img2img.png"
71
-
72
- # Run pipeline and capture saved images
73
- with torch.inference_mode():
74
- pipeline(
75
- prompt=prompt,
76
- w=width,
77
- h=height,
78
- number=num_images,
79
- batch=batch_size,
80
- hires_fix=hires_fix,
81
- adetailer=adetailer,
82
- enhance_prompt=enhance_prompt,
83
- img2img=img2img_enabled,
84
- stable_fast=stable_fast,
85
- reuse_seed=reuse_seed,
86
- flux_enabled=flux_enabled,
87
- prio_speed=prio_speed,
88
- autohdr=True,
89
- realistic_model=realistic_model,
90
- )
91
-
92
- # Clean up temporary file if it exists
93
- if os.path.exists("temp_img2img.png"):
94
- os.remove("temp_img2img.png")
95
-
96
- return load_generated_images()
97
-
98
- except Exception:
99
- import traceback
100
-
101
- print(traceback.format_exc())
102
- # Clean up temporary file if it exists
103
- if os.path.exists("temp_img2img.png"):
104
- os.remove("temp_img2img.png")
105
- return [Image.new("RGB", (512, 512), color="black")]
106
-
107
-
108
- # Create Gradio interface
109
- with gr.Blocks(title="LightDiffusion Web UI") as demo:
110
- gr.Markdown("# LightDiffusion Web UI")
111
- gr.Markdown("Generate AI images using LightDiffusion")
112
- gr.Markdown(
113
- "This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next"
114
- )
115
-
116
- with gr.Row():
117
- with gr.Column():
118
- # Input components
119
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
120
-
121
- with gr.Row():
122
- width = gr.Slider(
123
- minimum=64, maximum=2048, value=512, step=64, label="Width"
124
- )
125
- height = gr.Slider(
126
- minimum=64, maximum=2048, value=512, step=64, label="Height"
127
- )
128
-
129
- with gr.Row():
130
- num_images = gr.Slider(
131
- minimum=1, maximum=10, value=1, step=1, label="Number of Images"
132
- )
133
- batch_size = gr.Slider(
134
- minimum=1, maximum=4, value=1, step=1, label="Batch Size"
135
- )
136
-
137
- with gr.Row():
138
- hires_fix = gr.Checkbox(label="HiRes Fix")
139
- adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
140
- enhance_prompt = gr.Checkbox(label="Enhance Prompt")
141
- stable_fast = gr.Checkbox(label="Stable Fast Mode")
142
-
143
- with gr.Row():
144
- reuse_seed = gr.Checkbox(label="Reuse Seed")
145
- flux_enabled = gr.Checkbox(label="Flux Mode")
146
- prio_speed = gr.Checkbox(label="Prioritize Speed")
147
- realistic_model = gr.Checkbox(label="Realistic Model")
148
-
149
- with gr.Row():
150
- img2img_enabled = gr.Checkbox(label="Image to Image Mode")
151
- img2img_image = gr.Image(label="Input Image for img2img", visible=False)
152
-
153
- # Make input image visible only when img2img is enabled
154
- img2img_enabled.change(
155
- fn=lambda x: gr.update(visible=x),
156
- inputs=[img2img_enabled],
157
- outputs=[img2img_image],
158
- )
159
-
160
- generate_btn = gr.Button("Generate")
161
-
162
- # Output gallery
163
- gallery = gr.Gallery(
164
- label="Generated Images",
165
- show_label=True,
166
- elem_id="gallery",
167
- columns=[2],
168
- rows=[2],
169
- object_fit="contain",
170
- height="auto",
171
- )
172
-
173
- # Connect generate button to pipeline
174
- generate_btn.click(
175
- fn=generate_images,
176
- inputs=[
177
- prompt,
178
- width,
179
- height,
180
- num_images,
181
- batch_size,
182
- hires_fix,
183
- adetailer,
184
- enhance_prompt,
185
- img2img_enabled,
186
- img2img_image,
187
- stable_fast,
188
- reuse_seed,
189
- flux_enabled,
190
- prio_speed,
191
- realistic_model,
192
- ],
193
- outputs=gallery,
194
- )
195
-
196
-
197
- def is_huggingface_space():
198
- return "SPACE_ID" in os.environ
199
-
200
-
201
- # For local testing
202
- if __name__ == "__main__":
203
- if is_huggingface_space():
204
- demo.launch(
205
- debug=False,
206
- server_name="0.0.0.0",
207
- server_port=7860, # Standard HF Spaces port
208
- )
209
- else:
210
- demo.launch(
211
- server_name="0.0.0.0",
212
- server_port=8000,
213
- auth=None,
214
- share=True, # Only enable sharing locally
215
- debug=True,
216
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import sys
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+ import spaces
8
+
9
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
10
+
11
+ from modules.user.pipeline import pipeline
12
+ import torch
13
+
14
+
15
+ def load_generated_images():
16
+ """Load generated images with given prefix from disk"""
17
+ image_files = glob.glob("./_internal/output/**/*.png")
18
+
19
+ # If there are no image files, return
20
+ if not image_files:
21
+ return []
22
+
23
+ # Sort files by modification time in descending order
24
+ image_files.sort(key=os.path.getmtime, reverse=True)
25
+
26
+ # Get most recent timestamp
27
+ latest_time = os.path.getmtime(image_files[0])
28
+
29
+ # Get all images from same batch (within 1 second of most recent)
30
+ batch_images = []
31
+ for file in image_files:
32
+ if abs(os.path.getmtime(file) - latest_time) < 1.0:
33
+ try:
34
+ img = Image.open(file)
35
+ batch_images.append(img)
36
+ except:
37
+ continue
38
+
39
+ if not batch_images:
40
+ return []
41
+ return batch_images
42
+
43
+
44
+ @spaces.GPU
45
+ def generate_images(
46
+ prompt: str,
47
+ width: int = 512,
48
+ height: int = 512,
49
+ num_images: int = 1,
50
+ batch_size: int = 1,
51
+ hires_fix: bool = False,
52
+ adetailer: bool = False,
53
+ enhance_prompt: bool = False,
54
+ img2img_enabled: bool = False,
55
+ img2img_image: str = None,
56
+ stable_fast: bool = False,
57
+ reuse_seed: bool = False,
58
+ flux_enabled: bool = False,
59
+ prio_speed: bool = False,
60
+ realistic_model: bool = False,
61
+ multiscale_enabled: bool = True,
62
+ multiscale_intermittent: bool = False,
63
+ multiscale_factor: float = 0.5,
64
+ multiscale_fullres_start: int = 3,
65
+ multiscale_fullres_end: int = 8,
66
+ keep_models_loaded: bool = True,
67
+ progress=gr.Progress(),
68
+ ):
69
+ """Generate images using the LightDiffusion pipeline"""
70
+ try:
71
+ # Set model persistence preference
72
+ from modules.Device.ModelCache import set_keep_models_loaded
73
+
74
+ set_keep_models_loaded(keep_models_loaded)
75
+
76
+ if img2img_enabled and img2img_image is not None:
77
+ # Convert numpy array to PIL Image
78
+ if isinstance(img2img_image, np.ndarray):
79
+ img_pil = Image.fromarray(img2img_image)
80
+ img_pil.save("temp_img2img.png")
81
+ prompt = "temp_img2img.png"
82
+
83
+ # Run pipeline and capture saved images
84
+ with torch.inference_mode():
85
+ pipeline(
86
+ prompt=prompt,
87
+ w=width,
88
+ h=height,
89
+ number=num_images,
90
+ batch=batch_size,
91
+ hires_fix=hires_fix,
92
+ adetailer=adetailer,
93
+ enhance_prompt=enhance_prompt,
94
+ img2img=img2img_enabled,
95
+ stable_fast=stable_fast,
96
+ reuse_seed=reuse_seed,
97
+ flux_enabled=flux_enabled,
98
+ prio_speed=prio_speed,
99
+ autohdr=True,
100
+ realistic_model=realistic_model,
101
+ enable_multiscale=multiscale_enabled,
102
+ multiscale_intermittent_fullres=multiscale_intermittent,
103
+ multiscale_factor=multiscale_factor,
104
+ multiscale_fullres_start=multiscale_fullres_start,
105
+ multiscale_fullres_end=multiscale_fullres_end,
106
+ )
107
+
108
+ # Clean up temporary file if it exists
109
+ if os.path.exists("temp_img2img.png"):
110
+ os.remove("temp_img2img.png")
111
+
112
+ return load_generated_images()
113
+
114
+ except Exception:
115
+ import traceback
116
+
117
+ print(traceback.format_exc())
118
+ # Clean up temporary file if it exists
119
+ if os.path.exists("temp_img2img.png"):
120
+ os.remove("temp_img2img.png")
121
+ return [Image.new("RGB", (512, 512), color="black")]
122
+
123
+
124
+ def get_vram_info():
125
+ """Get VRAM usage information"""
126
+ try:
127
+ from modules.Device.ModelCache import get_memory_info
128
+
129
+ info = get_memory_info()
130
+ return f"""
131
+ **VRAM Usage:**
132
+ - Total: {info["total_vram"]:.1f} GB
133
+ - Used: {info["used_vram"]:.1f} GB
134
+ - Free: {info["free_vram"]:.1f} GB
135
+ - Keep Models Loaded: {info["keep_loaded"]}
136
+ - Has Cached Checkpoint: {info["has_cached_checkpoint"]}
137
+ """
138
+ except Exception as e:
139
+ return f"Error getting VRAM info: {e}"
140
+
141
+
142
+ def clear_model_cache_ui():
143
+ """Clear model cache from UI"""
144
+ try:
145
+ from modules.Device.ModelCache import clear_model_cache
146
+
147
+ clear_model_cache()
148
+ return "✅ Model cache cleared successfully!"
149
+ except Exception as e:
150
+ return f" Error clearing cache: {e}"
151
+
152
+
153
+ def apply_multiscale_preset(preset_name):
154
+ """Apply multiscale preset values to the UI components"""
155
+ if preset_name == "None":
156
+ return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
157
+
158
+ try:
159
+ from modules.sample.multiscale_presets import get_preset_parameters
160
+
161
+ params = get_preset_parameters(preset_name)
162
+
163
+ return (
164
+ gr.update(value=params["enable_multiscale"]),
165
+ gr.update(value=params["multiscale_factor"]),
166
+ gr.update(value=params["multiscale_fullres_start"]),
167
+ gr.update(value=params["multiscale_fullres_end"]),
168
+ gr.update(value=params["multiscale_intermittent_fullres"]),
169
+ )
170
+ except Exception as e:
171
+ print(f"Error applying preset {preset_name}: {e}")
172
+ return gr.update(), gr.update(), gr.update(), gr.update(), gr.update()
173
+
174
+
175
+ # Create Gradio interface
176
+ with gr.Blocks(title="LightDiffusion Web UI") as demo:
177
+ gr.Markdown("# LightDiffusion Web UI")
178
+ gr.Markdown("Generate AI images using LightDiffusion")
179
+ gr.Markdown(
180
+ "This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next"
181
+ )
182
+
183
+ with gr.Row():
184
+ with gr.Column():
185
+ # Input components
186
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
187
+
188
+ with gr.Row():
189
+ width = gr.Slider(
190
+ minimum=64, maximum=2048, value=512, step=64, label="Width"
191
+ )
192
+ height = gr.Slider(
193
+ minimum=64, maximum=2048, value=512, step=64, label="Height"
194
+ )
195
+
196
+ with gr.Row():
197
+ num_images = gr.Slider(
198
+ minimum=1, maximum=10, value=1, step=1, label="Number of Images"
199
+ )
200
+ batch_size = gr.Slider(
201
+ minimum=1, maximum=4, value=1, step=1, label="Batch Size"
202
+ )
203
+
204
+ with gr.Row():
205
+ hires_fix = gr.Checkbox(label="HiRes Fix")
206
+ adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
207
+ enhance_prompt = gr.Checkbox(label="Enhance Prompt")
208
+ stable_fast = gr.Checkbox(label="Stable Fast Mode")
209
+
210
+ with gr.Row():
211
+ reuse_seed = gr.Checkbox(label="Reuse Seed")
212
+ flux_enabled = gr.Checkbox(label="Flux Mode")
213
+ prio_speed = gr.Checkbox(label="Prioritize Speed")
214
+ realistic_model = gr.Checkbox(label="Realistic Model")
215
+
216
+ with gr.Row():
217
+ multiscale_enabled = gr.Checkbox(
218
+ label="Multi-Scale Diffusion", value=True
219
+ )
220
+ img2img_enabled = gr.Checkbox(label="Image to Image Mode")
221
+ keep_models_loaded = gr.Checkbox(
222
+ label="Keep Models in VRAM",
223
+ value=True,
224
+ info="Keep models loaded for instant reuse (faster but uses more VRAM)",
225
+ )
226
+
227
+ img2img_image = gr.Image(label="Input Image for img2img", visible=False)
228
+
229
+ # Multi-scale preset selection
230
+ with gr.Row():
231
+ multiscale_preset = gr.Dropdown(
232
+ label="Multi-Scale Preset",
233
+ choices=["None", "quality", "performance", "balanced", "disabled"],
234
+ value="None",
235
+ info="Select a preset to automatically configure multi-scale settings",
236
+ )
237
+ multiscale_intermittent = gr.Checkbox(
238
+ label="Intermittent Full-Res",
239
+ value=False,
240
+ info="Enable intermittent full-resolution rendering in low-res region",
241
+ )
242
+
243
+ with gr.Row():
244
+ multiscale_factor = gr.Slider(
245
+ minimum=0.1,
246
+ maximum=1.0,
247
+ value=0.5,
248
+ step=0.1,
249
+ label="Multi-Scale Factor",
250
+ )
251
+ multiscale_fullres_start = gr.Slider(
252
+ minimum=0, maximum=10, value=3, step=1, label="Full-Res Start Steps"
253
+ )
254
+ multiscale_fullres_end = gr.Slider(
255
+ minimum=0, maximum=20, value=8, step=1, label="Full-Res End Steps"
256
+ )
257
+
258
+ # Make input image visible only when img2img is enabled
259
+ img2img_enabled.change(
260
+ fn=lambda x: gr.update(visible=x),
261
+ inputs=[img2img_enabled],
262
+ outputs=[img2img_image],
263
+ )
264
+
265
+ # Handle preset changes
266
+ multiscale_preset.change(
267
+ fn=apply_multiscale_preset,
268
+ inputs=[multiscale_preset],
269
+ outputs=[
270
+ multiscale_enabled,
271
+ multiscale_factor,
272
+ multiscale_fullres_start,
273
+ multiscale_fullres_end,
274
+ multiscale_intermittent,
275
+ ],
276
+ )
277
+
278
+ generate_btn = gr.Button("Generate")
279
+
280
+ # Model Cache Management
281
+ with gr.Accordion("Model Cache Management", open=False):
282
+ with gr.Row():
283
+ vram_info_btn = gr.Button("🔍 Check VRAM Usage")
284
+ clear_cache_btn = gr.Button("🗑️ Clear Model Cache")
285
+ vram_info_display = gr.Markdown("")
286
+ cache_status_display = gr.Markdown("")
287
+
288
+ # Output gallery
289
+ gallery = gr.Gallery(
290
+ label="Generated Images",
291
+ show_label=True,
292
+ elem_id="gallery",
293
+ columns=[2],
294
+ rows=[2],
295
+ object_fit="contain",
296
+ height="auto",
297
+ )
298
+
299
+ # Connect generate button to pipeline
300
+ generate_btn.click(
301
+ fn=generate_images,
302
+ inputs=[
303
+ prompt,
304
+ width,
305
+ height,
306
+ num_images,
307
+ batch_size,
308
+ hires_fix,
309
+ adetailer,
310
+ enhance_prompt,
311
+ img2img_enabled,
312
+ img2img_image,
313
+ stable_fast,
314
+ reuse_seed,
315
+ flux_enabled,
316
+ prio_speed,
317
+ realistic_model,
318
+ multiscale_enabled,
319
+ multiscale_intermittent,
320
+ multiscale_factor,
321
+ multiscale_fullres_start,
322
+ multiscale_fullres_end,
323
+ keep_models_loaded,
324
+ ],
325
+ outputs=gallery,
326
+ )
327
+
328
+ # Connect VRAM info and cache management buttons
329
+ vram_info_btn.click(
330
+ fn=get_vram_info,
331
+ outputs=vram_info_display,
332
+ )
333
+
334
+ clear_cache_btn.click(
335
+ fn=clear_model_cache_ui,
336
+ outputs=cache_status_display,
337
+ )
338
+
339
+
340
+ def is_huggingface_space():
341
+ return "SPACE_ID" in os.environ
342
+
343
+
344
+ def is_docker_environment():
345
+ return "GRADIO_SERVER_PORT" in os.environ and "GRADIO_SERVER_NAME" in os.environ
346
+
347
+
348
+ # For local testing
349
+ if __name__ == "__main__":
350
+ if is_huggingface_space():
351
+ demo.launch(
352
+ debug=False,
353
+ server_name="0.0.0.0",
354
+ server_port=7860, # Standard HF Spaces port
355
+ )
356
+ elif is_docker_environment():
357
+ # Docker environment - use environment variables
358
+ server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
359
+ server_port = int(os.environ.get("GRADIO_SERVER_PORT", 7860))
360
+ demo.launch(
361
+ debug=False,
362
+ server_name=server_name,
363
+ server_port=server_port,
364
+ )
365
+ else:
366
+ demo.launch(
367
+ server_name="0.0.0.0",
368
+ server_port=8000,
369
+ auth=None,
370
+ share=True, # Only enable sharing locally
371
+ debug=True,
372
+ )
docker-compose.yml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ lightdiffusion:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
6
+ ports:
7
+ - "7860:7860"
8
+ volumes:
9
+ # Mount output directory to persist generated images
10
+ - ./output:/app/_internal/output
11
+ # Mount checkpoints directory for model files
12
+ - ./checkpoints:/app/_internal/checkpoints
13
+ # Mount other model directories
14
+ - ./loras:/app/_internal/loras
15
+ - ./embeddings:/app/_internal/embeddings
16
+ - ./ESRGAN:/app/_internal/ESRGAN
17
+ - ./yolos:/app/_internal/yolos
18
+ environment:
19
+ - GRADIO_SERVER_NAME=0.0.0.0
20
+ - GRADIO_SERVER_PORT=7860
21
+ - CUDA_VISIBLE_DEVICES=0
22
+ deploy:
23
+ resources:
24
+ reservations:
25
+ devices:
26
+ - driver: nvidia
27
+ count: 1
28
+ capabilities: [gpu]
29
+ restart: unless-stopped
30
+ stdin_open: true
31
+ tty: true
modules/Device/ModelCache.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model Persistence Manager for LightDiffusion
3
+ Keeps models loaded in VRAM for instant reuse between generations
4
+ """
5
+
6
+ from typing import Dict, Optional, Any, Tuple, List
7
+ import logging
8
+ from modules.Device import Device
9
+
10
+
11
+ class ModelCache:
12
+ """Global model cache to keep models loaded in VRAM"""
13
+
14
+ def __init__(self):
15
+ self._cached_models: Dict[str, Any] = {}
16
+ self._cached_clip: Optional[Any] = None
17
+ self._cached_vae: Optional[Any] = None
18
+ self._cached_model_patcher: Optional[Any] = None
19
+ self._cached_conditions: Dict[str, Any] = {}
20
+ self._last_checkpoint_path: Optional[str] = None
21
+ self._keep_models_loaded: bool = True
22
+ self._loaded_models_list: List[Any] = []
23
+
24
+ def set_keep_models_loaded(self, keep_loaded: bool) -> None:
25
+ """Enable or disable keeping models loaded in VRAM"""
26
+ self._keep_models_loaded = keep_loaded
27
+ if not keep_loaded:
28
+ self.clear_cache()
29
+
30
+ def get_keep_models_loaded(self) -> bool:
31
+ """Check if models should be kept loaded"""
32
+ return self._keep_models_loaded
33
+
34
+ def cache_checkpoint(
35
+ self, checkpoint_path: str, model_patcher: Any, clip: Any, vae: Any
36
+ ) -> None:
37
+ """Cache a loaded checkpoint"""
38
+ if not self._keep_models_loaded:
39
+ return
40
+
41
+ self._last_checkpoint_path = checkpoint_path
42
+ self._cached_model_patcher = model_patcher
43
+ self._cached_clip = clip
44
+ self._cached_vae = vae
45
+ logging.info(f"Cached checkpoint: {checkpoint_path}")
46
+
47
+ def get_cached_checkpoint(
48
+ self, checkpoint_path: str
49
+ ) -> Optional[Tuple[Any, Any, Any]]:
50
+ """Get cached checkpoint if available"""
51
+ if not self._keep_models_loaded:
52
+ return None
53
+
54
+ if (
55
+ self._last_checkpoint_path == checkpoint_path
56
+ and self._cached_model_patcher is not None
57
+ and self._cached_clip is not None
58
+ and self._cached_vae is not None
59
+ ):
60
+ logging.info(f"Using cached checkpoint: {checkpoint_path}")
61
+ return self._cached_model_patcher, self._cached_clip, self._cached_vae
62
+ return None
63
+
64
+ def cache_sampling_models(self, models: List[Any]) -> None:
65
+ """Cache models used during sampling"""
66
+ if not self._keep_models_loaded:
67
+ return
68
+
69
+ self._loaded_models_list = models.copy()
70
+
71
+ def get_cached_sampling_models(self) -> List[Any]:
72
+ """Get cached sampling models"""
73
+ if not self._keep_models_loaded:
74
+ return []
75
+ return self._loaded_models_list
76
+
77
+ def prevent_model_cleanup(self, conds: Dict[str, Any], models: List[Any]) -> None:
78
+ """Prevent models from being cleaned up if caching is enabled"""
79
+ if not self._keep_models_loaded:
80
+ # Original cleanup behavior
81
+ from modules.cond import cond_util
82
+
83
+ cond_util.cleanup_additional_models(models)
84
+
85
+ control_cleanup = []
86
+ for k in conds:
87
+ control_cleanup += cond_util.get_models_from_cond(conds[k], "control")
88
+ cond_util.cleanup_additional_models(set(control_cleanup))
89
+ else:
90
+ # Keep models loaded - only cleanup control models that aren't main models
91
+ control_cleanup = []
92
+ for k in conds:
93
+ from modules.cond import cond_util
94
+
95
+ control_cleanup += cond_util.get_models_from_cond(conds[k], "control")
96
+
97
+ # Only cleanup control models, not the main models
98
+ from modules.cond import cond_util
99
+
100
+ cond_util.cleanup_additional_models(set(control_cleanup))
101
+ logging.info("Kept main models loaded in VRAM for reuse")
102
+
103
+ def clear_cache(self) -> None:
104
+ """Clear all cached models"""
105
+ if self._cached_model_patcher is not None:
106
+ try:
107
+ # Properly unload the cached models
108
+ if hasattr(self._cached_model_patcher, "model_unload"):
109
+ self._cached_model_patcher.model_unload()
110
+ except Exception as e:
111
+ logging.warning(f"Error unloading cached model: {e}")
112
+
113
+ self._cached_models.clear()
114
+ self._cached_clip = None
115
+ self._cached_vae = None
116
+ self._cached_model_patcher = None
117
+ self._cached_conditions.clear()
118
+ self._last_checkpoint_path = None
119
+ self._loaded_models_list.clear()
120
+
121
+ # Force cleanup
122
+ Device.cleanup_models(keep_clone_weights_loaded=False)
123
+ Device.soft_empty_cache(force=True)
124
+ logging.info("Cleared model cache and freed VRAM")
125
+
126
+ def get_memory_info(self) -> Dict[str, Any]:
127
+ """Get memory usage information"""
128
+ device = Device.get_torch_device()
129
+ total_mem = Device.get_total_memory(device)
130
+ free_mem = Device.get_free_memory(device)
131
+ used_mem = total_mem - free_mem
132
+
133
+ return {
134
+ "total_vram": total_mem / (1024 * 1024 * 1024), # GB
135
+ "used_vram": used_mem / (1024 * 1024 * 1024), # GB
136
+ "free_vram": free_mem / (1024 * 1024 * 1024), # GB
137
+ "cached_models": len(self._cached_models),
138
+ "keep_loaded": self._keep_models_loaded,
139
+ "has_cached_checkpoint": self._cached_model_patcher is not None,
140
+ }
141
+
142
+
143
+ # Global model cache instance
144
+ model_cache = ModelCache()
145
+
146
+
147
+ def get_model_cache() -> ModelCache:
148
+ """Get the global model cache instance"""
149
+ return model_cache
150
+
151
+
152
+ def set_keep_models_loaded(keep_loaded: bool) -> None:
153
+ """Global function to enable/disable model persistence"""
154
+ model_cache.set_keep_models_loaded(keep_loaded)
155
+
156
+
157
+ def get_keep_models_loaded() -> bool:
158
+ """Global function to check if models should be kept loaded"""
159
+ return model_cache.get_keep_models_loaded()
160
+
161
+
162
+ def clear_model_cache() -> None:
163
+ """Global function to clear model cache"""
164
+ model_cache.clear_cache()
165
+
166
+
167
+ def get_memory_info() -> Dict[str, Any]:
168
+ """Global function to get memory info"""
169
+ return model_cache.get_memory_info()
modules/FileManaging/Loader.py CHANGED
@@ -128,6 +128,19 @@ class CheckpointLoaderSimple:
128
  - `tuple`: The model patcher, CLIP, and VAE.
129
  """
130
  ckpt_path = f"{ckpt_name}"
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  out = load_checkpoint_guess_config(
132
  ckpt_path,
133
  output_vae=output_vae,
@@ -135,4 +148,9 @@ class CheckpointLoaderSimple:
135
  embedding_directory="./_internal/embeddings/",
136
  )
137
  print("loading", ckpt_path)
 
 
 
 
 
138
  return out[:3]
 
128
  - `tuple`: The model patcher, CLIP, and VAE.
129
  """
130
  ckpt_path = f"{ckpt_name}"
131
+
132
+ # Check if model is already cached
133
+ from modules.Device.ModelCache import get_model_cache
134
+
135
+ cache = get_model_cache()
136
+ cached_result = cache.get_cached_checkpoint(ckpt_path)
137
+
138
+ if cached_result is not None:
139
+ model_patcher, clip, vae = cached_result
140
+ print("using cached", ckpt_path)
141
+ return (model_patcher, clip, vae)
142
+
143
+ # Load normally if not cached
144
  out = load_checkpoint_guess_config(
145
  ckpt_path,
146
  output_vae=output_vae,
 
148
  embedding_directory="./_internal/embeddings/",
149
  )
150
  print("loading", ckpt_path)
151
+
152
+ # Cache the loaded checkpoint
153
+ if len(out) >= 3:
154
+ cache.cache_checkpoint(ckpt_path, out[0], out[1], out[2])
155
+
156
  return out[:3]
modules/UltimateSDUpscale/image_util.py CHANGED
@@ -140,9 +140,41 @@ def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image
140
  #### Returns:
141
  - `Image.Image`: The converted PIL image.
142
  """
143
- img_tensor = img_tensor[batch_index].unsqueeze(0)
144
- i = 255.0 * img_tensor.cpu().numpy()
145
- img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  return img
147
 
148
 
@@ -155,9 +187,20 @@ def pil_to_tensor(image: Image.Image) -> torch.Tensor:
155
  #### Returns:
156
  - `torch.Tensor`: The converted tensor.
157
  """
158
- image = np.array(image).astype(np.float32) / 255.0
159
- image = torch.from_numpy(image).unsqueeze(0)
160
- return image
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
@@ -260,6 +303,6 @@ def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, t
260
  cropped = []
261
  for emb, x in cond:
262
  cond_dict = x.copy()
263
- n = [emb, cond_dict]
264
- cropped.append(n)
265
- return cropped
 
140
  #### Returns:
141
  - `Image.Image`: The converted PIL image.
142
  """
143
+ # Get the tensor for the specified batch index
144
+ tensor = img_tensor[batch_index]
145
+
146
+ # Handle different tensor dimensions
147
+ # The upscaler outputs in [H, W, C] format after movedim(-3, -1)
148
+ if tensor.dim() == 3: # [H, W, C] - already in correct format
149
+ pass
150
+ elif tensor.dim() == 2: # [H, W] - grayscale
151
+ pass
152
+ else:
153
+ raise ValueError(f"Unexpected tensor dimensions: {tensor.shape}")
154
+
155
+ # Clamp values to valid range [0, 1] and convert to numpy
156
+ tensor = torch.clamp(tensor, 0.0, 1.0)
157
+ numpy_array = (tensor.cpu().numpy() * 255.0).astype(np.uint8)
158
+
159
+ # Handle different channel configurations
160
+ if numpy_array.ndim == 3:
161
+ if numpy_array.shape[2] == 3:
162
+ img = Image.fromarray(numpy_array, 'RGB')
163
+ elif numpy_array.shape[2] == 1:
164
+ img = Image.fromarray(numpy_array.squeeze(axis=2), 'L')
165
+ elif numpy_array.shape[2] == 4:
166
+ img = Image.fromarray(numpy_array, 'RGBA')
167
+ else:
168
+ # Fallback: take first 3 channels if more than 3, or convert single channel to grayscale
169
+ if numpy_array.shape[2] >= 3:
170
+ img = Image.fromarray(numpy_array[:, :, :3], 'RGB')
171
+ else:
172
+ img = Image.fromarray(numpy_array.squeeze(axis=2), 'L')
173
+ elif numpy_array.ndim == 2:
174
+ img = Image.fromarray(numpy_array, 'L')
175
+ else:
176
+ raise ValueError(f"Cannot convert array with shape {numpy_array.shape} to PIL image")
177
+
178
  return img
179
 
180
 
 
187
  #### Returns:
188
  - `torch.Tensor`: The converted tensor.
189
  """
190
+ # Convert RGBA to RGB if necessary (upscaler models expect 3 channels)
191
+ if image.mode == 'RGBA':
192
+ # Create a white background for transparency
193
+ background = Image.new('RGB', image.size, (255, 255, 255))
194
+ background.paste(image, mask=image.split()[-1]) # Use alpha channel as mask
195
+ image = background
196
+ elif image.mode != 'RGB':
197
+ image = image.convert('RGB')
198
+ # Convert to numpy array and normalize
199
+ image_array = np.array(image).astype(np.float32) / 255.0
200
+
201
+ # Convert to tensor and add batch dimension: [H, W, C] -> [1, H, W, C]
202
+ tensor = torch.from_numpy(image_array).unsqueeze(0)
203
+ return tensor
204
 
205
 
206
  def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
 
303
  cropped = []
304
  for emb, x in cond:
305
  cond_dict = x.copy()
306
+ cond_entry = [emb, cond_dict]
307
+ cropped.append(cond_entry)
308
+ return cropped
modules/sample/CFG.py CHANGED
@@ -347,7 +347,10 @@ class CFGGuider:
347
  pipeline=pipeline,
348
  )
349
 
350
- cond_util.cleanup_models(self.conds, self.loaded_models)
 
 
 
351
  del self.inner_model
352
  del self.conds
353
  del self.loaded_models
 
347
  pipeline=pipeline,
348
  )
349
 
350
+ # Use model cache to prevent cleanup if models should stay loaded
351
+ from modules.Device.ModelCache import get_model_cache
352
+ get_model_cache().prevent_model_cleanup(self.conds, self.loaded_models)
353
+
354
  del self.inner_model
355
  del self.conds
356
  del self.loaded_models
modules/sample/multiscale_presets.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-scale diffusion presets for quality and performance optimization.
3
+
4
+ This module provides predefined configurations for multi-scale diffusion that balance
5
+ quality and performance based on different use cases.
6
+ """
7
+
8
+ from typing import NamedTuple, Dict, Any
9
+
10
+
11
+ class MultiscalePreset(NamedTuple):
12
+ """#### Class representing a multi-scale diffusion preset.
13
+
14
+ #### Args:
15
+ - `name` (str): The name of the preset.
16
+ - `description` (str): Description of the preset's purpose.
17
+ - `enable_multiscale` (bool): Whether multi-scale diffusion is enabled.
18
+ - `multiscale_factor` (float): Scale factor for intermediate steps (0.1-1.0).
19
+ - `multiscale_fullres_start` (int): Number of first steps at full resolution.
20
+ - `multiscale_fullres_end` (int): Number of last steps at full resolution.
21
+ - `multiscale_intermittent_fullres` (bool): Whether to use intermittent full-res.
22
+ """
23
+
24
+ name: str
25
+ description: str
26
+ enable_multiscale: bool
27
+ multiscale_factor: float
28
+ multiscale_fullres_start: int
29
+ multiscale_fullres_end: int
30
+ multiscale_intermittent_fullres: bool
31
+
32
+ @property
33
+ def as_dict(self) -> Dict[str, Any]:
34
+ """#### Convert the preset to a dictionary.
35
+
36
+ #### Returns:
37
+ - `Dict[str, Any]`: The preset parameters as a dictionary.
38
+ """
39
+ return {
40
+ "enable_multiscale": self.enable_multiscale,
41
+ "multiscale_factor": self.multiscale_factor,
42
+ "multiscale_fullres_start": self.multiscale_fullres_start,
43
+ "multiscale_fullres_end": self.multiscale_fullres_end,
44
+ "multiscale_intermittent_fullres": self.multiscale_intermittent_fullres,
45
+ }
46
+
47
+
48
+ # Predefined multi-scale diffusion presets
49
+ MULTISCALE_PRESETS = {
50
+ "quality": MultiscalePreset(
51
+ name="Quality",
52
+ description="High quality preset with intermittent full-res for best image quality",
53
+ enable_multiscale=True,
54
+ multiscale_factor=0.5,
55
+ multiscale_fullres_start=10,
56
+ multiscale_fullres_end=8,
57
+ multiscale_intermittent_fullres=True,
58
+ ),
59
+ "performance": MultiscalePreset(
60
+ name="Performance",
61
+ description="Performance-oriented preset with aggressive downscaling for maximum speed",
62
+ enable_multiscale=True,
63
+ multiscale_factor=0.25,
64
+ multiscale_fullres_start=5,
65
+ multiscale_fullres_end=8,
66
+ multiscale_intermittent_fullres=True,
67
+ ),
68
+ "balanced": MultiscalePreset(
69
+ name="Balanced",
70
+ description="Balanced preset offering good quality and performance",
71
+ enable_multiscale=True,
72
+ multiscale_factor=0.5,
73
+ multiscale_fullres_start=5,
74
+ multiscale_fullres_end=8,
75
+ multiscale_intermittent_fullres=True,
76
+ ),
77
+ "disabled": MultiscalePreset(
78
+ name="Disabled",
79
+ description="Multi-scale diffusion disabled - full resolution throughout",
80
+ enable_multiscale=False,
81
+ multiscale_factor=1.0,
82
+ multiscale_fullres_start=0,
83
+ multiscale_fullres_end=0,
84
+ multiscale_intermittent_fullres=False,
85
+ ),
86
+ }
87
+
88
+
89
+ def get_preset(preset_name: str) -> MultiscalePreset:
90
+ """#### Get a multi-scale diffusion preset by name.
91
+
92
+ #### Args:
93
+ - `preset_name` (str): The name of the preset to retrieve.
94
+
95
+ #### Returns:
96
+ - `MultiscalePreset`: The requested preset.
97
+
98
+ #### Raises:
99
+ - `KeyError`: If the preset name is not found.
100
+ """
101
+ if preset_name not in MULTISCALE_PRESETS:
102
+ available_presets = ", ".join(MULTISCALE_PRESETS.keys())
103
+ raise KeyError(
104
+ f"Preset '{preset_name}' not found. Available presets: {available_presets}"
105
+ )
106
+
107
+ return MULTISCALE_PRESETS[preset_name]
108
+
109
+
110
+ def get_preset_parameters(preset_name: str) -> Dict[str, Any]:
111
+ """#### Get multi-scale diffusion parameters for a preset.
112
+
113
+ #### Args:
114
+ - `preset_name` (str): The name of the preset.
115
+
116
+ #### Returns:
117
+ - `Dict[str, Any]`: The preset parameters.
118
+ """
119
+ return get_preset(preset_name).as_dict
120
+
121
+
122
+ def list_presets() -> Dict[str, str]:
123
+ """#### List all available multi-scale diffusion presets.
124
+
125
+ #### Returns:
126
+ - `Dict[str, str]`: Dictionary mapping preset names to descriptions.
127
+ """
128
+ return {name: preset.description for name, preset in MULTISCALE_PRESETS.items()}
129
+
130
+
131
+ def apply_preset_to_kwargs(preset_name: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
132
+ """#### Apply a multi-scale preset to keyword arguments.
133
+
134
+ #### Args:
135
+ - `preset_name` (str): The name of the preset to apply.
136
+ - `kwargs` (Dict[str, Any]): Existing keyword arguments.
137
+
138
+ #### Returns:
139
+ - `Dict[str, Any]`: Updated keyword arguments with preset parameters.
140
+ """
141
+ preset_params = get_preset_parameters(preset_name)
142
+ kwargs.update(preset_params)
143
+ return kwargs
modules/sample/samplers.py CHANGED
@@ -21,6 +21,12 @@ def sample_euler_ancestral(
21
  s_noise=1.0,
22
  noise_sampler=None,
23
  pipeline=False,
 
 
 
 
 
 
24
  ):
25
  # Pre-calculate common values
26
  device = x.device
@@ -31,6 +37,90 @@ def sample_euler_ancestral(
31
  from modules.AutoEncoders import taesd
32
  from modules.user import app_instance
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Pre-allocate tensors and init noise sampler
35
  s_in = torch.ones((x.shape[0],), device=device)
36
  noise_sampler = (
@@ -50,8 +140,24 @@ def sample_euler_ancestral(
50
  if not pipeline:
51
  app_instance.app.progress.set(i / (len(sigmas) - 1))
52
 
53
- # Combined model inference and step calculation
54
- denoised = model(x, sigmas[i] * s_in, **(extra_args or {}))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  sigma_down, sigma_up = sampling_util.get_ancestral_step(
56
  sigmas[i], sigmas[i + 1], eta=eta
57
  )
@@ -83,6 +189,12 @@ def sample_euler(
83
  s_tmax=float("inf"),
84
  s_noise=1.0,
85
  pipeline=False,
 
 
 
 
 
 
86
  ):
87
  # Pre-calculate common values
88
  device = x.device
@@ -93,6 +205,90 @@ def sample_euler(
93
  from modules.AutoEncoders import taesd
94
  from modules.user import app_instance
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Pre-allocate tensors and cache parameters
97
  s_in = torch.ones((x.shape[0],), device=device)
98
  gamma_max = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_churn > 0 else 0
@@ -108,6 +304,9 @@ def sample_euler(
108
  if not pipeline:
109
  app_instance.app.progress.set(i / (len(sigmas) - 1))
110
 
 
 
 
111
  # Combined sigma calculation and update
112
  sigma_hat = (
113
  sigmas[i] * (1 + (gamma_max if s_tmin <= sigmas[i] <= s_tmax else 0))
@@ -121,8 +320,21 @@ def sample_euler(
121
  + torch.randn_like(x) * s_noise * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
122
  )
123
 
124
- # Fused model inference and update step
125
- denoised = model(x, sigma_hat * s_in, **(extra_args or {}))
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  x = x + util.to_d(x, sigma_hat, denoised) * (sigmas[i + 1] - sigma_hat)
127
 
128
  if callback is not None:
@@ -588,6 +800,12 @@ def sample_dpmpp_2m_cfgpp(
588
  cfg_x0_scale=1.0,
589
  cfg_s_scale=1.0,
590
  cfg_min=1.0,
 
 
 
 
 
 
591
  ):
592
  """DPM-Solver++(2M) sampler with CFG++ optimizations"""
593
  # Pre-calculate common values and setup
@@ -599,11 +817,94 @@ def sample_dpmpp_2m_cfgpp(
599
  from modules.AutoEncoders import taesd
600
  from modules.user import app_instance
601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  # Pre-allocate tensors and transform sigmas
603
  s_in = torch.ones((x.shape[0],), device=device)
604
  t_steps = -torch.log(sigmas) # Fused calculation
605
  n_steps = len(sigmas) - 1
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  # Pre-calculate all needed values in one go
608
  sigma_steps = torch.exp(-t_steps) # Fused calculation
609
  ratios = sigma_steps[1:] / sigma_steps[:-1]
@@ -639,15 +940,31 @@ def sample_dpmpp_2m_cfgpp(
639
  if not pipeline:
640
  app_instance.app.progress.set(i / n_steps)
641
 
 
 
 
 
 
 
 
 
 
 
 
642
  # Use pre-calculated CFG scale
643
  current_cfg = cfg_values[i]
644
 
645
- # Fused model inference and update calculations
646
- denoised = model(x, sigmas[i] * s_in, **extra_args)
647
  uncond_denoised = extra_args.get("model_options", {}).get(
648
  "sampler_post_cfg_function", []
649
  )[-1]({"denoised": denoised, "uncond_denoised": None})
650
 
 
 
 
 
 
651
  if callback is not None:
652
  callback(
653
  {
@@ -713,8 +1030,14 @@ def sample_dpmpp_sde_cfgpp(
713
  cfg_x0_scale=1.0,
714
  cfg_s_scale=1.0,
715
  cfg_min=1.0,
 
 
 
 
 
 
716
  ):
717
- """DPM-Solver++ (SDE) with CFG++ optimizations"""
718
  # Pre-calculate common values
719
  device = x.device
720
  global disable_gui
@@ -728,9 +1051,91 @@ def sample_dpmpp_sde_cfgpp(
728
  if len(sigmas) <= 1:
729
  return x
730
 
731
- # Pre-allocate tensors and values
732
- s_in = torch.ones((x.shape[0],), device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  n_steps = len(sigmas) - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
734
  extra_args = {} if extra_args is None else extra_args
735
 
736
  # CFG++ scheduling
@@ -751,7 +1156,7 @@ def sample_dpmpp_sde_cfgpp(
751
  x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
752
  )
753
 
754
- # Track previous predictions
755
  old_denoised = None
756
  old_uncond_denoised = None
757
 
@@ -776,15 +1181,31 @@ def sample_dpmpp_sde_cfgpp(
776
  if not pipeline:
777
  app_instance.app.progress.set(i / n_steps)
778
 
 
 
 
 
 
 
 
 
 
 
 
779
  # Get current CFG scale
780
  current_cfg = get_cfg_scale(i)
781
 
782
- # Model inference
783
- denoised = model(x, sigmas[i] * s_in, **extra_args)
784
  uncond_denoised = extra_args.get("model_options", {}).get(
785
  "sampler_post_cfg_function", []
786
  )[-1]({"denoised": denoised, "uncond_denoised": None})
787
 
 
 
 
 
 
788
  if callback is not None:
789
  callback(
790
  {
@@ -815,10 +1236,10 @@ def sample_dpmpp_sde_cfgpp(
815
  uncond_denoised + (denoised - uncond_denoised) * current_cfg
816
  )
817
  else:
818
- # CFG++ with momentum
819
  x0_coeff = cfg_x0_scale * current_cfg
820
 
821
- # Calculate momentum terms
822
  h_ratio = (t - s_) / (2 * (t - t_next))
823
  momentum = (1 + h_ratio) * denoised - h_ratio * old_denoised
824
  uncond_momentum = (
@@ -828,18 +1249,35 @@ def sample_dpmpp_sde_cfgpp(
828
  # Combine with CFG++ scaling
829
  cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
830
 
 
 
831
  x_2 = (
832
  (sigma_fn(s_) / sigma_fn(t)) * x
833
  - (t - s_).expm1() * cfg_denoised
834
- + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
835
  )
836
 
 
 
 
 
 
 
 
 
 
 
837
  # Step 2 inference
838
- denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
839
  uncond_denoised_2 = extra_args.get("model_options", {}).get(
840
  "sampler_post_cfg_function", []
841
  )[-1]({"denoised": denoised_2, "uncond_denoised": None})
842
 
 
 
 
 
 
843
  # Step 2 CFG++ combination
844
  if old_uncond_denoised is None:
845
  cfg_denoised_2 = (
@@ -861,11 +1299,12 @@ def sample_dpmpp_sde_cfgpp(
861
  t_next_ = t_fn(sd)
862
 
863
  # Combined update with both predictions
 
864
  x = (
865
  (sigma_fn(t_next_) / sigma_fn(t)) * x
866
  - (t - t_next_).expm1()
867
  * ((1 - 1 / (2 * r)) * cfg_denoised + (1 / (2 * r)) * cfg_denoised_2)
868
- + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
869
  )
870
 
871
  old_denoised = denoised
 
21
  s_noise=1.0,
22
  noise_sampler=None,
23
  pipeline=False,
24
+ # Multi-scale parameters
25
+ enable_multiscale=True,
26
+ multiscale_factor=0.5,
27
+ multiscale_fullres_start=3,
28
+ multiscale_fullres_end=8,
29
+ multiscale_intermittent_fullres=False,
30
  ):
31
  # Pre-calculate common values
32
  device = x.device
 
37
  from modules.AutoEncoders import taesd
38
  from modules.user import app_instance
39
 
40
+ # Multi-scale setup with validation
41
+ original_shape = x.shape
42
+ batch_size, channels, orig_h, orig_w = original_shape
43
+
44
+ # Validate multi-scale parameters
45
+ if enable_multiscale:
46
+ if not (0.1 <= multiscale_factor <= 1.0):
47
+ print(
48
+ f"Warning: multiscale_factor {multiscale_factor} out of range [0.1, 1.0], disabling multi-scale"
49
+ )
50
+ enable_multiscale = False
51
+ if multiscale_fullres_start < 0 or multiscale_fullres_end < 0:
52
+ print("Warning: Invalid fullres step counts, disabling multi-scale")
53
+ enable_multiscale = False
54
+
55
+ # Calculate scaled dimensions (must be multiples of 8 for VAE compatibility)
56
+ scale_h = (
57
+ int(max(8, ((orig_h * multiscale_factor) // 8) * 8))
58
+ if enable_multiscale
59
+ else orig_h
60
+ )
61
+ scale_w = (
62
+ int(max(8, ((orig_w * multiscale_factor) // 8) * 8))
63
+ if enable_multiscale
64
+ else orig_w
65
+ )
66
+
67
+ # Disable multi-scale for small images or short step counts
68
+ n_steps = len(sigmas) - 1
69
+ multiscale_active = (
70
+ enable_multiscale
71
+ and orig_h > 64
72
+ and orig_w > 64
73
+ and n_steps > (multiscale_fullres_start + multiscale_fullres_end)
74
+ and (scale_h != orig_h or scale_w != orig_w)
75
+ )
76
+
77
+ if enable_multiscale and not multiscale_active:
78
+ print(
79
+ f"Multi-scale disabled: image too small ({orig_h}x{orig_w}) or insufficient steps ({n_steps})"
80
+ )
81
+ elif multiscale_active:
82
+ print(
83
+ f"Multi-scale active: {orig_h}x{orig_w} -> {scale_h}x{scale_w} (factor: {multiscale_factor})"
84
+ )
85
+
86
+ def downscale_tensor(tensor):
87
+ """Downscale tensor using bilinear interpolation"""
88
+ if not multiscale_active or tensor.shape[-2:] == (scale_h, scale_w):
89
+ return tensor
90
+ return torch.nn.functional.interpolate(
91
+ tensor, size=(scale_h, scale_w), mode="bilinear", align_corners=False
92
+ )
93
+
94
+ def upscale_tensor(tensor):
95
+ """Upscale tensor using bilinear interpolation"""
96
+ if not multiscale_active or tensor.shape[-2:] == (orig_h, orig_w):
97
+ return tensor
98
+ return torch.nn.functional.interpolate(
99
+ tensor, size=(orig_h, orig_w), mode="bilinear", align_corners=False
100
+ )
101
+
102
+ def should_use_fullres(step):
103
+ """Determine if this step should use full resolution"""
104
+ if not multiscale_active:
105
+ return True
106
+
107
+ # Always use full resolution for start and end steps
108
+ if step < multiscale_fullres_start or step >= n_steps - multiscale_fullres_end:
109
+ return True
110
+
111
+ # Intermittent full-res: every 2nd step in low-res region if enabled
112
+ if multiscale_intermittent_fullres:
113
+ # Check if we're in the low-res region
114
+ low_res_region_start = multiscale_fullres_start
115
+ low_res_region_end = n_steps - multiscale_fullres_end
116
+ if low_res_region_start <= step < low_res_region_end:
117
+ # Calculate position within low-res region
118
+ relative_step = step - low_res_region_start
119
+ # Use full-res every 2nd step (0, 2, 4, ...)
120
+ return relative_step % 2 == 0
121
+
122
+ return False
123
+
124
  # Pre-allocate tensors and init noise sampler
125
  s_in = torch.ones((x.shape[0],), device=device)
126
  noise_sampler = (
 
140
  if not pipeline:
141
  app_instance.app.progress.set(i / (len(sigmas) - 1))
142
 
143
+ # Determine resolution for this step
144
+ use_fullres = should_use_fullres(i)
145
+
146
+ # Scale input for processing
147
+ if use_fullres:
148
+ x_process = x
149
+ s_in_process = s_in
150
+ else:
151
+ x_process = downscale_tensor(x)
152
+ s_in_process = torch.ones((x_process.shape[0],), device=device)
153
+
154
+ # Model inference at appropriate resolution
155
+ denoised = model(x_process, sigmas[i] * s_in_process, **(extra_args or {}))
156
+
157
+ # Scale predictions back to original resolution if needed
158
+ if not use_fullres:
159
+ denoised = upscale_tensor(denoised)
160
+
161
  sigma_down, sigma_up = sampling_util.get_ancestral_step(
162
  sigmas[i], sigmas[i + 1], eta=eta
163
  )
 
189
  s_tmax=float("inf"),
190
  s_noise=1.0,
191
  pipeline=False,
192
+ # Multi-scale parameters
193
+ enable_multiscale=True,
194
+ multiscale_factor=0.5,
195
+ multiscale_fullres_start=3,
196
+ multiscale_fullres_end=8,
197
+ multiscale_intermittent_fullres=False,
198
  ):
199
  # Pre-calculate common values
200
  device = x.device
 
205
  from modules.AutoEncoders import taesd
206
  from modules.user import app_instance
207
 
208
+ # Multi-scale setup with validation
209
+ original_shape = x.shape
210
+ batch_size, channels, orig_h, orig_w = original_shape
211
+
212
+ # Validate multi-scale parameters
213
+ if enable_multiscale:
214
+ if not (0.1 <= multiscale_factor <= 1.0):
215
+ print(
216
+ f"Warning: multiscale_factor {multiscale_factor} out of range [0.1, 1.0], disabling multi-scale"
217
+ )
218
+ enable_multiscale = False
219
+ if multiscale_fullres_start < 0 or multiscale_fullres_end < 0:
220
+ print("Warning: Invalid fullres step counts, disabling multi-scale")
221
+ enable_multiscale = False
222
+
223
+ # Calculate scaled dimensions (must be multiples of 8 for VAE compatibility)
224
+ scale_h = (
225
+ int(max(8, ((orig_h * multiscale_factor) // 8) * 8))
226
+ if enable_multiscale
227
+ else orig_h
228
+ )
229
+ scale_w = (
230
+ int(max(8, ((orig_w * multiscale_factor) // 8) * 8))
231
+ if enable_multiscale
232
+ else orig_w
233
+ )
234
+
235
+ # Disable multi-scale for small images or short step counts
236
+ n_steps = len(sigmas) - 1
237
+ multiscale_active = (
238
+ enable_multiscale
239
+ and orig_h > 64
240
+ and orig_w > 64
241
+ and n_steps > (multiscale_fullres_start + multiscale_fullres_end)
242
+ and (scale_h != orig_h or scale_w != orig_w)
243
+ )
244
+
245
+ if enable_multiscale and not multiscale_active:
246
+ print(
247
+ f"Multi-scale disabled: image too small ({orig_h}x{orig_w}) or insufficient steps ({n_steps})"
248
+ )
249
+ elif multiscale_active:
250
+ print(
251
+ f"Multi-scale active: {orig_h}x{orig_w} -> {scale_h}x{scale_w} (factor: {multiscale_factor})"
252
+ )
253
+
254
+ def downscale_tensor(tensor):
255
+ """Downscale tensor using bilinear interpolation"""
256
+ if not multiscale_active or tensor.shape[-2:] == (scale_h, scale_w):
257
+ return tensor
258
+ return torch.nn.functional.interpolate(
259
+ tensor, size=(scale_h, scale_w), mode="bilinear", align_corners=False
260
+ )
261
+
262
+ def upscale_tensor(tensor):
263
+ """Upscale tensor using bilinear interpolation"""
264
+ if not multiscale_active or tensor.shape[-2:] == (orig_h, orig_w):
265
+ return tensor
266
+ return torch.nn.functional.interpolate(
267
+ tensor, size=(orig_h, orig_w), mode="bilinear", align_corners=False
268
+ )
269
+
270
+ def should_use_fullres(step):
271
+ """Determine if this step should use full resolution"""
272
+ if not multiscale_active:
273
+ return True
274
+
275
+ # Always use full resolution for start and end steps
276
+ if step < multiscale_fullres_start or step >= n_steps - multiscale_fullres_end:
277
+ return True
278
+
279
+ # Intermittent full-res: every 2nd step in low-res region if enabled
280
+ if multiscale_intermittent_fullres:
281
+ # Check if we're in the low-res region
282
+ low_res_region_start = multiscale_fullres_start
283
+ low_res_region_end = n_steps - multiscale_fullres_end
284
+ if low_res_region_start <= step < low_res_region_end:
285
+ # Calculate position within low-res region
286
+ relative_step = step - low_res_region_start
287
+ # Use full-res every 2nd step (0, 2, 4, ...)
288
+ return relative_step % 2 == 0
289
+
290
+ return False
291
+
292
  # Pre-allocate tensors and cache parameters
293
  s_in = torch.ones((x.shape[0],), device=device)
294
  gamma_max = min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) if s_churn > 0 else 0
 
304
  if not pipeline:
305
  app_instance.app.progress.set(i / (len(sigmas) - 1))
306
 
307
+ # Determine resolution for this step
308
+ use_fullres = should_use_fullres(i)
309
+
310
  # Combined sigma calculation and update
311
  sigma_hat = (
312
  sigmas[i] * (1 + (gamma_max if s_tmin <= sigmas[i] <= s_tmax else 0))
 
320
  + torch.randn_like(x) * s_noise * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5
321
  )
322
 
323
+ # Scale input for processing
324
+ if use_fullres:
325
+ x_process = x
326
+ s_in_process = s_in
327
+ else:
328
+ x_process = downscale_tensor(x)
329
+ s_in_process = torch.ones((x_process.shape[0],), device=device)
330
+
331
+ # Model inference at appropriate resolution
332
+ denoised = model(x_process, sigma_hat * s_in_process, **(extra_args or {}))
333
+
334
+ # Scale predictions back to original resolution if needed
335
+ if not use_fullres:
336
+ denoised = upscale_tensor(denoised)
337
+
338
  x = x + util.to_d(x, sigma_hat, denoised) * (sigmas[i + 1] - sigma_hat)
339
 
340
  if callback is not None:
 
800
  cfg_x0_scale=1.0,
801
  cfg_s_scale=1.0,
802
  cfg_min=1.0,
803
+ # Multi-scale parameters
804
+ enable_multiscale=True,
805
+ multiscale_factor=0.5,
806
+ multiscale_fullres_start=5,
807
+ multiscale_fullres_end=8,
808
+ multiscale_intermittent_fullres=True,
809
  ):
810
  """DPM-Solver++(2M) sampler with CFG++ optimizations"""
811
  # Pre-calculate common values and setup
 
817
  from modules.AutoEncoders import taesd
818
  from modules.user import app_instance
819
 
820
+ # Multi-scale setup with validation
821
+ original_shape = x.shape
822
+ batch_size, channels, orig_h, orig_w = original_shape
823
+
824
+ # Validate multi-scale parameters
825
+ if enable_multiscale:
826
+ if not (0.1 <= multiscale_factor <= 1.0):
827
+ print(
828
+ f"Warning: multiscale_factor {multiscale_factor} out of range [0.1, 1.0], disabling multi-scale"
829
+ )
830
+ enable_multiscale = False
831
+ if multiscale_fullres_start < 0 or multiscale_fullres_end < 0:
832
+ print("Warning: Invalid fullres step counts, disabling multi-scale")
833
+ enable_multiscale = False
834
+
835
+ # Calculate scaled dimensions (must be multiples of 8 for VAE compatibility)
836
+ scale_h = (
837
+ int(max(8, ((orig_h * multiscale_factor) // 8) * 8))
838
+ if enable_multiscale
839
+ else orig_h
840
+ )
841
+ scale_w = (
842
+ int(max(8, ((orig_w * multiscale_factor) // 8) * 8))
843
+ if enable_multiscale
844
+ else orig_w
845
+ )
846
+
847
  # Pre-allocate tensors and transform sigmas
848
  s_in = torch.ones((x.shape[0],), device=device)
849
  t_steps = -torch.log(sigmas) # Fused calculation
850
  n_steps = len(sigmas) - 1
851
 
852
+ # Disable multi-scale for small images or short step counts
853
+ multiscale_active = (
854
+ enable_multiscale
855
+ and orig_h > 64
856
+ and orig_w > 64
857
+ and n_steps > (multiscale_fullres_start + multiscale_fullres_end)
858
+ and (scale_h != orig_h or scale_w != orig_w)
859
+ )
860
+
861
+ if enable_multiscale and not multiscale_active:
862
+ print(
863
+ f"Multi-scale disabled: image too small ({orig_h}x{orig_w}) or insufficient steps ({n_steps})"
864
+ )
865
+ elif multiscale_active:
866
+ print(
867
+ f"Multi-scale active: {orig_h}x{orig_w} -> {scale_h}x{scale_w} (factor: {multiscale_factor})"
868
+ )
869
+
870
+ def downscale_tensor(tensor):
871
+ """Downscale tensor using bilinear interpolation"""
872
+ if not multiscale_active or tensor.shape[-2:] == (scale_h, scale_w):
873
+ return tensor
874
+ return torch.nn.functional.interpolate(
875
+ tensor, size=(scale_h, scale_w), mode="bilinear", align_corners=False
876
+ )
877
+
878
+ def upscale_tensor(tensor):
879
+ """Upscale tensor using bilinear interpolation"""
880
+ if not multiscale_active or tensor.shape[-2:] == (orig_h, orig_w):
881
+ return tensor
882
+ return torch.nn.functional.interpolate(
883
+ tensor, size=(orig_h, orig_w), mode="bilinear", align_corners=False
884
+ )
885
+
886
+ def should_use_fullres(step):
887
+ """Determine if this step should use full resolution"""
888
+ if not multiscale_active:
889
+ return True
890
+
891
+ # Always use full resolution for start and end steps
892
+ if step < multiscale_fullres_start or step >= n_steps - multiscale_fullres_end:
893
+ return True
894
+
895
+ # Intermittent full-res: every 2nd step in low-res region if enabled
896
+ if multiscale_intermittent_fullres:
897
+ # Check if we're in the low-res region
898
+ low_res_region_start = multiscale_fullres_start
899
+ low_res_region_end = n_steps - multiscale_fullres_end
900
+ if low_res_region_start <= step < low_res_region_end:
901
+ # Calculate position within low-res region
902
+ relative_step = step - low_res_region_start
903
+ # Use full-res every 2nd step (0, 2, 4, ...)
904
+ return relative_step % 2 == 0
905
+
906
+ return False
907
+
908
  # Pre-calculate all needed values in one go
909
  sigma_steps = torch.exp(-t_steps) # Fused calculation
910
  ratios = sigma_steps[1:] / sigma_steps[:-1]
 
940
  if not pipeline:
941
  app_instance.app.progress.set(i / n_steps)
942
 
943
+ # Determine resolution for this step
944
+ use_fullres = should_use_fullres(i)
945
+
946
+ # Scale input for processing
947
+ if use_fullres:
948
+ x_process = x
949
+ s_in_process = s_in
950
+ else:
951
+ x_process = downscale_tensor(x)
952
+ s_in_process = torch.ones((x_process.shape[0],), device=device)
953
+
954
  # Use pre-calculated CFG scale
955
  current_cfg = cfg_values[i]
956
 
957
+ # Model inference at appropriate resolution
958
+ denoised = model(x_process, sigmas[i] * s_in_process, **extra_args)
959
  uncond_denoised = extra_args.get("model_options", {}).get(
960
  "sampler_post_cfg_function", []
961
  )[-1]({"denoised": denoised, "uncond_denoised": None})
962
 
963
+ # Scale predictions back to original resolution if needed
964
+ if not use_fullres:
965
+ denoised = upscale_tensor(denoised)
966
+ uncond_denoised = upscale_tensor(uncond_denoised)
967
+
968
  if callback is not None:
969
  callback(
970
  {
 
1030
  cfg_x0_scale=1.0,
1031
  cfg_s_scale=1.0,
1032
  cfg_min=1.0,
1033
+ # Multi-scale parameters
1034
+ enable_multiscale=True,
1035
+ multiscale_factor=0.5,
1036
+ multiscale_fullres_start=5,
1037
+ multiscale_fullres_end=8,
1038
+ multiscale_intermittent_fullres=False,
1039
  ):
1040
+ """DPM-Solver++ (SDE) with CFG++ optimizations and multi-scale diffusion"""
1041
  # Pre-calculate common values
1042
  device = x.device
1043
  global disable_gui
 
1051
  if len(sigmas) <= 1:
1052
  return x
1053
 
1054
+ # Multi-scale setup with validation
1055
+ original_shape = x.shape
1056
+ batch_size, channels, orig_h, orig_w = original_shape
1057
+
1058
+ # Validate multi-scale parameters
1059
+ if enable_multiscale:
1060
+ if not (0.1 <= multiscale_factor <= 1.0):
1061
+ print(
1062
+ f"Warning: multiscale_factor {multiscale_factor} out of range [0.1, 1.0], disabling multi-scale"
1063
+ )
1064
+ enable_multiscale = False
1065
+ if multiscale_fullres_start < 0 or multiscale_fullres_end < 0:
1066
+ print("Warning: Invalid fullres step counts, disabling multi-scale")
1067
+ enable_multiscale = False
1068
+
1069
+ # Calculate scaled dimensions (must be multiples of 8 for VAE compatibility)
1070
+ scale_h = (
1071
+ int(max(8, ((orig_h * multiscale_factor) // 8) * 8))
1072
+ if enable_multiscale
1073
+ else orig_h
1074
+ )
1075
+ scale_w = (
1076
+ int(max(8, ((orig_w * multiscale_factor) // 8) * 8))
1077
+ if enable_multiscale
1078
+ else orig_w
1079
+ )
1080
+
1081
+ # Disable multi-scale for small images or short step counts
1082
  n_steps = len(sigmas) - 1
1083
+ multiscale_active = (
1084
+ enable_multiscale
1085
+ and orig_h > 64
1086
+ and orig_w > 64
1087
+ and n_steps > (multiscale_fullres_start + multiscale_fullres_end)
1088
+ and (scale_h != orig_h or scale_w != orig_w)
1089
+ )
1090
+
1091
+ if enable_multiscale and not multiscale_active:
1092
+ print(
1093
+ f"Multi-scale disabled: image too small ({orig_h}x{orig_w}) or insufficient steps ({n_steps})"
1094
+ )
1095
+ elif multiscale_active:
1096
+ print(
1097
+ f"Multi-scale active: {orig_h}x{orig_w} -> {scale_h}x{scale_w} (factor: {multiscale_factor})"
1098
+ )
1099
+
1100
+ def downscale_tensor(tensor):
1101
+ """Downscale tensor using bilinear interpolation"""
1102
+ if not multiscale_active or tensor.shape[-2:] == (scale_h, scale_w):
1103
+ return tensor
1104
+ return torch.nn.functional.interpolate(
1105
+ tensor, size=(scale_h, scale_w), mode="bilinear", align_corners=False
1106
+ )
1107
+
1108
+ def upscale_tensor(tensor):
1109
+ """Upscale tensor using bilinear interpolation"""
1110
+ if not multiscale_active or tensor.shape[-2:] == (orig_h, orig_w):
1111
+ return tensor
1112
+ return torch.nn.functional.interpolate(
1113
+ tensor, size=(orig_h, orig_w), mode="bilinear", align_corners=False
1114
+ )
1115
+
1116
+ def should_use_fullres(step):
1117
+ """Determine if this step should use full resolution"""
1118
+ if not multiscale_active:
1119
+ return True
1120
+
1121
+ # Always use full resolution for start and end steps
1122
+ if step < multiscale_fullres_start or step >= n_steps - multiscale_fullres_end:
1123
+ return True
1124
+
1125
+ # Intermittent full-res: every 2nd step in low-res region if enabled
1126
+ if multiscale_intermittent_fullres:
1127
+ # Check if we're in the low-res region
1128
+ low_res_region_start = multiscale_fullres_start
1129
+ low_res_region_end = n_steps - multiscale_fullres_end
1130
+ if low_res_region_start <= step < low_res_region_end:
1131
+ # Calculate position within low-res region
1132
+ relative_step = step - low_res_region_start
1133
+ # Use full-res every 2nd step (0, 2, 4, ...)
1134
+ return relative_step % 2 == 0
1135
+
1136
+ return False # Pre-allocate tensors and values
1137
+
1138
+ s_in = torch.ones((x.shape[0],), device=device)
1139
  extra_args = {} if extra_args is None else extra_args
1140
 
1141
  # CFG++ scheduling
 
1156
  x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
1157
  )
1158
 
1159
+ # Track previous predictions for momentum (stored at original resolution)
1160
  old_denoised = None
1161
  old_uncond_denoised = None
1162
 
 
1181
  if not pipeline:
1182
  app_instance.app.progress.set(i / n_steps)
1183
 
1184
+ # Determine resolution for this step
1185
+ use_fullres = should_use_fullres(i)
1186
+
1187
+ # Scale input for processing
1188
+ if use_fullres:
1189
+ x_process = x
1190
+ s_in_process = s_in
1191
+ else:
1192
+ x_process = downscale_tensor(x)
1193
+ s_in_process = torch.ones((x_process.shape[0],), device=device)
1194
+
1195
  # Get current CFG scale
1196
  current_cfg = get_cfg_scale(i)
1197
 
1198
+ # Model inference at appropriate resolution
1199
+ denoised = model(x_process, sigmas[i] * s_in_process, **extra_args)
1200
  uncond_denoised = extra_args.get("model_options", {}).get(
1201
  "sampler_post_cfg_function", []
1202
  )[-1]({"denoised": denoised, "uncond_denoised": None})
1203
 
1204
+ # Scale predictions back to original resolution if needed
1205
+ if not use_fullres:
1206
+ denoised = upscale_tensor(denoised)
1207
+ uncond_denoised = upscale_tensor(uncond_denoised)
1208
+
1209
  if callback is not None:
1210
  callback(
1211
  {
 
1236
  uncond_denoised + (denoised - uncond_denoised) * current_cfg
1237
  )
1238
  else:
1239
+ # CFG++ with momentum (using properly scaled momentum terms)
1240
  x0_coeff = cfg_x0_scale * current_cfg
1241
 
1242
+ # Calculate momentum terms at original resolution
1243
  h_ratio = (t - s_) / (2 * (t - t_next))
1244
  momentum = (1 + h_ratio) * denoised - h_ratio * old_denoised
1245
  uncond_momentum = (
 
1249
  # Combine with CFG++ scaling
1250
  cfg_denoised = uncond_momentum + (momentum - uncond_momentum) * x0_coeff
1251
 
1252
+ # Calculate x_2 for step 2
1253
+ noise_step1 = noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
1254
  x_2 = (
1255
  (sigma_fn(s_) / sigma_fn(t)) * x
1256
  - (t - s_).expm1() * cfg_denoised
1257
+ + noise_step1
1258
  )
1259
 
1260
+ # Step 2 inference - determine resolution
1261
+ use_fullres_step2 = should_use_fullres(i) if multiscale_active else True
1262
+
1263
+ if use_fullres_step2:
1264
+ x_2_process = x_2
1265
+ s_in_process_2 = s_in
1266
+ else:
1267
+ x_2_process = downscale_tensor(x_2)
1268
+ s_in_process_2 = torch.ones((x_2_process.shape[0],), device=device)
1269
+
1270
  # Step 2 inference
1271
+ denoised_2 = model(x_2_process, sigma_fn(s) * s_in_process_2, **extra_args)
1272
  uncond_denoised_2 = extra_args.get("model_options", {}).get(
1273
  "sampler_post_cfg_function", []
1274
  )[-1]({"denoised": denoised_2, "uncond_denoised": None})
1275
 
1276
+ # Scale step 2 predictions back if needed
1277
+ if not use_fullres_step2:
1278
+ denoised_2 = upscale_tensor(denoised_2)
1279
+ uncond_denoised_2 = upscale_tensor(uncond_denoised_2)
1280
+
1281
  # Step 2 CFG++ combination
1282
  if old_uncond_denoised is None:
1283
  cfg_denoised_2 = (
 
1299
  t_next_ = t_fn(sd)
1300
 
1301
  # Combined update with both predictions
1302
+ noise_final = noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
1303
  x = (
1304
  (sigma_fn(t_next_) / sigma_fn(t)) * x
1305
  - (t - t_next_).expm1()
1306
  * ((1 - 1 / (2 * r)) * cfg_denoised + (1 / (2 * r)) * cfg_denoised_2)
1307
+ + noise_final
1308
  )
1309
 
1310
  old_denoised = denoised
modules/sample/sampling.py CHANGED
@@ -788,6 +788,12 @@ class KSampler:
788
  disable_noise: bool = False,
789
  pipeline: bool = False,
790
  flux: bool = False,
 
 
 
 
 
 
791
  ) -> tuple:
792
  """Unified sampling interface that works both as direct sampling and through the common_ksampler.
793
 
@@ -870,6 +876,11 @@ class KSampler:
870
  force_full_denoise,
871
  pipeline or self.pipeline,
872
  flux,
 
 
 
 
 
873
  )
874
 
875
 
@@ -896,6 +907,12 @@ def sample1(
896
  seed: int = None,
897
  pipeline: bool = False,
898
  flux: bool = False,
 
 
 
 
 
 
899
  ) -> torch.Tensor:
900
  """Sample using the given parameters with the unified KSampler.
901
 
@@ -925,32 +942,102 @@ def sample1(
925
  Returns:
926
  torch.Tensor: The sampled tensor.
927
  """
928
- sampler = KSampler(
929
- model=model,
930
- steps=steps,
931
- sampler=sampler_name,
932
- scheduler=scheduler,
933
- denoise=denoise,
934
- model_options=model.model_options,
935
- pipeline=pipeline,
936
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937
 
938
- samples = sampler.direct_sample(
939
- noise,
940
- positive,
941
- negative,
942
- cfg=cfg,
943
- latent_image=latent_image,
944
- start_step=start_step,
945
- last_step=last_step,
946
- force_full_denoise=force_full_denoise,
947
- denoise_mask=noise_mask,
948
- sigmas=sigmas,
949
- callback=callback,
950
- disable_pbar=disable_pbar,
951
- seed=seed,
952
- flux=flux,
953
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
954
  samples = samples.to(Device.intermediate_device())
955
  return samples
956
 
@@ -1066,6 +1153,12 @@ def common_ksampler(
1066
  force_full_denoise: bool = False,
1067
  pipeline: bool = False,
1068
  flux: bool = False,
 
 
 
 
 
 
1069
  ) -> tuple:
1070
  """Common ksampler function.
1071
 
@@ -1126,6 +1219,11 @@ def common_ksampler(
1126
  seed=seed,
1127
  pipeline=pipeline,
1128
  flux=flux,
 
 
 
 
 
1129
  )
1130
  out = latent.copy()
1131
  out["samples"] = samples
 
788
  disable_noise: bool = False,
789
  pipeline: bool = False,
790
  flux: bool = False,
791
+ # Multi-scale diffusion parameters
792
+ enable_multiscale: bool = True,
793
+ multiscale_factor: float = 0.5,
794
+ multiscale_fullres_start: int = 3,
795
+ multiscale_fullres_end: int = 8,
796
+ multiscale_intermittent_fullres: bool = False,
797
  ) -> tuple:
798
  """Unified sampling interface that works both as direct sampling and through the common_ksampler.
799
 
 
876
  force_full_denoise,
877
  pipeline or self.pipeline,
878
  flux,
879
+ enable_multiscale,
880
+ multiscale_factor,
881
+ multiscale_fullres_start,
882
+ multiscale_fullres_end,
883
+ multiscale_intermittent_fullres,
884
  )
885
 
886
 
 
907
  seed: int = None,
908
  pipeline: bool = False,
909
  flux: bool = False,
910
+ # Multi-scale diffusion parameters
911
+ enable_multiscale: bool = True,
912
+ multiscale_factor: float = 0.5,
913
+ multiscale_fullres_start: int = 3,
914
+ multiscale_fullres_end: int = 8,
915
+ multiscale_intermittent_fullres: bool = False,
916
  ) -> torch.Tensor:
917
  """Sample using the given parameters with the unified KSampler.
918
 
 
942
  Returns:
943
  torch.Tensor: The sampled tensor.
944
  """
945
+ # Create extra options for multi-scale diffusion (for supported samplers)
946
+ multiscale_supported_samplers = [
947
+ "dpmpp_sde_cfgpp",
948
+ "sample_euler_ancestral",
949
+ "sample_euler",
950
+ "sample_dpmpp_2m_cfgpp",
951
+ ]
952
+
953
+ extra_options = {}
954
+ if sampler_name in multiscale_supported_samplers:
955
+ extra_options = {
956
+ "enable_multiscale": enable_multiscale,
957
+ "multiscale_factor": multiscale_factor,
958
+ "multiscale_fullres_start": multiscale_fullres_start,
959
+ "multiscale_fullres_end": multiscale_fullres_end,
960
+ "multiscale_intermittent_fullres": multiscale_intermittent_fullres,
961
+ }
962
+
963
+ # Use custom sampler with extra options for supported samplers
964
+ if sampler_name in multiscale_supported_samplers and extra_options:
965
+ # Create a custom sampler with multi-scale options
966
+ sampler_obj = ksampler(
967
+ sampler_name, pipeline=pipeline, extra_options=extra_options
968
+ )
969
+ sigmas = ksampler_util.calculate_sigmas(
970
+ model.get_model_object("model_sampling"), scheduler, steps
971
+ )
972
+ if denoise is None or denoise > 0.9999:
973
+ pass # Use full sigmas
974
+ else:
975
+ if denoise <= 0.0:
976
+ sigmas = torch.FloatTensor([])
977
+ else:
978
+ new_steps = int(steps / denoise)
979
+ sigmas_full = ksampler_util.calculate_sigmas(
980
+ model.get_model_object("model_sampling"), scheduler, new_steps
981
+ )
982
+ sigmas = sigmas_full[-(steps + 1) :]
983
 
984
+ # Process sigmas for start/end steps
985
+ if last_step is not None and last_step < (len(sigmas) - 1):
986
+ sigmas = sigmas[: last_step + 1]
987
+ if force_full_denoise:
988
+ sigmas[-1] = 0
989
+ if start_step is not None and start_step < (len(sigmas) - 1):
990
+ sigmas = sigmas[start_step:]
991
+
992
+ sigmas = sigmas.to(model.load_device)
993
+
994
+ # Use the custom sample function directly
995
+ samples = sample(
996
+ model,
997
+ noise,
998
+ positive,
999
+ negative,
1000
+ cfg,
1001
+ model.load_device,
1002
+ sampler_obj,
1003
+ sigmas,
1004
+ model.model_options,
1005
+ latent_image=latent_image,
1006
+ denoise_mask=noise_mask,
1007
+ callback=callback,
1008
+ disable_pbar=disable_pbar,
1009
+ seed=seed,
1010
+ pipeline=pipeline,
1011
+ flux=flux,
1012
+ )
1013
+ else:
1014
+ # Use the standard KSampler for other samplers
1015
+ sampler = KSampler(
1016
+ model=model,
1017
+ steps=steps,
1018
+ sampler=sampler_name,
1019
+ scheduler=scheduler,
1020
+ denoise=denoise,
1021
+ model_options=model.model_options,
1022
+ pipeline=pipeline,
1023
+ )
1024
+
1025
+ samples = sampler.direct_sample(
1026
+ noise,
1027
+ positive,
1028
+ negative,
1029
+ cfg=cfg,
1030
+ latent_image=latent_image,
1031
+ start_step=start_step,
1032
+ last_step=last_step,
1033
+ force_full_denoise=force_full_denoise,
1034
+ denoise_mask=noise_mask,
1035
+ sigmas=sigmas,
1036
+ callback=callback,
1037
+ disable_pbar=disable_pbar,
1038
+ seed=seed,
1039
+ flux=flux,
1040
+ )
1041
  samples = samples.to(Device.intermediate_device())
1042
  return samples
1043
 
 
1153
  force_full_denoise: bool = False,
1154
  pipeline: bool = False,
1155
  flux: bool = False,
1156
+ # Multi-scale diffusion parameters
1157
+ enable_multiscale: bool = True,
1158
+ multiscale_factor: float = 0.5,
1159
+ multiscale_fullres_start: int = 3,
1160
+ multiscale_fullres_end: int = 8,
1161
+ multiscale_intermittent_fullres: bool = False,
1162
  ) -> tuple:
1163
  """Common ksampler function.
1164
 
 
1219
  seed=seed,
1220
  pipeline=pipeline,
1221
  flux=flux,
1222
+ enable_multiscale=enable_multiscale,
1223
+ multiscale_factor=multiscale_factor,
1224
+ multiscale_fullres_start=multiscale_fullres_start,
1225
+ multiscale_fullres_end=multiscale_fullres_end,
1226
+ multiscale_intermittent_fullres=multiscale_intermittent_fullres,
1227
  )
1228
  out = latent.copy()
1229
  out["samples"] = samples
modules/user/GUI.py CHANGED
@@ -32,6 +32,7 @@ from modules.Quantize import Quantizer
32
  from modules.WaveSpeed import fbcache_nodes
33
  from modules.hidiffusion import msw_msa_attention
34
  from modules.AutoHDR import ahdr
 
35
 
36
  Downloader.CheckAndDownload()
37
 
@@ -64,7 +65,10 @@ class App(tk.Tk):
64
  """Initialize the App class."""
65
  super().__init__()
66
  self.title("LightDiffusion")
67
- self.geometry("900x750")
 
 
 
68
 
69
  # Configure main window grid
70
  self.grid_columnconfigure(1, weight=1)
@@ -238,6 +242,9 @@ class App(tk.Tk):
238
  self.checkbox_frame.grid_rowconfigure(0, weight=1)
239
  self.checkbox_frame.grid_rowconfigure(1, weight=1)
240
  self.checkbox_frame.grid_rowconfigure(2, weight=1)
 
 
 
241
 
242
  # checkbox for hiresfix
243
  self.hires_fix_var = tk.BooleanVar()
@@ -296,6 +303,43 @@ class App(tk.Tk):
296
  row=2, column=0, padx=(75, 5), pady=5, sticky="nsew"
297
  )
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  # Button to launch the generation
300
  self.generate_button = ctk.CTkButton(
301
  self.sidebar,
@@ -432,6 +476,32 @@ class App(tk.Tk):
432
  else "dpmpp_2m_cfgpp"
433
  )
434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  def _img2img(self, file_path: str) -> None:
436
  """Perform img2img on the selected image.
437
 
@@ -516,7 +586,7 @@ class App(tk.Tk):
516
  pass
517
  ultimatesdupscale_250 = ultimatesdupscale.upscale(
518
  upscale_by=2,
519
- seed=random.randint(1, 2**64),
520
  steps=8,
521
  cfg=6,
522
  sampler_name=self.sampler,
@@ -576,6 +646,47 @@ class App(tk.Tk):
576
  else:
577
  print("Adetailer is OFF")
578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  def print_previewer(self) -> None:
580
  """Print the status of the previewer checkbox."""
581
  if self.previewer_var.get() is True:
@@ -754,8 +865,12 @@ class App(tk.Tk):
754
  emptylatentimage_244 = emptylatentimage.generate(
755
  width=w, height=h, batch_size=int(self.batch_slider.get())
756
  )
 
 
 
 
757
  ksampler_239 = ksampler_instance.sample(
758
- seed=random.randint(1, 2**64),
759
  steps=20,
760
  cfg=cfg,
761
  sampler_name=self.sampler,
@@ -767,6 +882,7 @@ class App(tk.Tk):
767
  positive=cliptextencode_242[0],
768
  negative=cliptextencode_243[0],
769
  latent_image=emptylatentimage_244[0],
 
770
  )
771
  self.progress.set(0.4)
772
  if self.hires_fix_var.get() is True:
@@ -776,7 +892,7 @@ class App(tk.Tk):
776
  samples=ksampler_239[0],
777
  )
778
  ksampler_253 = ksampler_instance.sample(
779
- seed=random.randint(1, 2**64),
780
  steps=10,
781
  cfg=8,
782
  sampler_name="euler_ancestral_cfgpp",
@@ -788,6 +904,7 @@ class App(tk.Tk):
788
  positive=cliptextencode_242[0],
789
  negative=cliptextencode_243[0],
790
  latent_image=latentupscale_254[0],
 
791
  )
792
  vaedecode_240 = vaedecode.decode(
793
  samples=ksampler_253[0],
@@ -839,7 +956,7 @@ class App(tk.Tk):
839
  guide_size=512,
840
  guide_size_for=False,
841
  max_size=768,
842
- seed=random.randint(1, 2**64),
843
  steps=20,
844
  cfg=6.5,
845
  sampler_name=self.sampler,
@@ -893,7 +1010,7 @@ class App(tk.Tk):
893
  guide_size=512,
894
  guide_size_for=False,
895
  max_size=768,
896
- seed=random.randint(1, 2**64),
897
  steps=20,
898
  cfg=6.5,
899
  sampler_name=self.sampler,
@@ -994,7 +1111,7 @@ class App(tk.Tk):
994
  # except ImportError:
995
  # print("Triton not found, skipping compilation")
996
  ksampler_3 = ksampler.sample(
997
- seed=random.randint(1, 2**64),
998
  steps=20,
999
  cfg=1,
1000
  sampler_name="euler_cfgpp",
 
32
  from modules.WaveSpeed import fbcache_nodes
33
  from modules.hidiffusion import msw_msa_attention
34
  from modules.AutoHDR import ahdr
35
+ from modules.sample.multiscale_presets import MULTISCALE_PRESETS, get_preset_parameters
36
 
37
  Downloader.CheckAndDownload()
38
 
 
65
  """Initialize the App class."""
66
  super().__init__()
67
  self.title("LightDiffusion")
68
+ self.geometry("900x850")
69
+
70
+ # Initialize last seed
71
+ self.last_seed = self._get_last_seed()
72
 
73
  # Configure main window grid
74
  self.grid_columnconfigure(1, weight=1)
 
242
  self.checkbox_frame.grid_rowconfigure(0, weight=1)
243
  self.checkbox_frame.grid_rowconfigure(1, weight=1)
244
  self.checkbox_frame.grid_rowconfigure(2, weight=1)
245
+ self.checkbox_frame.grid_rowconfigure(3, weight=1)
246
+ self.checkbox_frame.grid_rowconfigure(4, weight=1)
247
+ self.checkbox_frame.grid_rowconfigure(5, weight=1)
248
 
249
  # checkbox for hiresfix
250
  self.hires_fix_var = tk.BooleanVar()
 
303
  row=2, column=0, padx=(75, 5), pady=5, sticky="nsew"
304
  )
305
 
306
+ # checkbox to enable multi-scale diffusion
307
+ self.multiscale_var = tk.BooleanVar(value=True)
308
+ self.multiscale_checkbox = ctk.CTkCheckBox(
309
+ self.checkbox_frame,
310
+ text="Multi-Scale",
311
+ variable=self.multiscale_var,
312
+ text_color="black",
313
+ )
314
+ self.multiscale_checkbox.grid(row=2, column=1, padx=5, pady=5, sticky="nsew")
315
+
316
+ # checkbox to enable reuse last seed
317
+ self.reuse_seed_var = tk.BooleanVar()
318
+ self.reuse_seed_checkbox = ctk.CTkCheckBox(
319
+ self.checkbox_frame,
320
+ text="Reuse Last Seed",
321
+ variable=self.reuse_seed_var,
322
+ text_color="black",
323
+ )
324
+ self.reuse_seed_checkbox.grid(
325
+ row=3, column=0, padx=(75, 5), pady=5, sticky="nsew"
326
+ )
327
+
328
+ # Multiscale preset dropdown
329
+ preset_names = [preset.name for preset in MULTISCALE_PRESETS.values()]
330
+ self.multiscale_preset_var = tk.StringVar(value="Quality")
331
+ self.multiscale_preset_dropdown = ctk.CTkOptionMenu(
332
+ self.checkbox_frame,
333
+ values=preset_names,
334
+ variable=self.multiscale_preset_var,
335
+ fg_color="#F5EFFF",
336
+ text_color="black",
337
+ command=self.on_preset_selected,
338
+ )
339
+ self.multiscale_preset_dropdown.grid(
340
+ row=4, column=0, columnspan=2, padx=5, pady=2, sticky="ew"
341
+ )
342
+
343
  # Button to launch the generation
344
  self.generate_button = ctk.CTkButton(
345
  self.sidebar,
 
476
  else "dpmpp_2m_cfgpp"
477
  )
478
 
479
+ def _get_last_seed(self) -> int:
480
+ """Get the last used seed from file."""
481
+ try:
482
+ with open(os.path.join("./_internal/", "last_seed.txt"), "r") as f:
483
+ return int(f.read().strip())
484
+ except (FileNotFoundError, ValueError):
485
+ return random.randint(1, 2**64)
486
+
487
+ def _save_last_seed(self, seed: int) -> None:
488
+ """Save the seed to file."""
489
+ try:
490
+ with open(os.path.join("./_internal/", "last_seed.txt"), "w") as f:
491
+ f.write(str(seed))
492
+ self.last_seed = seed
493
+ except Exception as e:
494
+ print(f"Error saving seed: {e}")
495
+
496
+ def _get_seed(self) -> int:
497
+ """Get seed based on reuse_seed checkbox."""
498
+ if self.reuse_seed_var.get():
499
+ return self.last_seed
500
+ else:
501
+ new_seed = random.randint(1, 2**64)
502
+ self._save_last_seed(new_seed)
503
+ return new_seed
504
+
505
  def _img2img(self, file_path: str) -> None:
506
  """Perform img2img on the selected image.
507
 
 
586
  pass
587
  ultimatesdupscale_250 = ultimatesdupscale.upscale(
588
  upscale_by=2,
589
+ seed=self._get_seed(),
590
  steps=8,
591
  cfg=6,
592
  sampler_name=self.sampler,
 
646
  else:
647
  print("Adetailer is OFF")
648
 
649
+ def on_preset_selected(self, preset_name: str) -> None:
650
+ """Handle multiscale preset selection."""
651
+ # Find preset by name (case-insensitive)
652
+ preset = None
653
+ preset_key = None
654
+ for key, p in MULTISCALE_PRESETS.items():
655
+ if p.name.lower() == preset_name.lower():
656
+ preset = p
657
+ preset_key = key
658
+ break
659
+
660
+ if preset:
661
+ print(f"Selected multiscale preset: {preset_name}")
662
+ print(f" Factor: {preset.multiscale_factor}")
663
+ print(f" Start steps: {preset.multiscale_fullres_start}")
664
+ print(f" End steps: {preset.multiscale_fullres_end}")
665
+ print(f" Intermittent full-res: {preset.multiscale_intermittent_fullres}")
666
+
667
+ # Enable multiscale if not disabled preset
668
+ if preset_name != "Disabled":
669
+ self.multiscale_var.set(True)
670
+ else:
671
+ self.multiscale_var.set(False)
672
+
673
+ def get_multiscale_params(self) -> dict:
674
+ """Get current multiscale parameters from selected preset."""
675
+ preset_name = self.multiscale_preset_var.get()
676
+
677
+ # Find preset by name (case-insensitive)
678
+ for key, preset in MULTISCALE_PRESETS.items():
679
+ if preset.name.lower() == preset_name.lower():
680
+ params = get_preset_parameters(key)
681
+ # Override enable_multiscale based on checkbox
682
+ params["enable_multiscale"] = self.multiscale_var.get()
683
+ return params
684
+
685
+ # Fallback to quality preset if not found
686
+ params = get_preset_parameters("quality")
687
+ params["enable_multiscale"] = self.multiscale_var.get()
688
+ return params
689
+
690
  def print_previewer(self) -> None:
691
  """Print the status of the previewer checkbox."""
692
  if self.previewer_var.get() is True:
 
865
  emptylatentimage_244 = emptylatentimage.generate(
866
  width=w, height=h, batch_size=int(self.batch_slider.get())
867
  )
868
+
869
+ # Get multiscale parameters from selected preset
870
+ multiscale_params = self.get_multiscale_params()
871
+
872
  ksampler_239 = ksampler_instance.sample(
873
+ seed=self._get_seed(),
874
  steps=20,
875
  cfg=cfg,
876
  sampler_name=self.sampler,
 
882
  positive=cliptextencode_242[0],
883
  negative=cliptextencode_243[0],
884
  latent_image=emptylatentimage_244[0],
885
+ **multiscale_params,
886
  )
887
  self.progress.set(0.4)
888
  if self.hires_fix_var.get() is True:
 
892
  samples=ksampler_239[0],
893
  )
894
  ksampler_253 = ksampler_instance.sample(
895
+ seed=self._get_seed(),
896
  steps=10,
897
  cfg=8,
898
  sampler_name="euler_ancestral_cfgpp",
 
904
  positive=cliptextencode_242[0],
905
  negative=cliptextencode_243[0],
906
  latent_image=latentupscale_254[0],
907
+ **multiscale_params,
908
  )
909
  vaedecode_240 = vaedecode.decode(
910
  samples=ksampler_253[0],
 
956
  guide_size=512,
957
  guide_size_for=False,
958
  max_size=768,
959
+ seed=self._get_seed(),
960
  steps=20,
961
  cfg=6.5,
962
  sampler_name=self.sampler,
 
1010
  guide_size=512,
1011
  guide_size_for=False,
1012
  max_size=768,
1013
+ seed=self._get_seed(),
1014
  steps=20,
1015
  cfg=6.5,
1016
  sampler_name=self.sampler,
 
1111
  # except ImportError:
1112
  # print("Triton not found, skipping compilation")
1113
  ksampler_3 = ksampler.sample(
1114
+ seed=self._get_seed(),
1115
  steps=20,
1116
  cfg=1,
1117
  sampler_name="euler_cfgpp",
modules/user/pipeline.py CHANGED
@@ -44,6 +44,13 @@ def pipeline(
44
  prio_speed: bool = False,
45
  autohdr: bool = False,
46
  realistic_model: bool = False,
 
 
 
 
 
 
 
47
  ) -> None:
48
  """#### Run the LightDiffusion pipeline.
49
 
@@ -61,8 +68,29 @@ def pipeline(
61
  - `prio_speed` (bool, optional): Prioritize speed over quality. Defaults to False.
62
  - `autohdr` (bool, optional): Enable the AutoHDR mode. Defaults to False.
63
  - `realistic_model` (bool, optional): Use the realistic model. Defaults to False.
 
 
 
 
 
 
64
  """
65
  global last_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if reuse_seed:
67
  seed = last_seed
68
 
@@ -288,6 +316,7 @@ def pipeline(
288
  # applystablefast_158, "diffusion_model", 0.120
289
  # )
290
 
 
291
  ksampler_239 = ksampler_instance.sample(
292
  seed=seed,
293
  steps=20,
@@ -302,6 +331,11 @@ def pipeline(
302
  positive=cliptextencode_242[0],
303
  negative=cliptextencode_243[0],
304
  latent_image=emptylatentimage_244[0],
 
 
 
 
 
305
  )
306
  if hires_fix:
307
  latentupscale_254 = latent_upscale.upscale(
@@ -537,6 +571,41 @@ if __name__ == "__main__":
537
  action="store_true",
538
  help="Use the realistic model.",
539
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
  args = parser.parse_args()
541
 
542
  pipeline(
@@ -555,4 +624,10 @@ if __name__ == "__main__":
555
  args.prio_speed,
556
  args.autohdr,
557
  args.realistic_model,
 
 
 
 
 
 
558
  )
 
44
  prio_speed: bool = False,
45
  autohdr: bool = False,
46
  realistic_model: bool = False,
47
+ # Multi-scale diffusion parameters
48
+ multiscale_preset: str = None,
49
+ enable_multiscale: bool = True,
50
+ multiscale_factor: float = 0.5,
51
+ multiscale_fullres_start: int = 3,
52
+ multiscale_fullres_end: int = 8,
53
+ multiscale_intermittent_fullres: bool = False,
54
  ) -> None:
55
  """#### Run the LightDiffusion pipeline.
56
 
 
68
  - `prio_speed` (bool, optional): Prioritize speed over quality. Defaults to False.
69
  - `autohdr` (bool, optional): Enable the AutoHDR mode. Defaults to False.
70
  - `realistic_model` (bool, optional): Use the realistic model. Defaults to False.
71
+ - `multiscale_preset` (str, optional): Predefined multiscale preset ('quality', 'performance', 'balanced', 'disabled'). Overrides individual multiscale parameters. Defaults to None.
72
+ - `enable_multiscale` (bool, optional): Enable multi-scale diffusion for performance optimization. Defaults to True.
73
+ - `multiscale_factor` (float, optional): Scale factor for intermediate steps (0.1-1.0). Defaults to 0.5.
74
+ - `multiscale_fullres_start` (int, optional): Number of first steps at full resolution. Defaults to 3.
75
+ - `multiscale_fullres_end` (int, optional): Number of last steps at full resolution. Defaults to 8.
76
+ - `multiscale_intermittent_fullres` (bool, optional): Enable intermittent full-res rendering in low-res region. Defaults to False.
77
  """
78
  global last_seed
79
+
80
+ # Apply multiscale preset if specified (overrides individual parameters)
81
+ if multiscale_preset is not None:
82
+ from modules.sample.multiscale_presets import get_preset_parameters
83
+
84
+ preset_params = get_preset_parameters(multiscale_preset)
85
+ enable_multiscale = preset_params["enable_multiscale"]
86
+ multiscale_factor = preset_params["multiscale_factor"]
87
+ multiscale_fullres_start = preset_params["multiscale_fullres_start"]
88
+ multiscale_fullres_end = preset_params["multiscale_fullres_end"]
89
+ multiscale_intermittent_fullres = preset_params[
90
+ "multiscale_intermittent_fullres"
91
+ ]
92
+ print(f"Applied multiscale preset: {multiscale_preset}")
93
+
94
  if reuse_seed:
95
  seed = last_seed
96
 
 
316
  # applystablefast_158, "diffusion_model", 0.120
317
  # )
318
 
319
+ # Create sampler with multi-scale options
320
  ksampler_239 = ksampler_instance.sample(
321
  seed=seed,
322
  steps=20,
 
331
  positive=cliptextencode_242[0],
332
  negative=cliptextencode_243[0],
333
  latent_image=emptylatentimage_244[0],
334
+ enable_multiscale=enable_multiscale,
335
+ multiscale_factor=multiscale_factor,
336
+ multiscale_fullres_start=multiscale_fullres_start,
337
+ multiscale_fullres_end=multiscale_fullres_end,
338
+ multiscale_intermittent_fullres=multiscale_intermittent_fullres,
339
  )
340
  if hires_fix:
341
  latentupscale_254 = latent_upscale.upscale(
 
571
  action="store_true",
572
  help="Use the realistic model.",
573
  )
574
+ parser.add_argument(
575
+ "--multiscale-preset",
576
+ type=str,
577
+ choices=["quality", "performance", "balanced", "disabled"],
578
+ help="Predefined multiscale preset ('quality', 'performance', 'balanced', 'disabled'). Overrides individual multiscale parameters.",
579
+ )
580
+ parser.add_argument(
581
+ "--enable-multiscale",
582
+ action="store_true",
583
+ default=True,
584
+ help="Enable multi-scale diffusion for performance optimization.",
585
+ )
586
+ parser.add_argument(
587
+ "--multiscale-factor",
588
+ type=float,
589
+ default=0.5,
590
+ help="Scale factor for intermediate steps (0.1-1.0).",
591
+ )
592
+ parser.add_argument(
593
+ "--multiscale-fullres-start",
594
+ type=int,
595
+ default=3,
596
+ help="Number of first steps at full resolution.",
597
+ )
598
+ parser.add_argument(
599
+ "--multiscale-fullres-end",
600
+ type=int,
601
+ default=8,
602
+ help="Number of last steps at full resolution.",
603
+ )
604
+ parser.add_argument(
605
+ "--multiscale-intermittent-fullres",
606
+ action="store_true",
607
+ help="Enable intermittent full-res rendering in low-res region.",
608
+ )
609
  args = parser.parse_args()
610
 
611
  pipeline(
 
624
  args.prio_speed,
625
  args.autohdr,
626
  args.realistic_model,
627
+ args.multiscale_preset,
628
+ args.enable_multiscale,
629
+ args.multiscale_factor,
630
+ args.multiscale_fullres_start,
631
+ args.multiscale_fullres_end,
632
+ args.multiscale_intermittent_fullres,
633
  )
pipeline.bat CHANGED
@@ -24,7 +24,7 @@ FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
- uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
+ uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu128
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
run.bat CHANGED
@@ -24,7 +24,7 @@ FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
- uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
+ uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
run_web.bat CHANGED
@@ -24,7 +24,7 @@ FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
- uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 
24
  )
25
  IF NOT ERRORLEVEL 1 (
26
  echo NVIDIA GPU detected, installing GPU dependencies...
27
+ uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu128
28
  ) ELSE (
29
  echo No NVIDIA GPU detected, installing CPU dependencies...
30
  uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu