Jittor_LSGAN / lsgan_celebA.py
isLandLZ's picture
Upload lsgan_celebA.py
f9ffbd3
from jittor.dataset.mnist import MNIST
import jittor.transform as transform
from jittor.dataset.dataset import ImageFolder
import jittor as jt
from jittor import nn, Module
import os
import argparse
from time import *
import PIL.Image as Image
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
jt.flags.use_cuda = 1
# 参数设定
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型')
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
parser.add_argument('--eval_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址')
parser.add_argument('--n_epochs', type=int, default=100, help='训练的时期数')
parser.add_argument('--batch_size', type=int, default=64, help='批次大小')
parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
parser.add_argument('--img_size', type=int, default=112, help='每个图像尺寸的大小')
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数')
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数')
parser.add_argument('--img_row', type=int, default=5, help='图像样本之间的间隔')
parser.add_argument('--img_column', type=int, default=5, help='图像样本之间的间隔')
'''
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')
'''
opt = parser.parse_args()
print(opt)
# 训练集加载程序
def DataLoader(dataclass, img_size, batch_size, train_dir, eval_dir):
if dataclass == 'MNIST':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.Gray(),
transform.ImageNormalize(mean=[0.5], std=[0.5])])
train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True)
eval_loader = MNIST (data_root=eval_dir, train=False, transform = Transform).set_attrs(batch_size=1, shuffle=True)
elif dataclass == 'celebA':
Transform = transform.Compose([
transform.Resize(size=img_size),
transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
train_loader = ImageFolder(train_dir)\
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
eval_loader = ImageFolder(eval_dir)\
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True)
else:
print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass)
dataclass = input("请输入:MNIST或者celebA:")
DataLoader(dataclass, img_size, batch_size,train_dir, eval_dir)
return train_loader, eval_loader
# 加载训练集数据
train_loader, eval_loader = DataLoader(dataclass=opt.task,img_size=opt.img_size,batch_size=opt.batch_size,train_dir=opt.train_dir,eval_dir=opt.eval_dir)
# 生成器
class generator(Module):
def __init__(self, dim=3):
super(generator, self).__init__()
self.fc = nn.Linear(1024, 7*7*256)
self.fc_bn = nn.BatchNorm(256)
self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
self.deconv1_bn = nn.BatchNorm(256)
self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1)
self.deconv2_bn = nn.BatchNorm(256)
self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1)
self.deconv3_bn = nn.BatchNorm(256)
self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1)
self.deconv4_bn = nn.BatchNorm(256)
self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1)
self.deconv5_bn = nn.BatchNorm(128)
self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1)
self.deconv6_bn = nn.BatchNorm(64)
self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def execute(self, input):
x = self.fc(input).reshape((-1, 256, 7, 7))
x = self.relu(self.fc_bn(x))
x = self.relu(self.deconv1_bn(self.deconv1(x)))
x = self.relu(self.deconv2_bn(self.deconv2(x)))
x = self.relu(self.deconv3_bn(self.deconv3(x)))
x = self.relu(self.deconv4_bn(self.deconv4(x)))
x = self.relu(self.deconv5_bn(self.deconv5(x)))
x = self.relu(self.deconv6_bn(self.deconv6(x)))
x = self.tanh(self.deconv7(x))
return x
# 判别器
class discriminator(nn.Module):
def __init__(self, dim=3):
super(discriminator, self).__init__()
self.conv1 = nn.Conv(dim, 64, 5, 2, 2)
self.conv2 = nn.Conv(64, 128, 5, 2, 2)
self.conv2_bn = nn.BatchNorm(128)
self.conv3 = nn.Conv(128, 256, 5, 2, 2)
self.conv3_bn = nn.BatchNorm(256)
self.conv4 = nn.Conv(256, 512, 5, 2, 2)
self.conv4_bn = nn.BatchNorm(512)
self.fc = nn.Linear(512*7*7, 1)
self.leaky_relu = nn.Leaky_relu()
def execute(self, input):
x = self.leaky_relu(self.conv1(input), 0.2)
x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
x = x.reshape((x.shape[0], 512*7*7))
x = self.fc(x)
return x
# 损失函数
def ls_loss(x, b):
mini_batch = x.shape[0]
y_real_ = jt.ones((mini_batch,))
y_fake_ = jt.zeros((mini_batch,))
if b:
return (x-y_real_).sqr().mean()
else:
return (x-y_fake_).sqr().mean()
# 定义图像拼接函数
def image_compose(array,IMAGE_SIZE=128,IMAGE_SAVE_PATH='./images_celebA'):
to_image = Image.new('RGB', (opt.img_column * IMAGE_SIZE, opt.img_row * IMAGE_SIZE)) # 创建一个新图
randomList = np.random.randint(0,array.shape[0],25)
img_list = list()
for i in randomList:
# print(type(array[i]))
img = Image.fromarray(np.uint8(array[i].transpose((1,2,0))*255))
img_list.append(img)
# 循环遍历,把每张图片按顺序粘贴到对应位置上
for y in range(1, opt.img_row + 1):
for x in range(1, opt.img_column + 1):
from_image = img_list.pop().resize((IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS)
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE))
return to_image.save(IMAGE_SAVE_PATH) # 保存新图
def save_img_result(num_epoch, G, path = './images_celebA/result.png'):
fixed_z_ = jt.init.gauss((5 * 5, 1024), 'float') # fixed noise
z_ = fixed_z_
G.eval()
test_images = G(z_)
G.train()
size_figure_grid = 5
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
for i in range(size_figure_grid):
for j in range(size_figure_grid):
ax[i, j].get_xaxis().set_visible(False)
ax[i, j].get_yaxis().set_visible(False)
for k in range(5*5):
i = k // 5
j = k % 5
ax[i, j].cla()
if opt.task=="MNIST":
ax[i, j].imshow((test_images[k, 0].data+1)/2, cmap='gray')
else:
ax[i, j].imshow((test_images[k].data.transpose(1, 2, 0)+1)/2)
label = 'Epoch {0}'.format(num_epoch)
fig.text(0.5, 0.04, label, ha='center')
plt.savefig(path)
plt.close()
def train(epoch):
for batch_idx, (x_, target) in enumerate(train_loader):
mini_batch = x_.shape[0]
# 判别器训练 将假图片尽可能的判别为0
D_result = D(x_) #输入[128,3,112,112,] 生成[128,1] 128位batch_size
D_real_loss = ls_loss(D_result, True) #真实图片的损失
z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成随机噪声,大小为[128,1024]
G_result = G(z_) #输入噪声,生成[128,3,112,112,]
D_result_ = D(G_result) #输入由噪声生成的图像,得到判别器的预测值
D_fake_loss = ls_loss(D_result_, False) #假图片的损失
D_train_loss = D_real_loss + D_fake_loss
D_train_loss.sync()
D_optim.step(D_train_loss)
# 生成器训练 让生成器尽可能的生成真实的照片
z_ = jt.init.gauss((mini_batch, 1024), 'float') #生成噪声
G_result = G(z_) #由噪声生成假图片
D_result = D(G_result) #将假图片输入到判别器,得到预测值
G_train_loss = ls_loss(D_result, True) #将假图片的预测值与1做损失,目的是未来让生成器尽可能的生成真实的照片
G_train_loss.sync()
G_optim.step(G_train_loss)
if (batch_idx%100==0 ):
print("train: epoch{} batch_idx{} D training loss = {} G training loss = {} ".format(epoch,batch_idx,D_train_loss.data.mean(),G_train_loss.data.mean()))
# if((epoch)%5==0 or epoch==0 and batch_idx==100):
# image_compose(G_result.data,128,"./imgs/epoch{}-G_{}.jpg".format(epoch,task))
def validate(epoch):
D_losses = []
G_losses = []
G.eval()
D.eval()
for batch_idx, (x_, target) in enumerate(eval_loader):
mini_batch = x_.shape[0]
# 判别器损失计算
D_result = D(x_)
D_real_loss = ls_loss(D_result, True)
z_ = jt.init.gauss((mini_batch, 1024), 'float')
G_result = G(z_)
D_result_ = D(G_result)
D_fake_loss = ls_loss(D_result_, False)
D_train_loss = D_real_loss + D_fake_loss
D_losses.append(D_train_loss.data.mean())
# 生成器损失计算
z_ = jt.init.gauss((mini_batch, 1024), 'float')
G_result = G(z_)
D_result = D(G_result)
G_train_loss = ls_loss(D_result, True)
G_losses.append(G_train_loss.data.mean())
G.train()
D.train()
print("validate: epoch{}\tbatch_idx{}\tD training loss = {}\tG training loss = {}"
.format(epoch, batch_idx, str(np.array(D_losses).mean()), str(np.array(G_losses).mean())))
# 初始化生成器和判别器 (通道数)
G = generator(opt.celebA_channels)
D = discriminator(opt.celebA_channels)
# 优化器 0.0002 (0.5, 0.999)
G_optim = jt.nn.Adam(G.parameters(), opt.lr, betas=(opt.b1, opt.b2))
D_optim = jt.nn.Adam(D.parameters(), opt.lr, betas=(opt.b1, opt.b2))
# 结果存储地址
save_img_path = './images_celebA'
save_model_path = './save_model_celebA'
os.makedirs(save_img_path, exist_ok=True)
os.makedirs(save_model_path, exist_ok=True)
G.load_parameters(jt.load(save_model_path+'/generator_celebA.pkl'))
D.load_parameters(jt.load(save_model_path+'/discriminator_celebA.pkl'))
for epoch in range(37,opt.n_epochs):
print ('number of epochs', epoch)
train(epoch)
#validate(epoch)
result_img_path = save_img_path + '/' + str(epoch) + '.png'
save_img_result(epoch, G, path=result_img_path)
# 指定地址保存训练好的模型
if (epoch+1) % 10 == 0:
G.save(save_model_path+"/generator_celebA.pkl")
D.save(save_model_path+"/discriminator_celebA.pkl")