import jittor as jt from jittor import init from jittor import nn import argparse import numpy as np import cv2 jt.flags.use_cuda = 1 parser = argparse.ArgumentParser() parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数') parser.add_argument('--batch_size', type=int, default=64, help='批次大小') parser.add_argument('--lr', type=float, default=0.0002, help='学习率') parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减') parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减') parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数') parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度') parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小') parser.add_argument('--channels', type=int, default=1, help='图像通道数') parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔') opt = parser.parse_args() print(opt) img_shape = (opt.channels, opt.img_size, opt.img_size) # 生成器 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() def block(in_feat, out_feat, normalize=True): layers = [nn.Linear(in_feat, out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat, 0.8)) layers.append(nn.LeakyReLU(scale=0.2)) return layers self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh()) def execute(self, z): img = self.model(z) img = img.view((img.shape[0], *img_shape)) return img # 判别器 class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(scale=0.2), nn.Linear(512, 256), nn.LeakyReLU(scale=0.2), nn.Linear(256, 1), nn.Sigmoid()) def execute(self, img): img_flat = img.view((img.shape[0], (- 1))) validity = self.model(img_flat) return validity def deal_image(img, path=None, nrow=None): N,C,W,H = img.shape# (25, 1, 28, 28) ''' [-1,700,28] , img2的形状(1,700,28) img[0][0][0] = img2[0][0] img2:[ [1*28] ......(一共700个) ](1,700,28) ''' img2=img.reshape([-1,W*nrow*nrow,H]) # [:,:28*5,:],img:(1,140,28) img=img2[:,:W*nrow,:] for i in range(1,nrow):#[1,5) ''' img(1,140,28),img2(1,700,28) img从(1,140,28)->(1,140,28+28)->...->(1,140,28+28+28+28)=(1,140,140) np.concatenate把两个三维数组合并 ''' img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2) # img中的数据大小从(-1,1)--(+1)-->(0,2)--(/2)-->(0,1)--(*255)-->(0,255)转换成了像素值 img=(img+1.0)/2.0*255 # (1,140,140)--->(140,140,1) # (channels通道数,imagesize,imagesize)转化为(imagesize,imagesize,channels通道数) img=img.transpose((1,2,0)) if path: # 根据地址保存图片样本数据 cv2.imwrite(path,img) cv2.imshow('1',img) cv2.waitKey(0) # 初始化生成器和判别器,并加载模型 generator = Generator() g_model_path = "saved_models/generator_last.pkl" generator.load_parameters(jt.load(g_model_path)) generator.load(g_model_path) discriminator = Discriminator() d_model_path = "saved_models/discriminator_last.pkl" discriminator.load_parameters(jt.load(d_model_path)) discriminator.load(d_model_path) z = jt.array(np.random.normal(0, 1, (64, opt.latent_dim)).astype(np.float32)) gen_imgs = generator(z) deal_image(gen_imgs.data[:25], "images_test/1.png",nrow=5)