import streamlit as st import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torchvision.utils import make_grid import matplotlib.pyplot as plt # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyperparameters z_dim = 64 image_dim = 28 * 28 batch_size = 32 lr = 3e-4 # Load Data transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) dataset = torchvision.datasets.MNIST(root='dataset/', transform=transform, download=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) # Generator class Generator(nn.Module): def __init__(self, z_dim, img_dim): super().__init__() self.gen = nn.Sequential( nn.Linear(z_dim, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 1024), nn.ReLU(), nn.Linear(1024, img_dim), nn.Tanh() ) def forward(self, x): return self.gen(x) # Discriminator class Discriminator(nn.Module): def __init__(self, img_dim): super().__init__() self.disc = nn.Sequential( nn.Linear(img_dim, 1024), nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 1), nn.Sigmoid(), ) def forward(self, x): return self.disc(x) # Initialize generator and discriminator gen = Generator(z_dim, image_dim).to(device) disc = Discriminator(image_dim).to(device) # Optimizers opt_gen = optim.Adam(gen.parameters(), lr=lr) opt_disc = optim.Adam(disc.parameters(), lr=lr) # Loss function criterion = nn.BCELoss() # Function to train the model def train_gan(epochs): for epoch in range(epochs): for batch_idx, (real, _) in enumerate(dataloader): real = real.view(-1, 784).to(device) batch_size = real.shape[0] # Train Discriminator noise = torch.randn(batch_size, z_dim).to(device) fake = gen(noise) disc_real = disc(real).view(-1) lossD_real = criterion(disc_real, torch.ones_like(disc_real)) disc_fake = disc(fake).view(-1) lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) lossD = (lossD_real + lossD_fake) / 2 disc.zero_grad() lossD.backward(retain_graph=True) opt_disc.step() # Train Generator output = disc(fake).view(-1) lossG = criterion(output, torch.ones_like(output)) gen.zero_grad() lossG.backward() opt_gen.step() st.write(f"Epoch [{epoch+1}/{epochs}] Loss D: {lossD:.4f}, Loss G: {lossG:.4f}") return fake # Streamlit interface st.title("Simple GAN with Epoch Slider") epochs = st.slider("Number of Epochs", 1, 100, 1) if st.button("Train GAN"): fake_images = train_gan(epochs) fake_images = fake_images.view(-1, 1, 28, 28) fake_images = make_grid(fake_images, nrow=8, normalize=True) plt.imshow(fake_images.permute(1, 2, 0).cpu().detach().numpy(), cmap='gray') st.pyplot(plt.gcf())