ju4nppp commited on
Commit
b6afda1
·
verified ·
1 Parent(s): 8f3fcec

updated app

Browse files

addited the model path

Files changed (1) hide show
  1. app.py +106 -106
app.py CHANGED
@@ -1,107 +1,107 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.utils as vutils
4
- import gradio as gr
5
- import numpy as np
6
-
7
-
8
- # Define Generator architecture - must match the architecture used during training
9
- class Generator(nn.Module):
10
- def __init__(self, ngpu=1, nz=100, ngf=64, nc=1):
11
- super(Generator, self).__init__()
12
- self.ngpu = ngpu
13
- self.main = nn.Sequential(
14
- # input is Z, going into a convolution
15
- nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
16
- nn.BatchNorm2d(ngf * 8),
17
- nn.ReLU(True),
18
- # state size. (ngf*8) x 4 x 4
19
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
20
- nn.BatchNorm2d(ngf * 4),
21
- nn.ReLU(True),
22
- # state size. (ngf*4) x 8 x 8
23
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
24
- nn.BatchNorm2d(ngf * 2),
25
- nn.ReLU(True),
26
- # state size. (ngf*2) x 16 x 16
27
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
28
- nn.BatchNorm2d(ngf),
29
- nn.ReLU(True),
30
- # state size. (ngf) x 32 x 32
31
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
32
- nn.Tanh()
33
- # state size. (nc) x 64 x 64
34
- )
35
-
36
- def forward(self, input):
37
- return self.main(input)
38
-
39
-
40
- # Load the generator
41
- def load_model(model_path):
42
- # Create the generator and load the saved weights
43
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
- netG = Generator().to(device)
45
- netG.load_state_dict(torch.load(model_path, map_location=device))
46
- netG.eval() # Set to evaluation mode
47
- return netG, device
48
-
49
-
50
- # Generate images using the model
51
- def generate_images(num_images=16, seed=None, model_path="models/netG_epoch_21.pth"):
52
- # Load the model
53
- netG, device = load_model(model_path)
54
-
55
- # Set random seed for reproducibility if provided
56
- if seed is not None:
57
- torch.manual_seed(seed)
58
- np.random.seed(seed)
59
-
60
- # Generate latent vectors
61
- nz = 100 # Size of the latent vector (must match the model)
62
- noise = torch.randn(num_images, nz, 1, 1, device=device)
63
-
64
- # Generate fake images
65
- with torch.no_grad():
66
- fake_images = netG(noise).detach().cpu()
67
-
68
- # Convert to grid for display
69
- grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
70
-
71
- # Convert from tensor to numpy array for Gradio
72
- grid_np = grid.numpy().transpose((1, 2, 0))
73
-
74
- # Convert from [-1, 1] to [0, 1] range for display
75
- grid_np = (grid_np + 1) / 2.0
76
-
77
- return grid_np
78
-
79
-
80
- # Create Gradio interface
81
- def create_gradio_app():
82
- with gr.Blocks(title="DCGAN MNIST Generator") as app:
83
- gr.Markdown("# DCGAN MNIST Generator")
84
- gr.Markdown("Generate MNIST-like digits using a Deep Convolutional GAN")
85
-
86
- with gr.Row():
87
- with gr.Column():
88
- num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
89
- seed = gr.Number(label="Random Seed (leave blank for random)", precision=0)
90
- generate_button = gr.Button("Generate Images")
91
-
92
- with gr.Column():
93
- output_image = gr.Image(label="Generated Images")
94
-
95
- generate_button.click(fn=generate_images, inputs=[num_images, seed], outputs=output_image)
96
-
97
- gr.Markdown("## About")
98
- gr.Markdown("This model was trained using PyTorch DCGAN implementation on the MNIST dataset. "
99
- "It generates new handwritten digit-like images from random noise.")
100
-
101
- return app
102
-
103
-
104
- # Launch the app if the script is run directly
105
- if __name__ == "__main__":
106
- app = create_gradio_app()
107
  app.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.utils as vutils
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+
8
+ # Define Generator architecture - must match the architecture used during training
9
+ class Generator(nn.Module):
10
+ def __init__(self, ngpu=1, nz=100, ngf=64, nc=1):
11
+ super(Generator, self).__init__()
12
+ self.ngpu = ngpu
13
+ self.main = nn.Sequential(
14
+ # input is Z, going into a convolution
15
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
16
+ nn.BatchNorm2d(ngf * 8),
17
+ nn.ReLU(True),
18
+ # state size. (ngf*8) x 4 x 4
19
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
20
+ nn.BatchNorm2d(ngf * 4),
21
+ nn.ReLU(True),
22
+ # state size. (ngf*4) x 8 x 8
23
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
24
+ nn.BatchNorm2d(ngf * 2),
25
+ nn.ReLU(True),
26
+ # state size. (ngf*2) x 16 x 16
27
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
28
+ nn.BatchNorm2d(ngf),
29
+ nn.ReLU(True),
30
+ # state size. (ngf) x 32 x 32
31
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
32
+ nn.Tanh()
33
+ # state size. (nc) x 64 x 64
34
+ )
35
+
36
+ def forward(self, input):
37
+ return self.main(input)
38
+
39
+
40
+ # Load the generator
41
+ def load_model(model_path):
42
+ # Create the generator and load the saved weights
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ netG = Generator().to(device)
45
+ netG.load_state_dict(torch.load(model_path, map_location=device))
46
+ netG.eval() # Set to evaluation mode
47
+ return netG, device
48
+
49
+
50
+ # Generate images using the model
51
+ def generate_images(num_images=16, seed=None, model_path="models/netG_epoch_29.pth"):
52
+ # Load the model
53
+ netG, device = load_model(model_path)
54
+
55
+ # Set random seed for reproducibility if provided
56
+ if seed is not None:
57
+ torch.manual_seed(seed)
58
+ np.random.seed(seed)
59
+
60
+ # Generate latent vectors
61
+ nz = 100 # Size of the latent vector (must match the model)
62
+ noise = torch.randn(num_images, nz, 1, 1, device=device)
63
+
64
+ # Generate fake images
65
+ with torch.no_grad():
66
+ fake_images = netG(noise).detach().cpu()
67
+
68
+ # Convert to grid for display
69
+ grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
70
+
71
+ # Convert from tensor to numpy array for Gradio
72
+ grid_np = grid.numpy().transpose((1, 2, 0))
73
+
74
+ # Convert from [-1, 1] to [0, 1] range for display
75
+ grid_np = (grid_np + 1) / 2.0
76
+
77
+ return grid_np
78
+
79
+
80
+ # Create Gradio interface
81
+ def create_gradio_app():
82
+ with gr.Blocks(title="DCGAN MNIST Generator") as app:
83
+ gr.Markdown("# DCGAN MNIST Generator")
84
+ gr.Markdown("Generate MNIST-like digits using a Deep Convolutional GAN")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
89
+ seed = gr.Number(label="Random Seed (leave blank for random)", precision=0)
90
+ generate_button = gr.Button("Generate Images")
91
+
92
+ with gr.Column():
93
+ output_image = gr.Image(label="Generated Images")
94
+
95
+ generate_button.click(fn=generate_images, inputs=[num_images, seed], outputs=output_image)
96
+
97
+ gr.Markdown("## About")
98
+ gr.Markdown("This model was trained using PyTorch DCGAN implementation on the MNIST dataset. "
99
+ "It generates new handwritten digit-like images from random noise.")
100
+
101
+ return app
102
+
103
+
104
+ # Launch the app if the script is run directly
105
+ if __name__ == "__main__":
106
+ app = create_gradio_app()
107
  app.launch()