ju4nppp commited on
Commit
79fa8cf
·
verified ·
1 Parent(s): e352aba

Update app.py

Browse files

changed folder name

Files changed (1) hide show
  1. app.py +125 -125
app.py CHANGED
@@ -1,126 +1,126 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchvision.utils as vutils
4
- import gradio as gr
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
-
8
-
9
- # Define Generator architecture - must match what you used during training
10
- class Generator(nn.Module):
11
- def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
12
- super(Generator, self).__init__()
13
- self.ngpu = ngpu
14
- self.main = nn.Sequential(
15
- # input is Z, going into a convolution
16
- nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
17
- nn.BatchNorm2d(ngf * 8),
18
- nn.ReLU(True),
19
- # state size. (ngf*8) x 4 x 4
20
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
21
- nn.BatchNorm2d(ngf * 4),
22
- nn.ReLU(True),
23
- # state size. (ngf*4) x 8 x 8
24
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
25
- nn.BatchNorm2d(ngf * 2),
26
- nn.ReLU(True),
27
- # state size. (ngf*2) x 16 x 16
28
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
29
- nn.BatchNorm2d(ngf),
30
- nn.ReLU(True),
31
- # state size. (ngf) x 32 x 32
32
- nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
33
- nn.Tanh()
34
- # state size. (nc) x 64 x 64
35
- )
36
-
37
- def forward(self, input):
38
- return self.main(input)
39
-
40
-
41
- # Load the generator
42
- def load_model(model_path="model/netG_best.pth"):
43
- # Create the generator and load the saved weights
44
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
- netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)
46
-
47
- try:
48
- netG.load_state_dict(torch.load(model_path, map_location=device))
49
- netG.eval() # Set to evaluation mode
50
- print(f"Model loaded successfully from {model_path}")
51
- return netG, device
52
- except Exception as e:
53
- print(f"Error loading model: {e}")
54
- return None, device
55
-
56
-
57
- # Generate images using the model
58
- def generate_images(num_images=16, seed=None, randomize=True):
59
- # Load the model (do this once when needed)
60
- global model, device
61
- if 'model' not in globals():
62
- model, device = load_model()
63
- if model is None:
64
- return np.zeros((299, 299, 3))
65
-
66
- # Set random seed for reproducibility if provided
67
- if seed is not None and not randomize:
68
- torch.manual_seed(seed)
69
- np.random.seed(seed)
70
-
71
- # Generate latent vectors
72
- nz = 100 # Size of the latent vector
73
- noise = torch.randn(num_images, nz, 1, 1, device=device)
74
-
75
- # Generate fake images
76
- with torch.no_grad():
77
- fake_images = model(noise).detach().cpu()
78
-
79
- # Convert to grid for display
80
- grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
81
-
82
- # Convert from tensor to numpy array for Gradio
83
- grid_np = grid.numpy().transpose((1, 2, 0))
84
-
85
- # Make sure values are in 0-1 range
86
- grid_np = np.clip(grid_np, 0, 1)
87
-
88
- return grid_np
89
-
90
-
91
- # Create Gradio interface
92
- def create_gradio_app():
93
- with gr.Blocks(title="Computer Mouse Generator") as app:
94
- gr.Markdown("# Computer Mouse GAN Generator")
95
- gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")
96
-
97
- with gr.Row():
98
- with gr.Column():
99
- num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
100
- seed = gr.Number(label="Random Seed", value=42, precision=0)
101
- randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
102
- generate_button = gr.Button("Generate Mice")
103
-
104
- with gr.Column():
105
- output_image = gr.Image(label="Generated Computer Mice")
106
-
107
- generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)
108
-
109
- gr.Markdown("## About")
110
- gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.
111
-
112
- The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.
113
-
114
- The generator creates brand new, never-before-seen computer mice from random noise!""")
115
-
116
- return app
117
-
118
-
119
- # Initialize global variables
120
- model = None
121
- device = None
122
-
123
- # Launch the app if the script is run directly
124
- if __name__ == "__main__":
125
- app = create_gradio_app()
126
  app.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.utils as vutils
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ # Define Generator architecture - must match what you used during training
10
+ class Generator(nn.Module):
11
+ def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
12
+ super(Generator, self).__init__()
13
+ self.ngpu = ngpu
14
+ self.main = nn.Sequential(
15
+ # input is Z, going into a convolution
16
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
17
+ nn.BatchNorm2d(ngf * 8),
18
+ nn.ReLU(True),
19
+ # state size. (ngf*8) x 4 x 4
20
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
21
+ nn.BatchNorm2d(ngf * 4),
22
+ nn.ReLU(True),
23
+ # state size. (ngf*4) x 8 x 8
24
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
25
+ nn.BatchNorm2d(ngf * 2),
26
+ nn.ReLU(True),
27
+ # state size. (ngf*2) x 16 x 16
28
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(ngf),
30
+ nn.ReLU(True),
31
+ # state size. (ngf) x 32 x 32
32
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
33
+ nn.Tanh()
34
+ # state size. (nc) x 64 x 64
35
+ )
36
+
37
+ def forward(self, input):
38
+ return self.main(input)
39
+
40
+
41
+ # Load the generator
42
+ def load_model(model_path="models/netG_best.pth"):
43
+ # Create the generator and load the saved weights
44
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
+ netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)
46
+
47
+ try:
48
+ netG.load_state_dict(torch.load(model_path, map_location=device))
49
+ netG.eval() # Set to evaluation mode
50
+ print(f"Model loaded successfully from {model_path}")
51
+ return netG, device
52
+ except Exception as e:
53
+ print(f"Error loading model: {e}")
54
+ return None, device
55
+
56
+
57
+ # Generate images using the model
58
+ def generate_images(num_images=16, seed=None, randomize=True):
59
+ # Load the model (do this once when needed)
60
+ global model, device
61
+ if 'model' not in globals():
62
+ model, device = load_model()
63
+ if model is None:
64
+ return np.zeros((299, 299, 3))
65
+
66
+ # Set random seed for reproducibility if provided
67
+ if seed is not None and not randomize:
68
+ torch.manual_seed(seed)
69
+ np.random.seed(seed)
70
+
71
+ # Generate latent vectors
72
+ nz = 100 # Size of the latent vector
73
+ noise = torch.randn(num_images, nz, 1, 1, device=device)
74
+
75
+ # Generate fake images
76
+ with torch.no_grad():
77
+ fake_images = model(noise).detach().cpu()
78
+
79
+ # Convert to grid for display
80
+ grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
81
+
82
+ # Convert from tensor to numpy array for Gradio
83
+ grid_np = grid.numpy().transpose((1, 2, 0))
84
+
85
+ # Make sure values are in 0-1 range
86
+ grid_np = np.clip(grid_np, 0, 1)
87
+
88
+ return grid_np
89
+
90
+
91
+ # Create Gradio interface
92
+ def create_gradio_app():
93
+ with gr.Blocks(title="Computer Mouse Generator") as app:
94
+ gr.Markdown("# Computer Mouse GAN Generator")
95
+ gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")
96
+
97
+ with gr.Row():
98
+ with gr.Column():
99
+ num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
100
+ seed = gr.Number(label="Random Seed", value=42, precision=0)
101
+ randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
102
+ generate_button = gr.Button("Generate Mice")
103
+
104
+ with gr.Column():
105
+ output_image = gr.Image(label="Generated Computer Mice")
106
+
107
+ generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)
108
+
109
+ gr.Markdown("## About")
110
+ gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.
111
+
112
+ The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.
113
+
114
+ The generator creates brand new, never-before-seen computer mice from random noise!""")
115
+
116
+ return app
117
+
118
+
119
+ # Initialize global variables
120
+ model = None
121
+ device = None
122
+
123
+ # Launch the app if the script is run directly
124
+ if __name__ == "__main__":
125
+ app = create_gradio_app()
126
  app.launch()