Spaces:
Sleeping
Sleeping
File size: 4,227 Bytes
79fa8cf 49c9fa6 e27d7b8 79fa8cf 49c9fa6 79fa8cf 49c9fa6 79fa8cf 49c9fa6 79fa8cf 49c9fa6 79fa8cf 49c9fa6 79fa8cf 49c9fa6 e27d7b8 49c9fa6 e27d7b8 49c9fa6 e27d7b8 49c9fa6 e27d7b8 79fa8cf 49c9fa6 e27d7b8 49c9fa6 e27d7b8 49c9fa6 79fa8cf 49c9fa6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import torch
import torch.nn as nn
import numpy as np
import gradio as gr
from PIL import Image
import os
import math
# Define your Generator architecture - with ngf=128 to match your training parameters
class Generator(nn.Module):
def __init__(self, ngpu=1, nz=100, ngf=128, nc=3):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# Load the model - Update path to point to the models folder
device = torch.device("cpu")
model_path = "models/netG_epoch_246.pth"
# Print file existence for debugging
print(f"Checking if model file exists: {os.path.exists(model_path)}")
print(f"Listing contents of models directory: {os.listdir('models') if os.path.exists('models') else 'models directory not found'}")
# Initialize the model with ngf=128 to match your training parameters
model = Generator(ngf=128).to(device)
# Try loading with error handling
try:
model.load_state_dict(torch.load(model_path, map_location=device))
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# Try alternative loading methods if the first fails
try:
model.load_state_dict(torch.load(model_path, map_location=device), strict=False)
print("Model loaded with strict=False")
except Exception as e2:
print(f"Error with alternative loading: {e2}")
# Set model to evaluation mode
model.eval()
print(f"Model initialized: {model is not None}")
def create_image_grid(images, rows, cols):
"""Create a grid of images"""
w, h = images[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, image in enumerate(images):
grid.paste(image, box=(i%cols*w, i//cols*h))
return grid
def generate_multiple_images(random_seed=42, num_images=4):
"""Generate multiple images using the DCGAN model"""
# Set seed for reproducibility
torch.manual_seed(random_seed)
# Generate multiple images
images = []
for i in range(num_images):
# Generate random noise with different seeds
noise = torch.randn(1, 100, 1, 1, device=device)
# Generate fake image
with torch.no_grad():
fake_image = model(noise).detach().cpu()
# Convert tensor to image
fake_img = fake_image * 0.5 + 0.5 # unnormalize
fake_img = fake_img.squeeze(0).permute(1, 2, 0).numpy()
fake_img = np.clip(fake_img * 255, 0, 255).astype(np.uint8)
images.append(Image.fromarray(fake_img))
# Create a grid of images
rows = int(math.sqrt(num_images))
cols = int(math.ceil(num_images / rows))
grid = create_image_grid(images, rows, cols)
return grid
# Create Gradio interface
demo = gr.Interface(
fn=generate_multiple_images,
inputs=[
gr.Slider(minimum=1, maximum=100, step=1, default=42, label="Random Seed"),
gr.Slider(minimum=1, maximum=16, step=1, default=4, label="Number of Images")
],
outputs=gr.Image(type="pil", label="Generated Computer Mice"),
title="DCGAN Computer Mouse Generator",
description="Generate multiple unique computer mouse designs using a DCGAN model.",
examples=[[42, 4], [23, 9], [7, 16]]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |