ju4nppp's picture
Update app.py
79fa8cf verified
raw
history blame
4.51 kB
import torch
import torch.nn as nn
import torchvision.utils as vutils
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# Define Generator architecture - must match what you used during training
class Generator(nn.Module):
def __init__(self, ngpu=1, nz=100, ngf=64, nc=3):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# Load the generator
def load_model(model_path="models/netG_best.pth"):
# Create the generator and load the saved weights
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
netG = Generator(ngpu=1, nz=100, ngf=64, nc=3).to(device)
try:
netG.load_state_dict(torch.load(model_path, map_location=device))
netG.eval() # Set to evaluation mode
print(f"Model loaded successfully from {model_path}")
return netG, device
except Exception as e:
print(f"Error loading model: {e}")
return None, device
# Generate images using the model
def generate_images(num_images=16, seed=None, randomize=True):
# Load the model (do this once when needed)
global model, device
if 'model' not in globals():
model, device = load_model()
if model is None:
return np.zeros((299, 299, 3))
# Set random seed for reproducibility if provided
if seed is not None and not randomize:
torch.manual_seed(seed)
np.random.seed(seed)
# Generate latent vectors
nz = 100 # Size of the latent vector
noise = torch.randn(num_images, nz, 1, 1, device=device)
# Generate fake images
with torch.no_grad():
fake_images = model(noise).detach().cpu()
# Convert to grid for display
grid = vutils.make_grid(fake_images, padding=2, normalize=True, nrow=int(np.sqrt(num_images)))
# Convert from tensor to numpy array for Gradio
grid_np = grid.numpy().transpose((1, 2, 0))
# Make sure values are in 0-1 range
grid_np = np.clip(grid_np, 0, 1)
return grid_np
# Create Gradio interface
def create_gradio_app():
with gr.Blocks(title="Computer Mouse Generator") as app:
gr.Markdown("# Computer Mouse GAN Generator")
gr.Markdown("Generate computer mice using a Deep Convolutional GAN trained on ~2,500 augmented images")
with gr.Row():
with gr.Column():
num_images = gr.Slider(minimum=1, maximum=64, value=16, step=1, label="Number of Images")
seed = gr.Number(label="Random Seed", value=42, precision=0)
randomize = gr.Checkbox(label="Use Random Seeds (ignore seed value)", value=True)
generate_button = gr.Button("Generate Mice")
with gr.Column():
output_image = gr.Image(label="Generated Computer Mice")
generate_button.click(fn=generate_images, inputs=[num_images, seed, randomize], outputs=output_image)
gr.Markdown("## About")
gr.Markdown("""This model was trained using a PyTorch DCGAN implementation on a dataset of computer mouse images.
The training process used data augmentation to expand a small dataset of 300+ original images into 2,500+ training samples through techniques like flipping, rotation, and brightness/contrast adjustments.
The generator creates brand new, never-before-seen computer mice from random noise!""")
return app
# Initialize global variables
model = None
device = None
# Launch the app if the script is run directly
if __name__ == "__main__":
app = create_gradio_app()
app.launch()