Spaces:
Sleeping
Sleeping
Uploaded 3 files
Browse filesapp, requirements and readme added
- README.md +42 -14
- app.py +107 -0
- requirements.txt +5 -0
README.md
CHANGED
@@ -1,14 +1,42 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DCGAN MNIST Generator
|
2 |
+
|
3 |
+
This repository contains a Deep Convolutional GAN (DCGAN) trained on the MNIST dataset. The model generates handwritten-like digit images from random noise.
|
4 |
+
|
5 |
+
## Model Architecture
|
6 |
+
|
7 |
+
The model implementation is based on the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434).
|
8 |
+
|
9 |
+
- Generator architecture: 5 transposed convolutional layers with batch normalization
|
10 |
+
- Latent space dimension: 100
|
11 |
+
- Output: 64x64 grayscale images
|
12 |
+
|
13 |
+
## Demo App
|
14 |
+
|
15 |
+
The included Gradio app allows you to generate new MNIST-like images using the pre-trained model.
|
16 |
+
|
17 |
+
### Running Locally
|
18 |
+
|
19 |
+
1. Install dependencies:
|
20 |
+
```bash
|
21 |
+
pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
|
24 |
+
2. Run the app:
|
25 |
+
```bash
|
26 |
+
python app.py
|
27 |
+
```
|
28 |
+
|
29 |
+
### Features
|
30 |
+
|
31 |
+
- Generate multiple images at once
|
32 |
+
- Set a random seed for reproducible outputs
|
33 |
+
- Visualize the generated images in a grid
|
34 |
+
|
35 |
+
## Training Details
|
36 |
+
|
37 |
+
This model was trained for 25 epochs on the MNIST dataset using PyTorch. For optimal results, the model checkpoint from epoch 21 is used for inference, as it produced the most realistic images without mode collapse.
|
38 |
+
|
39 |
+
## Acknowledgments
|
40 |
+
|
41 |
+
- Original DCGAN implementation based on [PyTorch examples](https://github.com/pytorch/examples/tree/master/dcgan)
|
42 |
+
- Training was tracked using Weights & Biases
|
app.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.utils as vutils
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
# Define Generator architecture - must match the architecture used during training
|
9 |
+
class Generator(nn.Module):
|
10 |
+
def __init__(self, ngpu=1, nz=100, ngf=64, nc=1):
|
11 |
+
super(Generator, self).__init__()
|
12 |
+
self.ngpu = ngpu
|
13 |
+
self.main = nn.Sequential(
|
14 |
+
# input is Z, going into a convolution
|
15 |
+
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
16 |
+
nn.BatchNorm2d(ngf * 8),
|
17 |
+
nn.ReLU(True),
|
18 |
+
# state size. (ngf*8) x 4 x 4
|
19 |
+
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
20 |
+
nn.BatchNorm2d(ngf * 4),
|
21 |
+
nn.ReLU(True),
|
22 |
+
# state size. (ngf*4) x 8 x 8
|
23 |
+
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
24 |
+
nn.BatchNorm2d(ngf * 2),
|
25 |
+
nn.ReLU(True),
|
26 |
+
# state size. (ngf*2) x 16 x 16
|
27 |
+
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
28 |
+
nn.BatchNorm2d(ngf),
|
29 |
+
nn.ReLU(True),
|
30 |
+
# state size. (ngf) x 32 x 32
|
31 |
+
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
32 |
+
nn.Tanh()
|
33 |
+
# state size. (nc) x 64 x 64
|
34 |
+
)
|
35 |
+
|
36 |
+
def forward(self, input):
|
37 |
+
return self.main(input)
|
38 |
+
|
39 |
+
|
40 |
+
# Load the generator
|
41 |
+
def load_model(model_path):
|
42 |
+
# Create the generator and load the saved weights
|
43 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
44 |
+
netG = Generator().to(device)
|
45 |
+
netG.load_state_dict(torch.load(model_path, map_location=device))
|
46 |
+
netG.eval() # Set to evaluation mode
|
47 |
+
return netG, device
|
48 |
+
|
49 |
+
|
50 |
+
# Generate images using the model
|
51 |
+
def generate_images(num_images=16, seed=None, model_path="models/netG_epoch_21.pth"):
|
52 |
+
# Load the model
|
53 |
+
netG, device = load_model(model_path)
|
54 |
+
|
55 |
+
# Set random seed for reproducibility if provided
|
56 |
+
if seed is not None:
|
57 |
+
torch.manual_seed(seed)
|
58 |
+
np.random.seed(seed)
|
59 |
+
|
60 |
+
# Generate latent vectors
|
61 |
+
nz = 100 # Size of the latent vector (must match the model)
|
62 |
+
noise = torch.randn(num_images, nz, 1, 1, device=device)
|
63 |
+
|
64 |
+
# Generate fake images
|
65 |
+
with torch.no_grad():
|
66 |
+
fake_images = netG(noise).detach().cpu()
|
67 |
+
|
68 |
+
# Convert to grid for display
|
69 |
+
grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
|
70 |
+
|
71 |
+
# Convert from tensor to numpy array for Gradio
|
72 |
+
grid_np = grid.numpy().transpose((1, 2, 0))
|
73 |
+
|
74 |
+
# Convert from [-1, 1] to [0, 1] range for display
|
75 |
+
grid_np = (grid_np + 1) / 2.0
|
76 |
+
|
77 |
+
return grid_np
|
78 |
+
|
79 |
+
|
80 |
+
# Create Gradio interface
|
81 |
+
def create_gradio_app():
|
82 |
+
with gr.Blocks(title="DCGAN MNIST Generator") as app:
|
83 |
+
gr.Markdown("# DCGAN MNIST Generator")
|
84 |
+
gr.Markdown("Generate MNIST-like digits using a Deep Convolutional GAN")
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
with gr.Column():
|
88 |
+
num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
|
89 |
+
seed = gr.Number(label="Random Seed (leave blank for random)", precision=0)
|
90 |
+
generate_button = gr.Button("Generate Images")
|
91 |
+
|
92 |
+
with gr.Column():
|
93 |
+
output_image = gr.Image(label="Generated Images")
|
94 |
+
|
95 |
+
generate_button.click(fn=generate_images, inputs=[num_images, seed], outputs=output_image)
|
96 |
+
|
97 |
+
gr.Markdown("## About")
|
98 |
+
gr.Markdown("This model was trained using PyTorch DCGAN implementation on the MNIST dataset. "
|
99 |
+
"It generates new handwritten digit-like images from random noise.")
|
100 |
+
|
101 |
+
return app
|
102 |
+
|
103 |
+
|
104 |
+
# Launch the app if the script is run directly
|
105 |
+
if __name__ == "__main__":
|
106 |
+
app = create_gradio_app()
|
107 |
+
app.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
gradio>=3.50.0
|
4 |
+
numpy>=1.22.0
|
5 |
+
Pillow>=9.0.0
|