Yamarae 2018. 8. 17. 00:27


2014년에 처음 선보인 이안 굿펠로우(Ian Goodfellow)의 GAN(Generative Adversarial Network) 모델은 서서히 냉각되던 딥 러닝에 대한 열기를 다시 활활 타오르게 한 발명품이다. 연구에도 유행이 있는 것을 감안한다면, 현재의 최신 유행 딥 러닝 모델은 GAN이다. 


GAN을 접하기 이전에는, 워낙 최신 알고리즘인데다가 이름이 주는 위압감 때문에 막연하게 어려운 알고리즘이라고 생각했었다. 하지만 실제로 접하고 보니, 오히려 CNN이나 RNN을 처음 접했을 때 보다 쉬운 느낌이 들었다. 사실 GAN의 아이디어는 어렵지 않기 때문이다. 본 포스팅에서는 GAN을 (수식을 최대한 배제하고) 쉽게 이해해보고, 이를 간단한 Keras 코드로 구현해 볼 것이다. 본 포스팅은 Neural Network과 Keras에 대한 기본적인 이해가 있으면 더욱 읽기 수월하다.







1. GAN의 목적



GAN을 이름 그대로 해석하자면, '적대적 생성 모델' 정도가 되겠다. '적대적' 이라는 말은 잠시 후에 생각해보도록 하고, 그렇다면 먼저 '생성 모델' 에 대해 생각해보자. GAN은 무엇을 생성한다는 것일까? 일반적인 머신 러닝 모델이 생성해내는 것은 class에 대한 예측값, 혹은 continuous random variable에 대한 interval prediction 등의 결과이다. 하지만 이런 결과들을 어떤 형태를 만들어내는 것은 아니다. 가장 높은 probability 혹은 likelihood를 찾아낼 뿐이다. GAN은 여기서 한발 더 나아가, '데이터의 형태' 를 만들어내고자 한다. 데이터의 형태라는 것은 분포 혹은 분산을 의미한다. 이미지의 형태를 예로 들어보자. 우리는 픽셀들의 분포에 따라 이 모양은 코, 이 모양은 눈이라는 것을 인식한다. 명암이나 사진의 전체적인 채도와는 큰 상관이 없다. 그래서 분포를 만들어낸다는 것은, 단순히 결과값을 도출해내는 함수를 만드는 것을 넘어서 '실제적인 형태' 를 갖춘 데이터를 만들어낸다는 것이다. 요약하자면 GAN은 어떤 분포 혹은 분산 자체를 만들어내는 모델이라고 할 수 있겠다.


Ian Goodfellow의 논문에 수록된 그림. 이렇게 '분포'를 만드는 모델을 학습하는 것이 GAN의 목적이다.





2. 적대적 생성



다음으로 '적대적' 이라는 말에 대해 파헤쳐보자. '생성' 이라는 단어에서 GAN의 목적이 분포의 생성이라는 것을 알 수 있었다면, 적대적이라는 말에서는 GAN의 핵심 아이디어에 대해 알 수 있다. GAN을 설명할 때 흔히 사용되는 비유법은 지폐위조범과 경찰의 예시이다.


GAN 모델을 가장 쉽게 설명할 수 있는 그림. 

[출처 : https://files.slack.com/files-pri/T25783BPY-F9SHTP6F9/picture2.png?pub_secret=6821873e68]



지폐 위조범에게 Generator라는 역할을 부여하자. 그리고 경찰은 Discriminator라는 역할을 부여한다. 디카프리오 주연의 <Catch me if you can>이라는 영화를 본 사람이라면 이미지를 떠올리기가 더욱 수월할 것이다. 위조범은 경찰의 단속을 피하기 위해서 더욱 정교한 가짜 지폐를 만들어내고, 경찰은 점점 더 정교한 기법으로 지폐를 판별해내는 방법을 개발해낸다. 영화를 봤다면 아마도, 디카프리오가 결국 경찰의 협조 하에 가짜 지폐를 판독하는 장면을 떠올렸을 것이다. 이처럼 각각의 역할을 가진 두 개의 모델을 통해 진짜같은 가짜를 생성해내는 능력을 키워주는 것이 GAN의 핵심 아이디어이고, 그래서 적대적이라는 용어가 붙은 것이다.





3. GAN의 학습 방법



GAN의 아이디어에 대해 이해했다면, 이제 본격적으로 GAN을 파헤쳐볼 차례다. 본 포스팅에서는 우리에게 가장 익숙한 mnist 이미지를 생성하는 모델을 기준으로 예시를 들겠다. 먼저 이미지를 판별하는 Discriminator(이하 D)는 CNN 판별기처럼 구성할 수 있다. 그래서 이해하기에 매우 쉽다. 하지만 Generator(이하 G)를 구성하는 것은 새로운 아이디어가 도입된다. 



GAN의 모델 구성도 및 학습 flow

[출처 : https://hyeongminlee.github.io/post/gan001_gan/]



G는 random한 noise를 생성해내는 vector z를 input으로 하며(그림의 Noise), D가 판별하고자 하는 input image(여기서는 28X28의 mnist 이미지)를 output으로 하는 neural network unit이라고 할 수 있다. 이렇게 GAN의 코어가 되는 모델은 D와 G 두 가지이다. 학습 과정에서는 실제 mnist 이미지, Real Image를 D로 하여금 '진짜'라고 학습시키는 1번 과정, 그리고 vector z와 G에 의해 생성된 Fake Image를 '가짜'라고 학습시키는 2번과정으로 나뉜다. 여기서 유의할 점은 D가 두번 학습되고 G는 1번 학습되는 것이 아니라, 1번 과정에서의 Real Image와 Fake Image를 D의 x input으로 합쳐서 학습한다는 것이다. 이쯤에서 잠시 Keras 코드를 살펴보자.


    def train_D(self):
        """
        train Discriminator
        """

        # Real data
        real = self.data.get_real_sample()

        # Generated data
        z = self.data.get_z_sample(self.batch_size)
        generated_images = self.gan.G.predict(z)

        # labeling and concat generated, real images
        x = np.concatenate((real, generated_images), axis=0)
        y = [0.9] * self.batch_size + [0] * self.batch_size

        # train discriminator
        self.gan.D.trainable = True
        loss = self.gan.D.train_on_batch(x, y)
        return loss

    def train_G(self):
        """
        train Generator
        """

        # Generated data
        z = self.data.get_z_sample(self.batch_size)

        # labeling
        y = [1] * self.batch_size

        # train generator
        self.gan.D.trainable = False
        loss = self.gan.GD.train_on_batch(z, y)
        return loss


train_D는 D를 학습하는 부분, 그리고 train_G는 D(G(z))에서 G를 학습하는 부분이다. D.trainable을 사용하여, 위에서 설명한 대로 D는 한 번만 학습되도록 구현하였다. 코드에서 D(G(z))에서 D의 학습을 False로 한다면, 결국 G만 학습이 된다. 눈여겨 보아야 할 부분은 'x = np.concatenate((real, generated_images), axis=0)' 이다. 이 부분을 통해 진짜이미지와 가짜이미지를 D에게 한번에 학습시킨다.


수식을 최대한 배제한 채로 GAN 모델의 아이디어를 설명하고자 했지만, 단 한 번 수식을 반드시 설명해야 하는 이유가 있다. 바로 Loss 함수 때문이다. 어떠한 모델도 Loss 함수에 대한 이해 없이는 제대로 된 이해를 했다고 보기 힘들기 때문이다. GAN 프레임 워크는 코어가 되는 두 개의 모델의 학습에 따라 진행된다. D의 목표는 Real, 혹은 Fake 이미지를 제대로 분류해내는 것이다. 그리고 G의 임무는 완벽하게 D가 틀리도록 하는 것이다. 그래서 두 코어 모델의 Loss 지표는 반대가 되며, 이 때문에도 '적대적' 모델로 불린다.


두 코어의 목적을 하나로 합친 목적함수는 아래 그림과 같다. D와 G에 0~1 의 Probability를 갖는 값들을 직접 넣어보면서 목적함수가 최적이 되는 지점을 추리해보자. 수식에 대한 이해는 Divergence Theorem에 근거한 것인데, 자세한 내용은 원문을 참고하도록 하자.







4. 그 외


GAN 모델은 일반적인 머신 러닝, 혹은 딥 러닝 모델과는 달리 명확한 평가의 기준이 없다. Loss는 단지 학습을 위한 오토 파라미터의 구실을 하는 셈이고, 실제적인 Loss를 나타내거나 Accuracy와 같은 기준이 되는 명확한 평가지표가 존재하지 않는다. 이미지를 생성하는 GAN의 경우, 사람의 육안으로 결과물을 평가할 수 있을 뿐이다.


Loss함수를 정의하고 이를 최적화 할때, 실제 환경에서는 생각보다 G의 초기 성능이 안좋다. 그래서 D(G(z))가 0에 가깝게 되는데, 원론적인 수식으로 적용하면 학습이 잘 안된다. 그래서 약간의 테크닉을 써서 수식을 살짝 바꾼다. (역시 수식에 대한 자세한 내용, 연구적인 내용에 관심이 있는 사람이라면 논문을 참고하자) 또한 D : G의 학습 비율은 1 : 5 와 같은 형태로 불균형하게 하는것이 일반적인듯 하다. D가 G에 비해 너무 정확하다면, G의 gradient가 vanishing되는 문제가 생기기도 하고, 반대의 경우도 생긴다. 어쨌든 이러한 문제점을 해결하기 위해 학습 비율을 조정을 하면 해결된다고 한다. 실제 코드를 돌려보면, D, G의 iterate를 다른 비율로 학습하는 것이 훨씬 학습이 잘된다.







GAN은 D, G 그리고 Noise에 대한 함수와 네트워크 구성을 자유롭게 하면서, GAN의 advanced 모델들을 구현해보는 쏠쏠한 재미가 있다. 그러한 모델들을 구현해보기 전에, 우선 기초가 되는 Gaussian 분포 생성과 mnist 이미지 생성에 대한 튜토리얼을 진행해보는 것을 권장한다.


[Gaussian, mnist generate 예제 전체 코드 링크]


import argparse

import numpy as np

from keras.models import Model, Sequential
from keras.layers.core import Reshape, Dense, Dropout, Flatten
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, MaxPooling2D, UpSampling2D
from keras.layers.normalization import BatchNormalization
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
from keras import backend as K


K.set_image_data_format('channels_first')


class Data:
    """
    Define dataset for training GAN
    """
    def __init__(self, batch_size, z_input_dim):
        # load mnist dataset
        # 이미지는 보통 -1~1 사이의 값으로 normalization : generator의 outputlayer를 tanh로
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        self.x_data = ((X_train.astype(np.float32) - 127.5) / 127.5)
        self.x_data = self.x_data.reshape((self.x_data.shape[0], 1) + self.x_data.shape[1:])
        self.batch_size = batch_size
        self.z_input_dim = z_input_dim

    def get_real_sample(self):
        """
        get real sample mnist images

        :return: batch_size number of mnist image data
        """
        return self.x_data[np.random.randint(0, self.x_data.shape[0], size=self.batch_size)]

    def get_z_sample(self, sample_size):
        """
        get z sample data

        :return: random z data (batch_size, z_input_dim) size
        """
        return np.random.uniform(-1.0, 1.0, (sample_size, self.z_input_dim))


class GAN:
    def __init__(self, learning_rate, z_input_dim):
        """
        init params

        :param learning_rate: learning rate of optimizer
        :param z_input_dim: input dim of z
        """
        self.learning_rate = learning_rate
        self.z_input_dim = z_input_dim
        self.D = self.discriminator()
        self.G = self.generator()
        self.GD = self.combined()

    def discriminator(self):
        """
        define discriminator
        """
        D = Sequential()
        D.add(Conv2D(256, (5, 5),
                     padding='same',
                     input_shape=(1, 28, 28),
                     kernel_initializer=initializers.RandomNormal(stddev=0.02)))
        D.add(LeakyReLU(0.2))
        D.add(MaxPooling2D(pool_size=(2, 2), strides=2))
        D.add(Dropout(0.3))
        D.add(Conv2D(512, (5, 5), padding='same'))
        D.add(LeakyReLU(0.2))
        D.add(MaxPooling2D(pool_size=(2, 2), strides=2))
        D.add(Dropout(0.3))
        D.add(Flatten())
        D.add(Dense(256))
        D.add(LeakyReLU(0.2))
        D.add(Dropout(0.3))
        D.add(Dense(1, activation='sigmoid'))

        adam = Adam(lr=self.learning_rate, beta_1=0.5)
        D.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
        return D

    def generator(self):
        """
        define generator
        """
        G = Sequential()
        G.add(Dense(512, input_dim=self.z_input_dim))
        G.add(LeakyReLU(0.2))
        G.add(Dense(128 * 7 * 7))
        G.add(LeakyReLU(0.2))
        G.add(BatchNormalization())
        G.add(Reshape((128, 7, 7), input_shape=(128 * 7 * 7,)))
        G.add(UpSampling2D(size=(2, 2)))
        G.add(Conv2D(64, (5, 5), padding='same', activation='tanh'))
        G.add(UpSampling2D(size=(2, 2)))
        G.add(Conv2D(1, (5, 5), padding='same', activation='tanh'))

        adam = Adam(lr=self.learning_rate, beta_1=0.5)
        G.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
        return G

    def combined(self):
        """
        defien combined gan model
        """
        G, D = self.G, self.D
        D.trainable = False
        GD = Sequential()
        GD.add(G)
        GD.add(D)

        adam = Adam(lr=self.learning_rate, beta_1=0.5)
        GD.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])
        D.trainable = True
        return GD


class Model:
    def __init__(self, args):
        self.epochs = args.epochs
        self.batch_size = args.batch_size
        self.learning_rate = args.learning_rate
        self.z_input_dim = args.z_input_dim
        self.data = Data(self.batch_size, self.z_input_dim)

        # the reason why D, G differ in iter : Generator needs more training than Discriminator
        self.n_iter_D = args.n_iter_D
        self.n_iter_G = args.n_iter_G
        self.gan = GAN(self.learning_rate, self.z_input_dim)
        self.d_loss = []
        self.g_loss = []

        # print status
        batch_count = self.data.x_data.shape[0] / self.batch_size
        print('Epochs:', self.epochs)
        print('Batch size:', self.batch_size)
        print('Batches per epoch:', batch_count)
        print('Learning rate:', self.learning_rate)
        print('Image data format:', K.image_data_format())

    def fit(self):
        for epoch in range(self.epochs):

            # train discriminator by real data
            dloss = 0
            for iter in range(self.n_iter_D):
                dloss = self.train_D()

            # train GD by generated fake data
            gloss = 0
            for iter in range(self.n_iter_G):
                gloss = self.train_G()

            # print loss data
            print('Discriminator loss:', str(dloss))
            print('Generator loss:', str(gloss))

    def train_D(self):
        """
        train Discriminator
        """

        # Real data
        real = self.data.get_real_sample()

        # Generated data
        z = self.data.get_z_sample(self.batch_size)
        generated_images = self.gan.G.predict(z)

        # labeling and concat generated, real images
        x = np.concatenate((real, generated_images), axis=0)
        y = [0.9] * self.batch_size + [0] * self.batch_size

        # train discriminator
        self.gan.D.trainable = True
        loss = self.gan.D.train_on_batch(x, y)
        return loss

    def train_G(self):
        """
        train Generator
        """

        # Generated data
        z = self.data.get_z_sample(self.batch_size)

        # labeling
        y = [1] * self.batch_size

        # train generator
        self.gan.D.trainable = False
        loss = self.gan.GD.train_on_batch(z, y)
        return loss

def main():
    # set hyper parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=128,
                        help='Batch size for networks')
    parser.add_argument('--epochs', type=int, default=200,
                        help='Epochs for the networks')
    parser.add_argument('--learning_rate', type=float, default=0.0002,
                        help='Learning rate')
    parser.add_argument('--z_input_dim', type=int, default=100,
                        help='Input dimension for the generator.')
    parser.add_argument('--n_iter_D', type=int, default=1,
                        help='training iteration for D')
    parser.add_argument('--n_iter_G', type=int, default=5,
                        help='training iteration for G')
    args = parser.parse_args()

    # run model
    model = Model(args)
    model.fit()


if __name__ == '__main__':
    main()




실행 가능한 전체 코드 (Google Colab) : https://colab.research.google.com/drive/1RcYxh17bwST6Wi0N5TNidoeSZc-k_XoT