ju4nppp's picture
Update app.py
e27d7b8 verified
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from PIL import Image
import os
import math
# Define your Generator architecture - with ngf=128 to match your training parameters
class Generator(nn.Module):
def __init__(self, ngpu=1, nz=100, ngf=128, nc=3):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# Load the model - Update path to point to the models folder
device = torch.device("cpu")
model_path = "models/netG_epoch_246.pth"
# Print file existence for debugging
print(f"Checking if model file exists: {os.path.exists(model_path)}")
print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
# Initialize the model with ngf=128 to match your training parameters
model = Generator(ngf=128).to(device)
# Try loading with error handling
try:
model.load_state_dict(torch.load(model_path, map_location=device))
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Try alternative loading methods if the first fails
try:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
print("Model loaded with strict=False")
except Exception as e2:
print(f"Error with alternative loading: {e2}")
# Set model to evaluation mode
model.eval()
print(f"Model initialized: {model is not None}")
def create_image_grid(images, rows, cols):
"""Create a grid of images"""
w, h = images[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
def generate_multiple_images(random_seed=42, num_images=4):
"""Generate multiple images using the DCGAN model"""
# Set seed for reproducibility
torch.manual_seed(random_seed)
# Generate multiple images
images = []
for i in range(num_images):
# Generate random noise with different seeds
noise = torch.randn(1, 100, 1, 1, device=device)
# Generate fake image
with torch.no_grad():
fake_image = model(noise).detach().cpu()
# Convert tensor to image
fake_img = fake_image * 0.5 + 0.5 # unnormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy()
fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8)
images.append(Image.fromarray(fake_img))
# Create a grid of images
rows = int(math.sqrt(num_images))
cols = int(math.ceil(num_images / rows))
grid = create_image_grid(images, rows, cols)
return grid
# Create Gradio interface
demo = gr.Interface(
fn=generate_multiple_images,
inputs=[
gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"),
gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images")
],
outputs=gr.Image(type="pil", label="Generated Computer Mice"),
title="DCGAN Computer Mouse Generator",
description="Generate multiple unique computer mouse designs using a DCGAN model.",
examples=[[42, 4], [23, 9], [7, 16]]
)
# Launch the app
if __name__ == "__main__":
demo.launch()