File size: 4,633 Bytes
e352aba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import torch.nn as nn
import torchvision.utils as vutils
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt


# Define Generator architecture - must match what you used during training
class Generator(nn.Module):
    def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)


# Load the generator
def load_model(model_path="model/netG_best.pth"):
    # Create the generator and load the saved weights
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)

    try:
        netG.load_state_dict(torch.load(model_path, map_location=device))
        netG.eval()  # Set to evaluation mode
        print(f"Model loaded successfully from {model_path}")
        return netG, device
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, device


# Generate images using the model
def generate_images(num_images=16, seed=None, randomize=True):
    # Load the model (do this once when needed)
    global model, device
    if 'model' not in globals():
        model, device = load_model()
        if model is None:
            return np.zeros((299, 299, 3))

    # Set random seed for reproducibility if provided
    if seed is not None and not randomize:
        torch.manual_seed(seed)
        np.random.seed(seed)

    # Generate latent vectors
    nz = 100  # Size of the latent vector
    noise = torch.randn(num_images, nz, 1, 1, device=device)

    # Generate fake images
    with torch.no_grad():
        fake_images = model(noise).detach().cpu()

    # Convert to grid for display
    grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))

    # Convert from tensor to numpy array for Gradio
    grid_np = grid.numpy().transpose((1, 2, 0))

    # Make sure values are in 0-1 range
    grid_np = np.clip(grid_np, 0, 1)

    return grid_np


# Create Gradio interface
def create_gradio_app():
    with gr.Blocks(title="Computer Mouse Generator") as app:
        gr.Markdown("# Computer Mouse GAN Generator")
        gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")

        with gr.Row():
            with gr.Column():
                num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
                seed = gr.Number(label="Random Seed", value=42, precision=0)
                randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
                generate_button = gr.Button("Generate Mice")

            with gr.Column():
                output_image = gr.Image(label="Generated Computer Mice")

        generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)

        gr.Markdown("## About")
        gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.



The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.



The generator creates brand new, never-before-seen computer mice from random noise!""")

    return app


# Initialize global variables
model = None
device = None

# Launch the app if the script is run directly
if __name__ == "__main__":
    app = create_gradio_app()
    app.launch()