File size: 4,227 Bytes
79fa8cf
 
 
49c9fa6
 
 
e27d7b8
79fa8cf
49c9fa6
79fa8cf
49c9fa6
79fa8cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c9fa6
 
 
79fa8cf
49c9fa6
 
 
79fa8cf
49c9fa6
 
79fa8cf
49c9fa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e27d7b8
 
 
 
 
 
 
 
 
 
 
49c9fa6
 
 
e27d7b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49c9fa6
e27d7b8
 
 
 
49c9fa6
e27d7b8
79fa8cf
 
49c9fa6
e27d7b8
 
 
 
 
 
49c9fa6
e27d7b8
 
49c9fa6
 
 
79fa8cf
49c9fa6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from PIL import Image
import os
import math

# Define your Generator architecture - with ngf=128 to match your training parameters
class Generator(nn.Module):
    def __init__(self, ngpu=1, nz=100, ngf=128, nc=3):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

# Load the model - Update path to point to the models folder
device = torch.device("cpu")
model_path = "models/netG_epoch_246.pth"

# Print file existence for debugging
print(f"Checking if model file exists: {os.path.exists(model_path)}")
print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")

# Initialize the model with ngf=128 to match your training parameters
model = Generator(ngf=128).to(device)

# Try loading with error handling
try:
    model.load_state_dict(torch.load(model_path, map_location=device))
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    # Try alternative loading methods if the first fails
    try:
        model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
        print("Model loaded with strict=False")
    except Exception as e2:
        print(f"Error with alternative loading: {e2}")

# Set model to evaluation mode
model.eval()
print(f"Model initialized: {model is not None}")

def create_image_grid(images, rows, cols):
    """Create a grid of images"""
    w, h = images[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    
    for i, image in enumerate(images):
        grid.paste(image, box=(i%cols*w, i//cols*h))
    return grid

def generate_multiple_images(random_seed=42, num_images=4):
    """Generate multiple images using the DCGAN model"""
    # Set seed for reproducibility
    torch.manual_seed(random_seed)
    
    # Generate multiple images
    images = []
    for i in range(num_images):
        # Generate random noise with different seeds
        noise = torch.randn(1, 100, 1, 1, device=device)
        
        # Generate fake image
        with torch.no_grad():
            fake_image = model(noise).detach().cpu()
        
        # Convert tensor to image
        fake_img = fake_image * 0.5 + 0.5  # unnormalize
        fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy()
        fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8)
        images.append(Image.fromarray(fake_img))
    
    # Create a grid of images
    rows = int(math.sqrt(num_images))
    cols = int(math.ceil(num_images / rows))
    grid = create_image_grid(images, rows, cols)
    
    return grid

# Create Gradio interface
demo = gr.Interface(
    fn=generate_multiple_images,
    inputs=[
        gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"),
        gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images")
    ],
    outputs=gr.Image(type="pil", label="Generated Computer Mice"),
    title="DCGAN Computer Mouse Generator",
    description="Generate multiple unique computer mouse designs using a DCGAN model.",
    examples=[[42, 4], [23, 9], [7, 16]]
)

# Launch the app
if __name__ == "__main__":
    demo.launch()