akhaliq HF Staff commited on
Commit
77b8b6c
·
1 Parent(s): c456be0

update system prompt

Browse files
Files changed (1) hide show
  1. app.py +334 -14
app.py CHANGED
@@ -234,28 +234,35 @@ Functions that typically need @spaces.GPU:
234
 
235
  ## Advanced ZeroGPU Optimization (Recommended)
236
 
237
- For production Spaces with heavy models, consider ahead-of-time (AoT) compilation for 1.3x-1.8x speedups:
238
 
 
239
  ```python
240
  import spaces
241
  import torch
242
  from diffusers import DiffusionPipeline
243
 
244
- pipe = DiffusionPipeline.from_pretrained(..., torch_dtype=torch.bfloat16)
 
245
  pipe.to('cuda')
246
 
247
- @spaces.GPU(duration=1500) # Max duration for compilation
248
  def compile_transformer():
 
249
  with spaces.aoti_capture(pipe.transformer) as call:
250
  pipe("arbitrary example prompt")
251
 
 
252
  exported = torch.export.export(
253
  pipe.transformer,
254
  args=call.args,
255
  kwargs=call.kwargs,
256
  )
 
 
257
  return spaces.aoti_compile(exported)
258
 
 
259
  compiled_transformer = compile_transformer()
260
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
261
 
@@ -264,10 +271,163 @@ def generate(prompt):
264
  return pipe(prompt).images
265
  ```
266
 
267
- Optional enhancements:
268
- - FP8 quantization with torchao for additional 1.2x speedup (H200 compatible)
269
- - Dynamic shapes for variable input sizes
270
- - FlashAttention-3 via kernels library for attention speedups
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  ## Complete Gradio API Reference
273
 
@@ -327,28 +487,35 @@ Functions that typically need @spaces.GPU:
327
 
328
  ## Advanced ZeroGPU Optimization (Recommended)
329
 
330
- For production Spaces with heavy models, consider ahead-of-time (AoT) compilation for 1.3x-1.8x speedups:
331
 
 
332
  ```python
333
  import spaces
334
  import torch
335
  from diffusers import DiffusionPipeline
336
 
337
- pipe = DiffusionPipeline.from_pretrained(..., torch_dtype=torch.bfloat16)
 
338
  pipe.to('cuda')
339
 
340
- @spaces.GPU(duration=1500) # Max duration for compilation
341
  def compile_transformer():
 
342
  with spaces.aoti_capture(pipe.transformer) as call:
343
  pipe("arbitrary example prompt")
344
 
 
345
  exported = torch.export.export(
346
  pipe.transformer,
347
  args=call.args,
348
  kwargs=call.kwargs,
349
  )
 
 
350
  return spaces.aoti_compile(exported)
351
 
 
352
  compiled_transformer = compile_transformer()
353
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
354
 
@@ -357,10 +524,163 @@ def generate(prompt):
357
  return pipe(prompt).images
358
  ```
359
 
360
- Optional enhancements:
361
- - FP8 quantization with torchao for additional 1.2x speedup (H200 compatible)
362
- - Dynamic shapes for variable input sizes
363
- - FlashAttention-3 via kernels library for attention speedups
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  ## Complete Gradio API Reference
366
 
 
234
 
235
  ## Advanced ZeroGPU Optimization (Recommended)
236
 
237
+ For production Spaces with heavy models, use ahead-of-time (AoT) compilation for 1.3x-1.8x speedups:
238
 
239
+ ### Basic AoT Compilation
240
  ```python
241
  import spaces
242
  import torch
243
  from diffusers import DiffusionPipeline
244
 
245
+ MODEL_ID = 'black-forest-labs/FLUX.1-dev'
246
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
247
  pipe.to('cuda')
248
 
249
+ @spaces.GPU(duration=1500) # Maximum duration allowed during startup
250
  def compile_transformer():
251
+ # 1. Capture example inputs
252
  with spaces.aoti_capture(pipe.transformer) as call:
253
  pipe("arbitrary example prompt")
254
 
255
+ # 2. Export the model
256
  exported = torch.export.export(
257
  pipe.transformer,
258
  args=call.args,
259
  kwargs=call.kwargs,
260
  )
261
+
262
+ # 3. Compile the exported model
263
  return spaces.aoti_compile(exported)
264
 
265
+ # 4. Apply compiled model to pipeline
266
  compiled_transformer = compile_transformer()
267
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
268
 
 
271
  return pipe(prompt).images
272
  ```
273
 
274
+ ### Advanced Optimizations
275
+
276
+ #### FP8 Quantization (Additional 1.2x speedup on H200)
277
+ ```python
278
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
279
+
280
+ @spaces.GPU(duration=1500)
281
+ def compile_transformer_with_quantization():
282
+ # Quantize before export for FP8 speedup
283
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
284
+
285
+ with spaces.aoti_capture(pipe.transformer) as call:
286
+ pipe("arbitrary example prompt")
287
+
288
+ exported = torch.export.export(
289
+ pipe.transformer,
290
+ args=call.args,
291
+ kwargs=call.kwargs,
292
+ )
293
+ return spaces.aoti_compile(exported)
294
+ ```
295
+
296
+ #### Dynamic Shapes (Variable input sizes)
297
+ ```python
298
+ from torch.utils._pytree import tree_map
299
+
300
+ @spaces.GPU(duration=1500)
301
+ def compile_transformer_dynamic():
302
+ with spaces.aoti_capture(pipe.transformer) as call:
303
+ pipe("arbitrary example prompt")
304
+
305
+ # Define dynamic dimension ranges (model-dependent)
306
+ transformer_hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
307
+
308
+ # Map argument names to dynamic dimensions
309
+ transformer_dynamic_shapes = {
310
+ "hidden_states": {1: transformer_hidden_dim},
311
+ "img_ids": {0: transformer_hidden_dim},
312
+ }
313
+
314
+ # Create dynamic shapes structure
315
+ dynamic_shapes = tree_map(lambda v: None, call.kwargs)
316
+ dynamic_shapes.update(transformer_dynamic_shapes)
317
+
318
+ exported = torch.export.export(
319
+ pipe.transformer,
320
+ args=call.args,
321
+ kwargs=call.kwargs,
322
+ dynamic_shapes=dynamic_shapes,
323
+ )
324
+ return spaces.aoti_compile(exported)
325
+ ```
326
+
327
+ #### Multi-Compile for Different Resolutions
328
+ ```python
329
+ @spaces.GPU(duration=1500)
330
+ def compile_multiple_resolutions():
331
+ compiled_models = {}
332
+ resolutions = [(512, 512), (768, 768), (1024, 1024)]
333
+
334
+ for width, height in resolutions:
335
+ # Capture inputs for specific resolution
336
+ with spaces.aoti_capture(pipe.transformer) as call:
337
+ pipe(f"test prompt {width}x{height}", width=width, height=height)
338
+
339
+ exported = torch.export.export(
340
+ pipe.transformer,
341
+ args=call.args,
342
+ kwargs=call.kwargs,
343
+ )
344
+ compiled_models[f"{width}x{height}"] = spaces.aoti_compile(exported)
345
+
346
+ return compiled_models
347
+
348
+ # Usage with resolution dispatch
349
+ compiled_models = compile_multiple_resolutions()
350
+
351
+ @spaces.GPU
352
+ def generate_with_resolution(prompt, width=1024, height=1024):
353
+ resolution_key = f"{width}x{height}"
354
+ if resolution_key in compiled_models:
355
+ # Temporarily apply the right compiled model
356
+ spaces.aoti_apply(compiled_models[resolution_key], pipe.transformer)
357
+ return pipe(prompt, width=width, height=height).images
358
+ ```
359
+
360
+ #### FlashAttention-3 Integration
361
+ ```python
362
+ from kernels import get_kernel
363
+
364
+ # Load pre-built FA3 kernel compatible with H200
365
+ try:
366
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
367
+ print("✅ FlashAttention-3 kernel loaded successfully")
368
+ except Exception as e:
369
+ print(f"⚠️ FlashAttention-3 not available: {e}")
370
+
371
+ # Custom attention processor example
372
+ class FlashAttention3Processor:
373
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
374
+ # Use FA3 kernel for attention computation
375
+ return vllm_flash_attn3(hidden_states, encoder_hidden_states, attention_mask)
376
+
377
+ # Apply FA3 processor to model
378
+ if 'vllm_flash_attn3' in locals():
379
+ for name, module in pipe.transformer.named_modules():
380
+ if hasattr(module, 'processor'):
381
+ module.processor = FlashAttention3Processor()
382
+ ```
383
+
384
+ ### Complete Optimized Example
385
+ ```python
386
+ import spaces
387
+ import torch
388
+ from diffusers import DiffusionPipeline
389
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
390
+
391
+ MODEL_ID = 'black-forest-labs/FLUX.1-dev'
392
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
393
+ pipe.to('cuda')
394
+
395
+ @spaces.GPU(duration=1500)
396
+ def compile_optimized_transformer():
397
+ # Apply FP8 quantization
398
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
399
+
400
+ # Capture inputs
401
+ with spaces.aoti_capture(pipe.transformer) as call:
402
+ pipe("optimization test prompt")
403
+
404
+ # Export and compile
405
+ exported = torch.export.export(
406
+ pipe.transformer,
407
+ args=call.args,
408
+ kwargs=call.kwargs,
409
+ )
410
+ return spaces.aoti_compile(exported)
411
+
412
+ # Compile during startup
413
+ compiled_transformer = compile_optimized_transformer()
414
+ spaces.aoti_apply(compiled_transformer, pipe.transformer)
415
+
416
+ @spaces.GPU
417
+ def generate(prompt):
418
+ return pipe(prompt).images
419
+ ```
420
+
421
+ **Expected Performance Gains:**
422
+ - Basic AoT: 1.3x-1.8x speedup
423
+ - + FP8 Quantization: Additional 1.2x speedup
424
+ - + FlashAttention-3: Additional attention speedup
425
+ - Total potential: 2x-3x faster inference
426
+
427
+ **Hardware Requirements:**
428
+ - FP8 quantization requires CUDA compute capability ≥ 9.0 (H200 ✅)
429
+ - FlashAttention-3 works on H200 hardware via kernels library
430
+ - Dynamic shapes add flexibility for variable input sizes
431
 
432
  ## Complete Gradio API Reference
433
 
 
487
 
488
  ## Advanced ZeroGPU Optimization (Recommended)
489
 
490
+ For production Spaces with heavy models, use ahead-of-time (AoT) compilation for 1.3x-1.8x speedups:
491
 
492
+ ### Basic AoT Compilation
493
  ```python
494
  import spaces
495
  import torch
496
  from diffusers import DiffusionPipeline
497
 
498
+ MODEL_ID = 'black-forest-labs/FLUX.1-dev'
499
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
500
  pipe.to('cuda')
501
 
502
+ @spaces.GPU(duration=1500) # Maximum duration allowed during startup
503
  def compile_transformer():
504
+ # 1. Capture example inputs
505
  with spaces.aoti_capture(pipe.transformer) as call:
506
  pipe("arbitrary example prompt")
507
 
508
+ # 2. Export the model
509
  exported = torch.export.export(
510
  pipe.transformer,
511
  args=call.args,
512
  kwargs=call.kwargs,
513
  )
514
+
515
+ # 3. Compile the exported model
516
  return spaces.aoti_compile(exported)
517
 
518
+ # 4. Apply compiled model to pipeline
519
  compiled_transformer = compile_transformer()
520
  spaces.aoti_apply(compiled_transformer, pipe.transformer)
521
 
 
524
  return pipe(prompt).images
525
  ```
526
 
527
+ ### Advanced Optimizations
528
+
529
+ #### FP8 Quantization (Additional 1.2x speedup on H200)
530
+ ```python
531
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
532
+
533
+ @spaces.GPU(duration=1500)
534
+ def compile_transformer_with_quantization():
535
+ # Quantize before export for FP8 speedup
536
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
537
+
538
+ with spaces.aoti_capture(pipe.transformer) as call:
539
+ pipe("arbitrary example prompt")
540
+
541
+ exported = torch.export.export(
542
+ pipe.transformer,
543
+ args=call.args,
544
+ kwargs=call.kwargs,
545
+ )
546
+ return spaces.aoti_compile(exported)
547
+ ```
548
+
549
+ #### Dynamic Shapes (Variable input sizes)
550
+ ```python
551
+ from torch.utils._pytree import tree_map
552
+
553
+ @spaces.GPU(duration=1500)
554
+ def compile_transformer_dynamic():
555
+ with spaces.aoti_capture(pipe.transformer) as call:
556
+ pipe("arbitrary example prompt")
557
+
558
+ # Define dynamic dimension ranges (model-dependent)
559
+ transformer_hidden_dim = torch.export.Dim('hidden', min=4096, max=8212)
560
+
561
+ # Map argument names to dynamic dimensions
562
+ transformer_dynamic_shapes = {
563
+ "hidden_states": {1: transformer_hidden_dim},
564
+ "img_ids": {0: transformer_hidden_dim},
565
+ }
566
+
567
+ # Create dynamic shapes structure
568
+ dynamic_shapes = tree_map(lambda v: None, call.kwargs)
569
+ dynamic_shapes.update(transformer_dynamic_shapes)
570
+
571
+ exported = torch.export.export(
572
+ pipe.transformer,
573
+ args=call.args,
574
+ kwargs=call.kwargs,
575
+ dynamic_shapes=dynamic_shapes,
576
+ )
577
+ return spaces.aoti_compile(exported)
578
+ ```
579
+
580
+ #### Multi-Compile for Different Resolutions
581
+ ```python
582
+ @spaces.GPU(duration=1500)
583
+ def compile_multiple_resolutions():
584
+ compiled_models = {}
585
+ resolutions = [(512, 512), (768, 768), (1024, 1024)]
586
+
587
+ for width, height in resolutions:
588
+ # Capture inputs for specific resolution
589
+ with spaces.aoti_capture(pipe.transformer) as call:
590
+ pipe(f"test prompt {width}x{height}", width=width, height=height)
591
+
592
+ exported = torch.export.export(
593
+ pipe.transformer,
594
+ args=call.args,
595
+ kwargs=call.kwargs,
596
+ )
597
+ compiled_models[f"{width}x{height}"] = spaces.aoti_compile(exported)
598
+
599
+ return compiled_models
600
+
601
+ # Usage with resolution dispatch
602
+ compiled_models = compile_multiple_resolutions()
603
+
604
+ @spaces.GPU
605
+ def generate_with_resolution(prompt, width=1024, height=1024):
606
+ resolution_key = f"{width}x{height}"
607
+ if resolution_key in compiled_models:
608
+ # Temporarily apply the right compiled model
609
+ spaces.aoti_apply(compiled_models[resolution_key], pipe.transformer)
610
+ return pipe(prompt, width=width, height=height).images
611
+ ```
612
+
613
+ #### FlashAttention-3 Integration
614
+ ```python
615
+ from kernels import get_kernel
616
+
617
+ # Load pre-built FA3 kernel compatible with H200
618
+ try:
619
+ vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
620
+ print("✅ FlashAttention-3 kernel loaded successfully")
621
+ except Exception as e:
622
+ print(f"⚠️ FlashAttention-3 not available: {e}")
623
+
624
+ # Custom attention processor example
625
+ class FlashAttention3Processor:
626
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
627
+ # Use FA3 kernel for attention computation
628
+ return vllm_flash_attn3(hidden_states, encoder_hidden_states, attention_mask)
629
+
630
+ # Apply FA3 processor to model
631
+ if 'vllm_flash_attn3' in locals():
632
+ for name, module in pipe.transformer.named_modules():
633
+ if hasattr(module, 'processor'):
634
+ module.processor = FlashAttention3Processor()
635
+ ```
636
+
637
+ ### Complete Optimized Example
638
+ ```python
639
+ import spaces
640
+ import torch
641
+ from diffusers import DiffusionPipeline
642
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
643
+
644
+ MODEL_ID = 'black-forest-labs/FLUX.1-dev'
645
+ pipe = DiffusionPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
646
+ pipe.to('cuda')
647
+
648
+ @spaces.GPU(duration=1500)
649
+ def compile_optimized_transformer():
650
+ # Apply FP8 quantization
651
+ quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
652
+
653
+ # Capture inputs
654
+ with spaces.aoti_capture(pipe.transformer) as call:
655
+ pipe("optimization test prompt")
656
+
657
+ # Export and compile
658
+ exported = torch.export.export(
659
+ pipe.transformer,
660
+ args=call.args,
661
+ kwargs=call.kwargs,
662
+ )
663
+ return spaces.aoti_compile(exported)
664
+
665
+ # Compile during startup
666
+ compiled_transformer = compile_optimized_transformer()
667
+ spaces.aoti_apply(compiled_transformer, pipe.transformer)
668
+
669
+ @spaces.GPU
670
+ def generate(prompt):
671
+ return pipe(prompt).images
672
+ ```
673
+
674
+ **Expected Performance Gains:**
675
+ - Basic AoT: 1.3x-1.8x speedup
676
+ - + FP8 Quantization: Additional 1.2x speedup
677
+ - + FlashAttention-3: Additional attention speedup
678
+ - Total potential: 2x-3x faster inference
679
+
680
+ **Hardware Requirements:**
681
+ - FP8 quantization requires CUDA compute capability ≥ 9.0 (H200 ✅)
682
+ - FlashAttention-3 works on H200 hardware via kernels library
683
+ - Dynamic shapes add flexibility for variable input sizes
684
 
685
  ## Complete Gradio API Reference
686