GAN系列-03 Deep Convolutional Generative Adversarial Network

GAN系列-03 Deep Convolutional Generative Adversarial Network

Deep Convolutional Generative Adversarial Network

前言

DCGAN這篇可以看成將GAN應用到CNN的論文,所使用的Objective Function跟GAN一樣,不同的地方在這篇論文的作者使用CNN跟Deconvolution(DeConv)來當訓練用的模型。在DCGAN中,Discriminator是一個CNN的網路,輸入一張圖片,然後輸出是真還假的機率;而Generator和CNN相反,輸入noise通過Generator生成一張圖片,因為vector通過一層層layer而逐漸變大,和CNN的作用相反所以稱之為DeConv。

DCGAN的特點

  1. 不使用pooling。Discriminator用帶有stride的Convolution Layer來取代Pooling;在Generator中使用Transposed Convolution Layer,而非UnPooling。
  2. 在Generator和Discriminator上都使用Batch Normalization。處理初始化不良導致的訓練問題,幫助梯度傳播到每一層,防止Generator把所有樣本都收斂到同一個點,用來穩定學習。
  3. 在模型中使用LeakyReLU當作activation function。

網路架構

參數宣告

可以看到除了必要儲存的檔案夾、參數以外,還多了一個TFRecordPath,這是因為在訓練之前會先將整個資料集轉換成TFRecord的形式,

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from absl import app
from absl import flags
from absl import logging

FLAGS = flags.FLAGS
flags.DEFINE_string('DatasetPath', 'YourDatasetPath',
'path to dataset')
flags.DEFINE_enum('Partition', 'all', ['train', 'val', 'test', 'all'],
'specify train, val or test part')
flags.DEFINE_enum('ImageType', 'img_celeba', ['img_celeba', 'img_align_celeba', 'img_align_celeba_png', 'MNIST'],
'specify celeba image type or use mnist')
flags.DEFINE_boolean('bbox', True, 'crop image by bbox(only run when ImageType is img_celeba)')

flags.DEFINE_string('TFRecordPath',
'TFRecordFilePath',
'path to save TFRecord file')
flags.DEFINE_string('LOG_PATH', 'logs/', 'path to log_dir')
flags.DEFINE_string('MODEL_PATH', 'models/', 'path to save the model')

flags.DEFINE_integer('DisInSize', 64, 'discriminator input image size')
flags.DEFINE_integer('noise_dim', 100, 'noise dimension')
flags.DEFINE_integer('BATCH_SIZE', 128, 'batch size')
flags.DEFINE_float('lr', 2e-4, 'learning rate')
flags.DEFINE_integer('epochs', 35, 'epoch')

資料集

在訓練過程中使用的是CelebA和MNIST資料集,CelebA根據圖片的類型共有3種不一樣的圖像資料分別為img_celebaimg_align_celebaimg_align_celeba_png,其中後兩個差異為儲存的檔案為jpg與png,此外img_celeba具有bbox的標籤,bbox會框出圖片中人臉的所在。

根據FLAGS.ImageType來決定需要讀入的為CelebA還是MNIST,從TFRecord檔案中讀入選擇的FLAGS.Partition資料,再透過tf.data.dataset來讀取資料、建構輸入資料的pipeline。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def create_dataset():
'''
-----------------------Data Set-----------------------
'''
if FLAGS.ImageType == 'MNIST':
FLAGS.bbox = False
[(train_x, train_y), (test_x, test_y)] = load_data('mnist.npz')

train_images = train_x.reshape(train_x.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5

train_dataset = tf.data.Dataset.from_tensor_slices(train_images)
train_dataset = train_dataset.map(lambda x: tf.image.resize_with_pad(x, FLAGS.DisInSize, FLAGS.DisInSize))
else:
logging.info('Covert CelebA data to TRRecord.')
if FLAGS.Partition == 'all':
TFRecordPath = os.path.join(FLAGS.TFRecordPath, 'CelebA_train.tfrecord')
logging.info(TFRecordPath)
train_dataset = LoadTFRecordDataset(TFRecordPath)
for partition in ['val', 'test']:
TFRecordPath = os.path.join(FLAGS.TFRecordPath, 'CelebA_{}.tfrecord'.format(partition))
logging.info(TFRecordPath)
part_dataset = LoadTFRecordDataset(TFRecordPath)
train_dataset.concatenate(part_dataset)
else:
TFRecordPath = os.path.join(FLAGS.TFRecordPath, 'CelebA_{}.tfrecord'.format(FLAGS.Partition))
logging.info(TFRecordPath)
train_dataset = LoadTFRecordDataset(TFRecordPath)
train_dataset = train_dataset.shuffle(300)
train_dataset = train_dataset.batch(FLAGS.BATCH_SIZE)
train_dataset = train_dataset.prefetch(20)

NOISE = tf.random.normal([16, FLAGS.noise_dim])
return train_dataset, NOISE

Generator跟Discriminator的網路架構

因為所使用的資料集有彩色跟黑白的圖像,所以Discriminator在呼叫的時候需要加上Channel這個參數來調整輸入的維度:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import (
Conv2D,
Flatten,
Dense,
LeakyReLU,
BatchNormalization,
Input,
Conv2DTranspose,
Reshape,
Dropout,
)
from absl.flags import FLAGS

def Dis_Net(ImageSize, Channel):
Length, Width = ImageSize
x = inputs_x = Input([Length, Width, Channel])
x = Conv2D(128, (4, 4), strides=(2, 2), padding='same')(x)
x = LeakyReLU()(x)

x = Conv2D(256, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)

x = Conv2D(512, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)

x = Conv2D(1024, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)
x = Dropout(0.3)(x)

x = Conv2D(1, (4, 4), strides=(1, 1), padding='valid')(x)
x = Dropout(0.3)(x)

x = Flatten()(x)
x = Dense(1)(x)
output = tf.sigmoid(x)
return Model(inputs_x, output, name='Discriminator')

def Gen_Net(Channel):
x = inputs_x = Input([FLAGS.noise_dim])
x = Reshape((1, 1, FLAGS.noise_dim))(x)

x = Conv2DTranspose(1024, (4, 4), strides=(1, 1), padding='valid')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(512, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

output = Conv2DTranspose(Channel, (4, 4), strides=(2, 2), padding='same', activation='tanh')(x)
return Model(inputs_x, output, name='Generator')

Gen & Dis Net, Optimizer and Loss Function

可以看到這裡我們採用了PiecewiseConstantDecay的Learning Rate Schedule,分別在第10、20、25的epoch會下降:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def setup_model():
if FLAGS.ImageType == 'MNIST':
Generator = Gen_Net(1)
Discriminator = Dis_Net([FLAGS.DisInSize, FLAGS.DisInSize], 1)
else:
Generator = Gen_Net(3)
Discriminator = Dis_Net([FLAGS.DisInSize, FLAGS.DisInSize], 3)
if FLAGS.ImageType == 'MNIST':
one_epoch_size = 60000//FLAGS.BATCH_SIZE + 1
else:
one_epoch_size = 162769//FLAGS.BATCH_SIZE + 1
boundaries = [one_epoch_size*10, one_epoch_size*20, one_epoch_size*25]
values = [FLAGS.lr, FLAGS.lr/10, FLAGS.lr/100, FLAGS.lr/1000]
lr_fn = PiecewiseConstantDecay(boundaries, values)

G_opt = tf.keras.optimizers.Adam(lr_fn, 0.5)
D_opt = tf.keras.optimizers.Adam(lr_fn, 0.5)
return Generator, Discriminator, G_opt, D_opt

Whole Training Process

整體的訓練過程因為可以透過傳遞參數訓練五種不同類型的圖像,所以架構上稍微複雜了一些:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
SaveName = '{}-{}'.format(FLAGS.ImageType, str(FLAGS.bbox))
train_dataset, NOISE = create_dataset()
# Initial Log File
if not os.path.exists(FLAGS.LOG_PATH):
os.mkdir(FLAGS.LOG_PATH)
csv_path = os.path.join(FLAGS.LOG_PATH, '{}-loss.csv'.format(SaveName))
with open(csv_path, 'w') as f:
f.write('epoch,Real_P,Fake_P,Gen_loss,Dis_loss\n')
format_str = '{:5d},{:.6f},{:.6f},{:.6f},{:.6f}\n'
dis_r_p = tf.keras.metrics.Mean()
dis_f_p = tf.keras.metrics.Mean()
G_loss = tf.keras.metrics.Mean()
D_loss = tf.keras.metrics.Mean()
loss = [dis_r_p, dis_f_p, G_loss, D_loss]

Generator, Discriminator, G_opt, D_opt = setup_model()
models = [Generator, Discriminator]
opts = [G_opt, D_opt]

'''
-----------------------Training-----------------------
'''
for epoch in range(FLAGS.epochs):
start = time.time()
for image_batch in tqdm(train_dataset.as_numpy_iterator()):
train_step(models, opts, image_batch, loss)

# Record Loss
with open(csv_path, 'a') as f:
f.write(format_str.format(epoch,
loss[0].result().numpy(),
loss[1].result().numpy(),
loss[2].result().numpy(),
loss[3].result().numpy()))
loss[0].reset_states()
loss[1].reset_states()
loss[2].reset_states()
loss[3].reset_states()
# Each Epoch Save Image
generate_and_save_images(Generator((NOISE), training=False), epoch + 1, SaveName)

# Save the model every 15 epochs
if (epoch + 1) % 15 == 0:
Gen_save_path = os.path.join(FLAGS.MODEL_PATH, SaveName, 'Generator')
Dis_save_path = os.path.join(FLAGS.MODEL_PATH, SaveName, 'Discriminator')
Generator.save_weights(Gen_save_path)
Discriminator.save_weights(Dis_save_path)

logging.info('Time for epoch {} is {:.3f} sec'.format(epoch + 1, time.time()-start))
time.sleep(0.2)

Result

img_celeba with bbox

以下是使用img_celeba加上bbox資訊的訓練過程中每個Epoch的輸出:

Image Generated by Each Epoch

img_celeba

這個則是直接使用img_celeba來訓練,其中每個Epoch的輸出,可以看到因為圖像保留的資訊較多(沒有擷取人臉的位置),所以生成圖片就顯得雜亂了許多,人臉部分生成的也不是很好:

Image Generated by Each Epoch

img_align_celeba

使用img_align_celeba來訓練,其中每個Epoch的輸出,這是CelebA原本就提供的圖片,針對人臉的位置進行擷取,再做適當的padding:

Image Generated by Each Epoch

MNIST

使用MNIST來訓練,其中每個Epoch的輸出,可以看到因為生成的圖片為64*64,所以圖片比起GAN較沒有顆粒感:

Image Generated by Each Epoch

結論

DCGAN在論文中還分析了noise所在的潛在空間,透過兩個noise之間的點分析了這個潛在空間是否跟Word2Vec一樣具有意義,也針對生成圖片的noise進行加減,其中特別有趣的是:對戴墨鏡的男子-沒戴墨鏡的男子+女子的Vector,竟然可以得到戴墨鏡女子的圖片,這代表了這個潛在空間具有圖像的意義,令人大開眼界。

Github:GAN-03 Deep Convolutional Generative Adversarial Network

GAN系列-03 Deep Convolutional Generative Adversarial Network

https://augustushsu.github.io/2021/12/10/GAN-03/

作者

Augustus Hsu

發表於

2021-12-10

更新於

2021-12-10

許可協議

評論