import torch import torch.nn as nn import numpy as np import gradio as gr from PIL import Image import os import math # Define your Generator architecture - with ngf=128 to match your training parameters class Generator(nn.Module): def __init__(self, ngpu=1, nz=100, ngf=128, nc=3): super(Generator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh() # state size. (nc) x 64 x 64 ) def forward(self, input): return self.main(input) # Load the model - Update path to point to the models folder device = torch.device("cpu") model_path = "models/netG_epoch_246.pth" # Print file existence for debugging print(f"Checking if model file exists: {os.path.exists(model_path)}") print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") # Initialize the model with ngf=128 to match your training parameters model = Generator(ngf=128).to(device) # Try loading with error handling try: model.load_state_dict(torch.load(model_path, map_location=device)) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # Try alternative loading methods if the first fails try: model.load_state_dict(torch.load(model_path, map_location=device), strict=False) print("Model loaded with strict=False") except Exception as e2: print(f"Error with alternative loading: {e2}") # Set model to evaluation mode model.eval() print(f"Model initialized: {model is not None}") def create_image_grid(images, rows, cols): """Create a grid of images""" w, h = images[0].size grid = Image.new('RGB', size=(cols*w, rows*h)) for i, image in enumerate(images): grid.paste(image, box=(i%cols*w, i//cols*h)) return grid def generate_multiple_images(random_seed=42, num_images=4): """Generate multiple images using the DCGAN model""" # Set seed for reproducibility torch.manual_seed(random_seed) # Generate multiple images images = [] for i in range(num_images): # Generate random noise with different seeds noise = torch.randn(1, 100, 1, 1, device=device) # Generate fake image with torch.no_grad(): fake_image = model(noise).detach().cpu() # Convert tensor to image fake_img = fake_image * 0.5 + 0.5 # unnormalize fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy() fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8) images.append(Image.fromarray(fake_img)) # Create a grid of images rows = int(math.sqrt(num_images)) cols = int(math.ceil(num_images / rows)) grid = create_image_grid(images, rows, cols) return grid # Create Gradio interface demo = gr.Interface( fn=generate_multiple_images, inputs=[ gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"), gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images") ], outputs=gr.Image(type="pil", label="Generated Computer Mice"), title="DCGAN Computer Mouse Generator", description="Generate multiple unique computer mouse designs using a DCGAN model.", examples=[[42, 4], [23, 9], [7, 16]] ) # Launch the app if __name__ == "__main__": demo.launch()