|
import jittor as jt |
|
from jittor import init |
|
from jittor import nn |
|
|
|
import argparse |
|
import numpy as np |
|
import cv2 |
|
|
|
jt.flags.use_cuda = 1 |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数') |
|
parser.add_argument('--batch_size', type=int, default=64, help='批次大小') |
|
parser.add_argument('--lr', type=float, default=0.0002, help='学习率') |
|
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') |
|
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') |
|
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') |
|
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') |
|
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小') |
|
parser.add_argument('--channels', type=int, default=1, help='图像通道数') |
|
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') |
|
|
|
opt = parser.parse_args() |
|
print(opt) |
|
img_shape = (opt.channels, opt.img_size, opt.img_size) |
|
|
|
|
|
class Generator(nn.Module): |
|
|
|
def __init__(self): |
|
super(Generator, self).__init__() |
|
|
|
def block(in_feat, out_feat, normalize=True): |
|
layers = [nn.Linear(in_feat, out_feat)] |
|
if normalize: |
|
layers.append(nn.BatchNorm1d(out_feat, 0.8)) |
|
layers.append(nn.LeakyReLU(scale=0.2)) |
|
return layers |
|
self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh()) |
|
|
|
def execute(self, z): |
|
img = self.model(z) |
|
img = img.view((img.shape[0], *img_shape)) |
|
return img |
|
|
|
|
|
class Discriminator(nn.Module): |
|
|
|
def __init__(self): |
|
super(Discriminator, self).__init__() |
|
self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(scale=0.2), nn.Linear(512, 256), nn.LeakyReLU(scale=0.2), nn.Linear(256, 1), nn.Sigmoid()) |
|
|
|
def execute(self, img): |
|
img_flat = img.view((img.shape[0], (- 1))) |
|
validity = self.model(img_flat) |
|
return validity |
|
|
|
def deal_image(img, path=None, nrow=None): |
|
N,C,W,H = img.shape |
|
''' |
|
[-1,700,28] , img2的形状(1,700,28) |
|
img[0][0][0] = img2[0][0] |
|
img2:[ |
|
[1*28] |
|
......(一共700个) |
|
](1,700,28) |
|
''' |
|
img2=img.reshape([-1,W*nrow*nrow,H]) |
|
|
|
img=img2[:,:W*nrow,:] |
|
for i in range(1,nrow): |
|
''' |
|
img(1,140,28),img2(1,700,28) |
|
img从(1,140,28)->(1,140,28+28)->...->(1,140,28+28+28+28)=(1,140,140) |
|
np.concatenate把两个三维数组合并 |
|
''' |
|
img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2) |
|
|
|
img=(img+1.0)/2.0*255 |
|
|
|
|
|
img=img.transpose((1,2,0)) |
|
if path: |
|
|
|
cv2.imwrite(path,img) |
|
cv2.imshow('1',img) |
|
cv2.waitKey(0) |
|
|
|
|
|
generator = Generator() |
|
g_model_path = "saved_models/generator_last.pkl" |
|
generator.load_parameters(jt.load(g_model_path)) |
|
generator.load(g_model_path) |
|
discriminator = Discriminator() |
|
d_model_path = "saved_models/discriminator_last.pkl" |
|
discriminator.load_parameters(jt.load(d_model_path)) |
|
discriminator.load(d_model_path) |
|
|
|
z = jt.array(np.random.normal(0, 1, (64, opt.latent_dim)).astype(np.float32)) |
|
gen_imgs = generator(z) |
|
deal_image(gen_imgs.data[:25], "images_test/1.png",nrow=5) |