ju4nppp commited on
Commit
a0ab06e
·
verified ·
1 Parent(s): cbfa2e4

Uploaded 3 files

Browse files

app, requirements and readme added

Files changed (3) hide show
  1. README.md +42 -14
  2. app.py +107 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,14 +1,42 @@
1
- ---
2
- title: Dcgan Mnist Demo
3
- emoji: 🐢
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.26.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: Space for deploying the dcgan-mnist app
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DCGAN MNIST Generator
2
+
3
+ This repository contains a Deep Convolutional GAN (DCGAN) trained on the MNIST dataset. The model generates handwritten-like digit images from random noise.
4
+
5
+ ## Model Architecture
6
+
7
+ The model implementation is based on the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](http://arxiv.org/abs/1511.06434).
8
+
9
+ - Generator architecture: 5 transposed convolutional layers with batch normalization
10
+ - Latent space dimension: 100
11
+ - Output: 64x64 grayscale images
12
+
13
+ ## Demo App
14
+
15
+ The included Gradio app allows you to generate new MNIST-like images using the pre-trained model.
16
+
17
+ ### Running Locally
18
+
19
+ 1. Install dependencies:
20
+ ```bash
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ 2. Run the app:
25
+ ```bash
26
+ python app.py
27
+ ```
28
+
29
+ ### Features
30
+
31
+ - Generate multiple images at once
32
+ - Set a random seed for reproducible outputs
33
+ - Visualize the generated images in a grid
34
+
35
+ ## Training Details
36
+
37
+ This model was trained for 25 epochs on the MNIST dataset using PyTorch. For optimal results, the model checkpoint from epoch 21 is used for inference, as it produced the most realistic images without mode collapse.
38
+
39
+ ## Acknowledgments
40
+
41
+ - Original DCGAN implementation based on [PyTorch examples](https://github.com/pytorch/examples/tree/master/dcgan)
42
+ - Training was tracked using Weights & Biases
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.utils as vutils
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+
8
+ # Define Generator architecture - must match the architecture used during training
9
+ class Generator(nn.Module):
10
+ def __init__(self, ngpu=1, nz=100, ngf=64, nc=1):
11
+ super(Generator, self).__init__()
12
+ self.ngpu = ngpu
13
+ self.main = nn.Sequential(
14
+ # input is Z, going into a convolution
15
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
16
+ nn.BatchNorm2d(ngf * 8),
17
+ nn.ReLU(True),
18
+ # state size. (ngf*8) x 4 x 4
19
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
20
+ nn.BatchNorm2d(ngf * 4),
21
+ nn.ReLU(True),
22
+ # state size. (ngf*4) x 8 x 8
23
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
24
+ nn.BatchNorm2d(ngf * 2),
25
+ nn.ReLU(True),
26
+ # state size. (ngf*2) x 16 x 16
27
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
28
+ nn.BatchNorm2d(ngf),
29
+ nn.ReLU(True),
30
+ # state size. (ngf) x 32 x 32
31
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
32
+ nn.Tanh()
33
+ # state size. (nc) x 64 x 64
34
+ )
35
+
36
+ def forward(self, input):
37
+ return self.main(input)
38
+
39
+
40
+ # Load the generator
41
+ def load_model(model_path):
42
+ # Create the generator and load the saved weights
43
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
44
+ netG = Generator().to(device)
45
+ netG.load_state_dict(torch.load(model_path, map_location=device))
46
+ netG.eval() # Set to evaluation mode
47
+ return netG, device
48
+
49
+
50
+ # Generate images using the model
51
+ def generate_images(num_images=16, seed=None, model_path="models/netG_epoch_21.pth"):
52
+ # Load the model
53
+ netG, device = load_model(model_path)
54
+
55
+ # Set random seed for reproducibility if provided
56
+ if seed is not None:
57
+ torch.manual_seed(seed)
58
+ np.random.seed(seed)
59
+
60
+ # Generate latent vectors
61
+ nz = 100 # Size of the latent vector (must match the model)
62
+ noise = torch.randn(num_images, nz, 1, 1, device=device)
63
+
64
+ # Generate fake images
65
+ with torch.no_grad():
66
+ fake_images = netG(noise).detach().cpu()
67
+
68
+ # Convert to grid for display
69
+ grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
70
+
71
+ # Convert from tensor to numpy array for Gradio
72
+ grid_np = grid.numpy().transpose((1, 2, 0))
73
+
74
+ # Convert from [-1, 1] to [0, 1] range for display
75
+ grid_np = (grid_np + 1) / 2.0
76
+
77
+ return grid_np
78
+
79
+
80
+ # Create Gradio interface
81
+ def create_gradio_app():
82
+ with gr.Blocks(title="DCGAN MNIST Generator") as app:
83
+ gr.Markdown("# DCGAN MNIST Generator")
84
+ gr.Markdown("Generate MNIST-like digits using a Deep Convolutional GAN")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
89
+ seed = gr.Number(label="Random Seed (leave blank for random)", precision=0)
90
+ generate_button = gr.Button("Generate Images")
91
+
92
+ with gr.Column():
93
+ output_image = gr.Image(label="Generated Images")
94
+
95
+ generate_button.click(fn=generate_images, inputs=[num_images, seed], outputs=output_image)
96
+
97
+ gr.Markdown("## About")
98
+ gr.Markdown("This model was trained using PyTorch DCGAN implementation on the MNIST dataset. "
99
+ "It generates new handwritten digit-like images from random noise.")
100
+
101
+ return app
102
+
103
+
104
+ # Launch the app if the script is run directly
105
+ if __name__ == "__main__":
106
+ app = create_gradio_app()
107
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ gradio>=3.50.0
4
+ numpy>=1.22.0
5
+ Pillow>=9.0.0