druvx13 commited on
Commit
9dacd99
·
verified ·
1 Parent(s): 5b0827f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -1,19 +1,21 @@
1
  import gradio as gr
2
  import torch
3
- from diffusers import DDColorPipeline
 
4
  from PIL import Image
5
  import numpy as np
6
  import os
7
 
8
- # Model loading with optimized settings
9
  cache_dir = "./model_cache"
10
  os.makedirs(cache_dir, exist_ok=True)
11
 
12
- # Load model once at startup
13
- pipe = DDColorPipeline.from_pretrained(
14
  "camenduru/cv_ddcolor_image-colorization",
15
  torch_dtype=torch.float16,
16
- cache_dir=cache_dir
 
17
  ).to("cuda")
18
 
19
  def colorize_image(input_image):
@@ -22,16 +24,21 @@ def colorize_image(input_image):
22
  if input_image.mode != 'L':
23
  input_image = input_image.convert('L')
24
 
25
- # Resize to model's expected input size (based on DDColor paper)
26
- target_size = (256, 256)
27
  resized_image = input_image.resize(target_size)
28
 
29
- # Convert to numpy array for pipeline input
30
- grayscale_array = np.array(resized_image)
31
 
32
  # Generate colorized image
33
  with torch.inference_mode():
34
- result = pipe(grayscale_array).images[0]
 
 
 
 
 
35
 
36
  return result
37
 
 
1
  import gradio as gr
2
  import torch
3
+ from diffusers import AutoPipelineForImage2Image
4
+ from transformers import pipeline
5
  from PIL import Image
6
  import numpy as np
7
  import os
8
 
9
+ # Model loading with dynamic pipeline selection
10
  cache_dir = "./model_cache"
11
  os.makedirs(cache_dir, exist_ok=True)
12
 
13
+ # Load model using AutoPipeline
14
+ pipe = AutoPipelineForImage2Image.from_pretrained(
15
  "camenduru/cv_ddcolor_image-colorization",
16
  torch_dtype=torch.float16,
17
+ cache_dir=cache_dir,
18
+ variant="fp16"
19
  ).to("cuda")
20
 
21
  def colorize_image(input_image):
 
24
  if input_image.mode != 'L':
25
  input_image = input_image.convert('L')
26
 
27
+ # Resize to model's expected input size
28
+ target_size = (512, 512) # Increased resolution for better quality
29
  resized_image = input_image.resize(target_size)
30
 
31
+ # Convert to RGB as required by model
32
+ grayscale_image = resized_image.convert("RGB")
33
 
34
  # Generate colorized image
35
  with torch.inference_mode():
36
+ result = pipe(
37
+ prompt="colorized photo",
38
+ image=grayscale_image,
39
+ num_inference_steps=20,
40
+ strength=0.8
41
+ ).images[0]
42
 
43
  return result
44