File size: 3,720 Bytes
b6afda1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0ab06e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torchvision.utils as vutils
import gradio as gr
import numpy as np


# Define Generator architecture - must match the architecture used during training
class Generator(nn.Module):
    def __init__(self, ngpu=1, nz=100, ngf=64, nc=1):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


# Load the generator
def load_model(model_path):
    # Create the generator and load the saved weights
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    netG = Generator().to(device)
    netG.load_state_dict(torch.load(model_path, map_location=device))
    netG.eval()  # Set to evaluation mode
    return netG, device


# Generate images using the model
def generate_images(num_images=16, seed=None, model_path="models/netG_epoch_29.pth"):
    # Load the model
    netG, device = load_model(model_path)

    # Set random seed for reproducibility if provided
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)

    # Generate latent vectors
    nz = 100  # Size of the latent vector (must match the model)
    noise = torch.randn(num_images, nz, 1, 1, device=device)

    # Generate fake images
    with torch.no_grad():
        fake_images = netG(noise).detach().cpu()

    # Convert to grid for display
    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))

    # Convert from tensor to numpy array for Gradio
    grid_np = grid.numpy().transpose((1, 2, 0))

    # Convert from [-1, 1] to [0, 1] range for display
    grid_np = (grid_np + 1) / 2.0

    return grid_np


# Create Gradio interface
def create_gradio_app():
    with gr.Blocks(title="DCGAN MNIST Generator") as app:
        gr.Markdown("# DCGAN MNIST Generator")
        gr.Markdown("Generate MNIST-like digits using a Deep Convolutional GAN")

        with gr.Row():
            with gr.Column():
                num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
                seed = gr.Number(label="Random Seed (leave blank for random)", precision=0)
                generate_button = gr.Button("Generate Images")

            with gr.Column():
                output_image = gr.Image(label="Generated Images")

        generate_button.click(fn=generate_images, inputs=[num_images, seed], outputs=output_image)

        gr.Markdown("## About")
        gr.Markdown("This model was trained using PyTorch DCGAN implementation on the MNIST dataset. "
                    "It generates new handwritten digit-like images from random noise.")

    return app


# Launch the app if the script is run directly
if __name__ == "__main__":
    app = create_gradio_app()
    app.launch()