|
from jittor.dataset.mnist import MNIST |
|
import jittor.transform as transform |
|
from jittor.dataset.dataset import ImageFolder |
|
import jittor as jt |
|
from jittor import nn, Module |
|
import os |
|
import argparse |
|
from time import * |
|
import PIL.Image as Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
plt.switch_backend('agg') |
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--task', type=str, default='celebA', help='训练数据集类型') |
|
parser.add_argument('--train_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') |
|
parser.add_argument('--eval_dir', type=str, default='D:\\Image_Generation_Learn\\Dataset\\CelebA_train', help='训练数据集地址') |
|
parser.add_argument('--n_epochs', type=int, default=100, help='训练的时期数') |
|
parser.add_argument('--batch_size', type=int, default=64, help='批次大小') |
|
parser.add_argument('--lr', type=float, default=0.0002, help='学习率') |
|
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') |
|
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') |
|
parser.add_argument('--img_size', type=int, default=112, help='每个图像尺寸的大小') |
|
parser.add_argument('--celebA_channels', type=int, default=3, help='图像通道数') |
|
parser.add_argument('--mnist_channels', type=int, default=1, help='图像通道数') |
|
parser.add_argument('--img_row', type=int, default=5, help='图像样本之间的间隔') |
|
parser.add_argument('--img_column', type=int, default=5, help='图像样本之间的间隔') |
|
''' |
|
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') |
|
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') |
|
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') |
|
''' |
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
|
|
def DataLoader(dataclass, img_size, batch_size, train_dir, eval_dir): |
|
if dataclass == 'MNIST': |
|
Transform = transform.Compose([ |
|
transform.Resize(size=img_size), |
|
transform.Gray(), |
|
transform.ImageNormalize(mean=[0.5], std=[0.5])]) |
|
train_loader = MNIST (data_root=train_dir, train=True, transform=Transform).set_attrs(batch_size=batch_size, shuffle=True) |
|
eval_loader = MNIST (data_root=eval_dir, train=False, transform = Transform).set_attrs(batch_size=1, shuffle=True) |
|
elif dataclass == 'celebA': |
|
Transform = transform.Compose([ |
|
transform.Resize(size=img_size), |
|
transform.ImageNormalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])]) |
|
train_loader = ImageFolder(train_dir)\ |
|
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) |
|
eval_loader = ImageFolder(eval_dir)\ |
|
.set_attrs(transform=Transform, batch_size=batch_size, shuffle=True) |
|
else: |
|
print("没有加载%s数据集的程序,请选择MNIST或者celebA!" % dataclass) |
|
dataclass = input("请输入:MNIST或者celebA:") |
|
DataLoader(dataclass, img_size, batch_size,train_dir, eval_dir) |
|
|
|
return train_loader, eval_loader |
|
|
|
|
|
train_loader, eval_loader = DataLoader(dataclass=opt.task,img_size=opt.img_size,batch_size=opt.batch_size,train_dir=opt.train_dir,eval_dir=opt.eval_dir) |
|
|
|
|
|
class generator(Module): |
|
def __init__(self, dim=3): |
|
super(generator, self).__init__() |
|
self.fc = nn.Linear(1024, 7*7*256) |
|
self.fc_bn = nn.BatchNorm(256) |
|
self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) |
|
self.deconv1_bn = nn.BatchNorm(256) |
|
self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1) |
|
self.deconv2_bn = nn.BatchNorm(256) |
|
self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) |
|
self.deconv3_bn = nn.BatchNorm(256) |
|
self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1) |
|
self.deconv4_bn = nn.BatchNorm(256) |
|
self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1) |
|
self.deconv5_bn = nn.BatchNorm(128) |
|
self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1) |
|
self.deconv6_bn = nn.BatchNorm(64) |
|
self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1) |
|
self.relu = nn.ReLU() |
|
self.tanh = nn.Tanh() |
|
|
|
def execute(self, input): |
|
x = self.fc(input).reshape((-1, 256, 7, 7)) |
|
x = self.relu(self.fc_bn(x)) |
|
x = self.relu(self.deconv1_bn(self.deconv1(x))) |
|
x = self.relu(self.deconv2_bn(self.deconv2(x))) |
|
x = self.relu(self.deconv3_bn(self.deconv3(x))) |
|
x = self.relu(self.deconv4_bn(self.deconv4(x))) |
|
x = self.relu(self.deconv5_bn(self.deconv5(x))) |
|
x = self.relu(self.deconv6_bn(self.deconv6(x))) |
|
x = self.tanh(self.deconv7(x)) |
|
return x |
|
|
|
|
|
class discriminator(nn.Module): |
|
def __init__(self, dim=3): |
|
super(discriminator, self).__init__() |
|
self.conv1 = nn.Conv(dim, 64, 5, 2, 2) |
|
self.conv2 = nn.Conv(64, 128, 5, 2, 2) |
|
self.conv2_bn = nn.BatchNorm(128) |
|
self.conv3 = nn.Conv(128, 256, 5, 2, 2) |
|
self.conv3_bn = nn.BatchNorm(256) |
|
self.conv4 = nn.Conv(256, 512, 5, 2, 2) |
|
self.conv4_bn = nn.BatchNorm(512) |
|
self.fc = nn.Linear(512*7*7, 1) |
|
self.leaky_relu = nn.Leaky_relu() |
|
|
|
def execute(self, input): |
|
x = self.leaky_relu(self.conv1(input), 0.2) |
|
x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) |
|
x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) |
|
x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) |
|
x = x.reshape((x.shape[0], 512*7*7)) |
|
x = self.fc(x) |
|
return x |
|
|
|
|
|
def ls_loss(x, b): |
|
mini_batch = x.shape[0] |
|
y_real_ = jt.ones((mini_batch,)) |
|
y_fake_ = jt.zeros((mini_batch,)) |
|
if b: |
|
return (x-y_real_).sqr().mean() |
|
else: |
|
return (x-y_fake_).sqr().mean() |
|
|
|
|
|
def image_compose(array,IMAGE_SIZE=128,IMAGE_SAVE_PATH='./images_celebA'): |
|
to_image = Image.new('RGB', (opt.img_column * IMAGE_SIZE, opt.img_row * IMAGE_SIZE)) |
|
randomList = np.random.randint(0,array.shape[0],25) |
|
img_list = list() |
|
for i in randomList: |
|
|
|
img = Image.fromarray(np.uint8(array[i].transpose((1,2,0))*255)) |
|
img_list.append(img) |
|
|
|
|
|
for y in range(1, opt.img_row + 1): |
|
for x in range(1, opt.img_column + 1): |
|
from_image = img_list.pop().resize((IMAGE_SIZE, IMAGE_SIZE), Image.ANTIALIAS) |
|
to_image.paste(from_image, ((x - 1) * IMAGE_SIZE, (y - 1) * IMAGE_SIZE)) |
|
return to_image.save(IMAGE_SAVE_PATH) |
|
|
|
def save_img_result(num_epoch, G, path = './images_celebA/result.png'): |
|
fixed_z_ = jt.init.gauss((5 * 5, 1024), 'float') |
|
z_ = fixed_z_ |
|
G.eval() |
|
test_images = G(z_) |
|
G.train() |
|
size_figure_grid = 5 |
|
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) |
|
for i in range(size_figure_grid): |
|
for j in range(size_figure_grid): |
|
ax[i, j].get_xaxis().set_visible(False) |
|
ax[i, j].get_yaxis().set_visible(False) |
|
|
|
for k in range(5*5): |
|
i = k // 5 |
|
j = k % 5 |
|
ax[i, j].cla() |
|
if opt.task=="MNIST": |
|
ax[i, j].imshow((test_images[k, 0].data+1)/2, cmap='gray') |
|
else: |
|
ax[i, j].imshow((test_images[k].data.transpose(1, 2, 0)+1)/2) |
|
|
|
label = 'Epoch {0}'.format(num_epoch) |
|
fig.text(0.5, 0.04, label, ha='center') |
|
plt.savefig(path) |
|
plt.close() |
|
|
|
def train(epoch): |
|
for batch_idx, (x_, target) in enumerate(train_loader): |
|
mini_batch = x_.shape[0] |
|
|
|
|
|
D_result = D(x_) |
|
D_real_loss = ls_loss(D_result, True) |
|
z_ = jt.init.gauss((mini_batch, 1024), 'float') |
|
G_result = G(z_) |
|
D_result_ = D(G_result) |
|
D_fake_loss = ls_loss(D_result_, False) |
|
D_train_loss = D_real_loss + D_fake_loss |
|
D_train_loss.sync() |
|
D_optim.step(D_train_loss) |
|
|
|
|
|
z_ = jt.init.gauss((mini_batch, 1024), 'float') |
|
G_result = G(z_) |
|
D_result = D(G_result) |
|
G_train_loss = ls_loss(D_result, True) |
|
G_train_loss.sync() |
|
G_optim.step(G_train_loss) |
|
if (batch_idx%100==0 ): |
|
print("train: epoch{} batch_idx{} D training loss = {} G training loss = {} ".format(epoch,batch_idx,D_train_loss.data.mean(),G_train_loss.data.mean())) |
|
|
|
|
|
|
|
def validate(epoch): |
|
D_losses = [] |
|
G_losses = [] |
|
G.eval() |
|
D.eval() |
|
for batch_idx, (x_, target) in enumerate(eval_loader): |
|
mini_batch = x_.shape[0] |
|
|
|
|
|
D_result = D(x_) |
|
D_real_loss = ls_loss(D_result, True) |
|
z_ = jt.init.gauss((mini_batch, 1024), 'float') |
|
G_result = G(z_) |
|
D_result_ = D(G_result) |
|
D_fake_loss = ls_loss(D_result_, False) |
|
D_train_loss = D_real_loss + D_fake_loss |
|
D_losses.append(D_train_loss.data.mean()) |
|
|
|
|
|
z_ = jt.init.gauss((mini_batch, 1024), 'float') |
|
G_result = G(z_) |
|
D_result = D(G_result) |
|
G_train_loss = ls_loss(D_result, True) |
|
G_losses.append(G_train_loss.data.mean()) |
|
G.train() |
|
D.train() |
|
print("validate: epoch{}\tbatch_idx{}\tD training loss = {}\tG training loss = {}" |
|
.format(epoch, batch_idx, str(np.array(D_losses).mean()), str(np.array(G_losses).mean()))) |
|
|
|
|
|
|
|
G = generator(opt.celebA_channels) |
|
D = discriminator(opt.celebA_channels) |
|
|
|
|
|
G_optim = jt.nn.Adam(G.parameters(), opt.lr, betas=(opt.b1, opt.b2)) |
|
D_optim = jt.nn.Adam(D.parameters(), opt.lr, betas=(opt.b1, opt.b2)) |
|
|
|
|
|
save_img_path = './images_celebA' |
|
save_model_path = './save_model_celebA' |
|
os.makedirs(save_img_path, exist_ok=True) |
|
os.makedirs(save_model_path, exist_ok=True) |
|
|
|
G.load_parameters(jt.load(save_model_path+'/generator_celebA.pkl')) |
|
D.load_parameters(jt.load(save_model_path+'/discriminator_celebA.pkl')) |
|
|
|
for epoch in range(37,opt.n_epochs): |
|
print ('number of epochs', epoch) |
|
train(epoch) |
|
|
|
result_img_path = save_img_path + '/' + str(epoch) + '.png' |
|
save_img_result(epoch, G, path=result_img_path) |
|
|
|
|
|
if (epoch+1) % 10 == 0: |
|
G.save(save_model_path+"/generator_celebA.pkl") |
|
D.save(save_model_path+"/discriminator_celebA.pkl") |
|
|