ju4nppp commited on
Commit
e352aba
·
verified ·
1 Parent(s): 4fa9660

Uploaded 3 files

Browse files

Uploaded readme, app, and requirements

Files changed (3) hide show
  1. README.md +47 -0
  2. app.py +126 -0
  3. requirements.txt +8 -0
README.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Computer Mouse GAN Generator
3
+ emoji: 🖱️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.50.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Computer Mouse GAN Generator
13
+
14
+ This project uses a Deep Convolutional Generative Adversarial Network (DCGAN) to generate realistic images of computer mice. The model was trained on a dataset of computer mouse images with data augmentation to expand the training set.
15
+
16
+ ## Model Architecture
17
+
18
+ The model is based on the DCGAN architecture from the paper [Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks](https://arxiv.org/abs/1511.06434).
19
+
20
+ - Generator: 5 transpose convolutional layers with batch normalization
21
+ - RGB color output (3 channels)
22
+ - Trained with Weights & Biases monitoring
23
+ - Training included data augmentation (flips, rotations, brightness/contrast adjustments)
24
+
25
+ ## Demo App
26
+
27
+ The Gradio interface allows you to:
28
+ - Generate multiple computer mouse images at once
29
+ - Set the number of images to generate (1-64)
30
+ - Use a specific random seed for reproducible results
31
+ - Toggle random seed generation for variety
32
+
33
+ ## Training Process
34
+
35
+ The model was trained on:
36
+ - ~300 original computer mouse images
37
+ - Expanded to ~2,500 training samples through augmentation
38
+ - Trained for 150+ epochs
39
+ - Used CUDA acceleration on an RTX 3070
40
+
41
+ ## Examples
42
+
43
+ Generated images show a variety of computer mouse designs with different colors and shapes. Each image is completely new and generated from random noise - these mice don't exist in the real world!
44
+
45
+ ## Usage
46
+
47
+ Simply adjust the sliders and click "Generate Mice" to create new computer mouse designs.
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision.utils as vutils
4
+ import gradio as gr
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ # Define Generator architecture - must match what you used during training
10
+ class Generator(nn.Module):
11
+ def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
12
+ super(Generator, self).__init__()
13
+ self.ngpu = ngpu
14
+ self.main = nn.Sequential(
15
+ # input is Z, going into a convolution
16
+ nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
17
+ nn.BatchNorm2d(ngf * 8),
18
+ nn.ReLU(True),
19
+ # state size. (ngf*8) x 4 x 4
20
+ nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
21
+ nn.BatchNorm2d(ngf * 4),
22
+ nn.ReLU(True),
23
+ # state size. (ngf*4) x 8 x 8
24
+ nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
25
+ nn.BatchNorm2d(ngf * 2),
26
+ nn.ReLU(True),
27
+ # state size. (ngf*2) x 16 x 16
28
+ nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
29
+ nn.BatchNorm2d(ngf),
30
+ nn.ReLU(True),
31
+ # state size. (ngf) x 32 x 32
32
+ nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
33
+ nn.Tanh()
34
+ # state size. (nc) x 64 x 64
35
+ )
36
+
37
+ def forward(self, input):
38
+ return self.main(input)
39
+
40
+
41
+ # Load the generator
42
+ def load_model(model_path="model/netG_best.pth"):
43
+ # Create the generator and load the saved weights
44
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
45
+ netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)
46
+
47
+ try:
48
+ netG.load_state_dict(torch.load(model_path, map_location=device))
49
+ netG.eval() # Set to evaluation mode
50
+ print(f"Model loaded successfully from {model_path}")
51
+ return netG, device
52
+ except Exception as e:
53
+ print(f"Error loading model: {e}")
54
+ return None, device
55
+
56
+
57
+ # Generate images using the model
58
+ def generate_images(num_images=16, seed=None, randomize=True):
59
+ # Load the model (do this once when needed)
60
+ global model, device
61
+ if 'model' not in globals():
62
+ model, device = load_model()
63
+ if model is None:
64
+ return np.zeros((299, 299, 3))
65
+
66
+ # Set random seed for reproducibility if provided
67
+ if seed is not None and not randomize:
68
+ torch.manual_seed(seed)
69
+ np.random.seed(seed)
70
+
71
+ # Generate latent vectors
72
+ nz = 100 # Size of the latent vector
73
+ noise = torch.randn(num_images, nz, 1, 1, device=device)
74
+
75
+ # Generate fake images
76
+ with torch.no_grad():
77
+ fake_images = model(noise).detach().cpu()
78
+
79
+ # Convert to grid for display
80
+ grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
81
+
82
+ # Convert from tensor to numpy array for Gradio
83
+ grid_np = grid.numpy().transpose((1, 2, 0))
84
+
85
+ # Make sure values are in 0-1 range
86
+ grid_np = np.clip(grid_np, 0, 1)
87
+
88
+ return grid_np
89
+
90
+
91
+ # Create Gradio interface
92
+ def create_gradio_app():
93
+ with gr.Blocks(title="Computer Mouse Generator") as app:
94
+ gr.Markdown("# Computer Mouse GAN Generator")
95
+ gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")
96
+
97
+ with gr.Row():
98
+ with gr.Column():
99
+ num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
100
+ seed = gr.Number(label="Random Seed", value=42, precision=0)
101
+ randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
102
+ generate_button = gr.Button("Generate Mice")
103
+
104
+ with gr.Column():
105
+ output_image = gr.Image(label="Generated Computer Mice")
106
+
107
+ generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)
108
+
109
+ gr.Markdown("## About")
110
+ gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.
111
+
112
+ The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.
113
+
114
+ The generator creates brand new, never-before-seen computer mice from random noise!""")
115
+
116
+ return app
117
+
118
+
119
+ # Initialize global variables
120
+ model = None
121
+ device = None
122
+
123
+ # Launch the app if the script is run directly
124
+ if __name__ == "__main__":
125
+ app = create_gradio_app()
126
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ matplotlib>=3.5.0
4
+ numpy>=1.22.0
5
+ Pillow>=9.0.0
6
+ gradio>=3.50.0
7
+ tqdm>=4.64.0
8
+ ipython>=8.0.0