A股上市公司传智教育(股票代码 003032)旗下技术交流社区北京昌平校区

 找回密码
 加入黑马

QQ登录

只需一步,快速开始

简介除VAE之外,生成式对抗网络(Generative Adversarial Nets,GAN)也是一种非常流行的无监督生成式模型
GAN中主要包括两个核心网络
  • 生成器(Generator):记作G,通过对大量样本的学习,能够生成一些以假乱真的样本,和VAE类似
  • 判别器(Discriminator):记作D,接受真实样本和G生成的样本,并进行判别和区分
  • G和D相互博弈,通过学习,G的生成能力和D的判别能力都逐渐增强并收敛
GAN的训练非常困难,有很多细节需要注意,才能生成质量较高的图片
  • 恰当地使用Batch Normalization、LeakyReLU
  • 用strides为2的卷积代替池化
  • 交替地训练,避免一方过强
这里我们以MNIST为例,通过TensorFlow实现GAN,由于用到深度卷积神经网络,所以也称作DCGAN(Deep Convolutional GAN)


原理对于一个服从随机分布的噪音z,生成器通过一个复杂的映射函数生成假的样本
判别器则使用另一个复杂的映射函数,对于真实样本或假的样本,输出一个0至1之间的值,越大表示越有可能是真实的样本
总的目标函数如下
实现加载库
# -*- coding: utf-8 -*-import tensorflow as tfimport numpy as npimport matplotlib.pyplot as plt%matplotlib inlineimport os, imageio复制代码加载数据
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data')复制代码定义一些常量、网络输入、辅助函数
batch_size = 100z_dim = 100OUTPUT_DIR = 'samples'if not os.path.exists(OUTPUT_DIR):    os.mkdir(OUTPUT_DIR)X = tf.placeholder(dtype=tf.float32, shape=[None, 28, 28, 1], name='X')noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise')is_training = tf.placeholder(dtype=tf.bool, name='is_training')def lrelu(x, leak=0.2):    return tf.maximum(x, leak * x)def sigmoid_cross_entropy_with_logits(x, y):    return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)复制代码判别器部分
def discriminator(image, reuse=None, is_training=is_training):    momentum = 0.9    with tf.variable_scope('discriminator', reuse=reuse):        h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))                h1 = tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same')        h1 = lrelu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))                h2 = tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same')        h2 = lrelu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))                h3 = tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same')        h3 = lrelu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))                h4 = tf.contrib.layers.flatten(h3)        h4 = tf.layers.dense(h4, units=1)        return tf.nn.sigmoid(h4), h4复制代码生成器部分
def generator(z, is_training=is_training):    momentum = 0.9    with tf.variable_scope('generator', reuse=None):        d = 3        h0 = tf.layers.dense(z, units=d * d * 512)        h0 = tf.reshape(h0, shape=[-1, d, d, 512])        h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))                h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')        h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))                h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')        h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))                h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')        h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))                h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=1, strides=1, padding='valid', activation=tf.nn.tanh, name='g')        return h4 复制代码定义损失函数,注意这里实现了两个判别器,但参数是共享的
g = generator(noise)d_real, d_real_logits = discriminator(X)d_fake, d_fake_logits = discriminator(g, reuse=True)vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits, tf.ones_like(d_real)))loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.ones_like(d_fake)))loss_d = loss_d_real + loss_d_fake复制代码定义优化函数,注意损失函数需要和可调参数对应上
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)with tf.control_dependencies(update_ops):    optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)    optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)复制代码定义一个辅助函数,用于将多张图片以网格状拼在一起显示
def montage(images):    if isinstance(images, list):        images = np.array(images)    img_h = images.shape[1]    img_w = images.shape[2]    n_plots = int(np.ceil(np.sqrt(images.shape[0])))    m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5    for i in range(n_plots):        for j in range(n_plots):            this_filter = i * n_plots + j            if this_filter < images.shape[0]:                this_img = images[this_filter]                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img    return m复制代码开始训练,每次迭代训练G两次
sess = tf.Session()sess.run(tf.global_variables_initializer())z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)samples = []loss = {'d': [], 'g': []}for i in range(60000):    n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)    batch = mnist.train.next_batch(batch_size=batch_size)[0]    batch = np.reshape(batch, [-1, 28, 28, 1])    batch = (batch - 0.5) * 2        d_ls, g_ls = sess.run([loss_d, loss_g], feed_dict={X: batch, noise: n, is_training: True})    loss['d'].append(d_ls)    loss['g'].append(g_ls)        sess.run(optimizer_d, feed_dict={X: batch, noise: n, is_training: True})    sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True})    sess.run(optimizer_g, feed_dict={X: batch, noise: n, is_training: True})            if i % 1000 == 0:        print(i, d_ls, g_ls)        gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})        gen_imgs = (gen_imgs + 1) / 2        imgs = [img[:, :, 0] for img in gen_imgs]        gen_imgs = montage(imgs)        plt.axis('off')        plt.imshow(gen_imgs, cmap='gray')        plt.savefig(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i))        plt.show()        samples.append(gen_imgs)plt.plot(loss['d'], label='Discriminator')plt.plot(loss['g'], label='Generator')plt.legend(loc='upper right')plt.savefig('Loss.png')plt.show()imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5)复制代码生成的图片如下,由于损失函数中并未使用到逐像素比较,因此图形边缘不会出现模糊


保存模型,便于后续使用
saver = tf.train.Saver()saver.save(sess, './mnist_dcgan', global_step=60000)复制代码加载模型,如果需要的话,例如在单机上使用
# -*- coding: utf-8 -*-import tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltbatch_size = 100z_dim = 100def montage(images):    if isinstance(images, list):        images = np.array(images)    img_h = images.shape[1]    img_w = images.shape[2]    n_plots = int(np.ceil(np.sqrt(images.shape[0])))    m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5    for i in range(n_plots):        for j in range(n_plots):            this_filter = i * n_plots + j            if this_filter < images.shape[0]:                this_img = images[this_filter]                m[1 + i + i * img_h:1 + i + (i + 1) * img_h,                  1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img    return msess = tf.Session()sess.run(tf.global_variables_initializer())saver = tf.train.import_meta_graph('./mnist_dcgan-60000.meta')saver.restore(sess, tf.train.latest_checkpoint('./'))graph = tf.get_default_graph()g = graph.get_tensor_by_name('generator/g/Tanh:0')noise = graph.get_tensor_by_name('noise:0')is_training = graph.get_tensor_by_name('is_training:0')n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})gen_imgs = (gen_imgs + 1) / 2imgs = [img[:, :, 0] for img in gen_imgs]gen_imgs = montage(imgs)plt.axis('off')plt.imshow(gen_imgs, cmap='gray')plt.show()复制代码参考


链接:https://juejin.im/post/5ba25816e51d450e78260ce9



2 个回复

倒序浏览
奈斯
回复 使用道具 举报
回复 使用道具 举报
您需要登录后才可以回帖 登录 | 加入黑马