File size: 3,915 Bytes
71fd799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import jittor as jt
from jittor import init
from jittor import nn

import argparse
import numpy as np
import cv2

jt.flags.use_cuda = 1

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='训练的时期数')
parser.add_argument('--batch_size', type=int, default=64, help='批次大小')
parser.add_argument('--lr', type=float, default=0.0002, help='学习率')
parser.add_argument('--b1', type=float, default=0.5, help='梯度的一阶动量衰减')
parser.add_argument('--b2', type=float, default=0.999, help='梯度的一阶动量衰减')
parser.add_argument('--n_cpu', type=int, default=8, help='批处理生成期间要使用的 cpu 线程数')
parser.add_argument('--latent_dim', type=int, default=100, help='潜在空间的维度')
parser.add_argument('--img_size', type=int, default=28, help='每个图像尺寸的大小')
parser.add_argument('--channels', type=int, default=1, help='图像通道数')
parser.add_argument('--sample_interval', type=int, default=400, help='图像样本之间的间隔')

opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)

# 生成器
class Generator(nn.Module):

    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(scale=0.2))
            return layers
        self.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False), *block(128, 256), *block(256, 512), *block(512, 1024), nn.Linear(1024, int(np.prod(img_shape))), nn.Tanh())

    def execute(self, z):
        img = self.model(z)
        img = img.view((img.shape[0], *img_shape))
        return img

# 判别器
class Discriminator(nn.Module):

    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512), nn.LeakyReLU(scale=0.2), nn.Linear(512, 256), nn.LeakyReLU(scale=0.2), nn.Linear(256, 1), nn.Sigmoid())

    def execute(self, img):
        img_flat = img.view((img.shape[0], (- 1)))
        validity = self.model(img_flat)
        return validity

def deal_image(img, path=None, nrow=None):
    N,C,W,H = img.shape# (25, 1, 28, 28)
    '''
    [-1,700,28] , img2的形状(1,700,28)
    img[0][0][0] = img2[0][0]
    img2:[
            [1*28]
            ......(一共700个)    
                    ](1,700,28)
    '''
    img2=img.reshape([-1,W*nrow*nrow,H])
    # [:,:28*5,:],img:(1,140,28)
    img=img2[:,:W*nrow,:]
    for i in range(1,nrow):#[1,5)
        '''
        img(1,140,28),img2(1,700,28)
        img从(1,140,28)->(1,140,28+28)->...->(1,140,28+28+28+28)=(1,140,140)
        np.concatenate把两个三维数组合并
        '''        
        img=np.concatenate([img,img2[:,W*nrow*i:W*nrow*(i+1),:]],axis=2)
    # img中的数据大小从(-1,1)--(+1)-->(0,2)--(/2)-->(0,1)--(*255)-->(0,255)转换成了像素值
    img=(img+1.0)/2.0*255
    # (1,140,140)--->(140,140,1)
    # (channels通道数,imagesize,imagesize)转化为(imagesize,imagesize,channels通道数)
    img=img.transpose((1,2,0))
    if path:
        # 根据地址保存图片样本数据
        cv2.imwrite(path,img)
    cv2.imshow('1',img)
    cv2.waitKey(0)

# 初始化生成器和判别器,并加载模型
generator = Generator()
g_model_path = "saved_models/generator_last.pkl"
generator.load_parameters(jt.load(g_model_path))
generator.load(g_model_path)
discriminator = Discriminator()
d_model_path = "saved_models/discriminator_last.pkl"
discriminator.load_parameters(jt.load(d_model_path))
discriminator.load(d_model_path)

z = jt.array(np.random.normal(0, 1, (64, opt.latent_dim)).astype(np.float32))
gen_imgs = generator(z)
deal_image(gen_imgs.data[:25], "images_test/1.png",nrow=5)