Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
import os | |
import math | |
# Define your Generator architecture - with ngf=128 to match your training parameters | |
class Generator(nn.Module): | |
def __init__(self, ngpu=1, nz=100, ngf=128, nc=3): | |
super(Generator, self).__init__() | |
self.ngpu = ngpu | |
self.main = nn.Sequential( | |
# input is Z, going into a convolution | |
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), | |
nn.BatchNorm2d(ngf * 8), | |
nn.ReLU(True), | |
# state size. (ngf*8) x 4 x 4 | |
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf * 4), | |
nn.ReLU(True), | |
# state size. (ngf*4) x 8 x 8 | |
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf * 2), | |
nn.ReLU(True), | |
# state size. (ngf*2) x 16 x 16 | |
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), | |
nn.BatchNorm2d(ngf), | |
nn.ReLU(True), | |
# state size. (ngf) x 32 x 32 | |
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), | |
nn.Tanh() | |
# state size. (nc) x 64 x 64 | |
) | |
def forward(self, input): | |
return self.main(input) | |
# Load the model - Update path to point to the models folder | |
device = torch.device("cpu") | |
model_path = "models/netG_epoch_246.pth" | |
# Print file existence for debugging | |
print(f"Checking if model file exists: {os.path.exists(model_path)}") | |
print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}") | |
# Initialize the model with ngf=128 to match your training parameters | |
model = Generator(ngf=128).to(device) | |
# Try loading with error handling | |
try: | |
model.load_state_dict(torch.load(model_path, map_location=device)) | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# Try alternative loading methods if the first fails | |
try: | |
model.load_state_dict(torch.load(model_path, map_location=device), strict=False) | |
print("Model loaded with strict=False") | |
except Exception as e2: | |
print(f"Error with alternative loading: {e2}") | |
# Set model to evaluation mode | |
model.eval() | |
print(f"Model initialized: {model is not None}") | |
def create_image_grid(images, rows, cols): | |
"""Create a grid of images""" | |
w, h = images[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h)) | |
for i, image in enumerate(images): | |
grid.paste(image, box=(i%cols*w, i//cols*h)) | |
return grid | |
def generate_multiple_images(random_seed=42, num_images=4): | |
"""Generate multiple images using the DCGAN model""" | |
# Set seed for reproducibility | |
torch.manual_seed(random_seed) | |
# Generate multiple images | |
images = [] | |
for i in range(num_images): | |
# Generate random noise with different seeds | |
noise = torch.randn(1, 100, 1, 1, device=device) | |
# Generate fake image | |
with torch.no_grad(): | |
fake_image = model(noise).detach().cpu() | |
# Convert tensor to image | |
fake_img = fake_image * 0.5 + 0.5 # unnormalize | |
fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy() | |
fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8) | |
images.append(Image.fromarray(fake_img)) | |
# Create a grid of images | |
rows = int(math.sqrt(num_images)) | |
cols = int(math.ceil(num_images / rows)) | |
grid = create_image_grid(images, rows, cols) | |
return grid | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=generate_multiple_images, | |
inputs=[ | |
gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"), | |
gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images") | |
], | |
outputs=gr.Image(type="pil", label="Generated Computer Mice"), | |
title="DCGAN Computer Mouse Generator", | |
description="Generate multiple unique computer mouse designs using a DCGAN model.", | |
examples=[[42, 4], [23, 9], [7, 16]] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |