3 분 소요

BEGAN

  • 특징: 가장 큰 특징은 구분자로 분류기가 아닌 CAE를 사용한다는 것이다.

구분자는 실제 이미지는 재구성 오차가 작아지도록 학습하고, 생성된 이미지는 재구성 오차가 커지도록 학습한다.

따라서 생성자는 이 재구성 오차가 적어지도록 학습한다.

이를 통해 gan은 단순하지만 강력한 구조와 빠르고 안정적인 학습과 수렴하고 이미지의 다양성과 질(quality) 사이의 trade-off를 조정하는 것이 가능해진다.

Loss

image.png

  • Lg는 생성자로 인한 재구성 오차이다.

  • Ld에서 x는 실제 이미지이고 L(x)는 실제 이미지에 대한 재구성 오차, L(G)는 생성된 이미지에 대한 재구성 오차이다. 이때 kt는 양수로 L(G)의 가중치이다.

  • kt 또한 학습시키면서 계속 조정해줘야 할 파라미터로 t는 스텝 수, r는 실제와 생성 이미지의 재구성 오차의 균형을 나타내는 파라미터로 r이 작을수록 실제 이미지를 모방하게 되고 r이 클수록 다양성이 있는 이미지를 생성한다. 논문에서는 r={0.3,0.5,0.7} 과 같은 패턴을 사용한다.

데이터 Celeba(유명인 얼굴 데이터셋) 받아오기

아래 주소에서 데이터를 받아온다.

https://www.kaggle.com/datasets/jessicali9530/celeba-dataset?select=img_align_celeba

import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf
import keras.backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, UpSampling2D
from tensorflow.keras import Sequential, Model
batch_size = 16
img_size = (64,64,3)
n_filters = 64
n_layers = 4
z_dim = 32
data_dir = '/content/drive/MyDrive/dataset/img_align_celeba'

data_gen = ImageDataGenerator(rescale=1/255.)
train_data_generator = data_gen.flow_from_directory(directory=data_dir,
                                                    batch_size=batch_size,
                                                    target_size=img_size[:2])
Found 54719 images belonging to 1 classes.

모델 BEGAN 구현

def build_encoder(input_shape, z_dim, n_filters, n_layers):
  model = Sequential()
  model.add(
      Conv2D(n_filters, 3, activation='elu', input_shape=input_shape, padding='same'),
  )
  model.add(Conv2D(n_filters, 3, padding='same'))
  for i in range(2, n_layers + 1):
    model.add(Conv2D(i*n_filters, 3, activation='elu', padding='same'))
    model.add(Conv2D(i*n_filters, 3, activation='elu', strides=2, padding='same'))
  
  model.add(Conv2D(n_layers*n_filters, 3, padding='same'))
  model.add(Flatten())
  model.add(Dense(z_dim))

  return model
def build_decoder(output_shape, z_dim, n_filters, n_layers):

  scale = 2**(n_layers - 1)

  fc_shape = (output_shape[0]//scale, output_shape[1]//scale, n_filters)

  fc_size = fc_shape[0] * fc_shape[1] * fc_shape[2]

  model = Sequential()
  model.add(Dense(fc_size, input_shape=(z_dim,)))
  model.add(Reshape(fc_shape))

  for i in range(n_layers - 1):
    model.add(Conv2D(n_filters//(i+2), 3, activation='elu', padding='same'))
    model.add(Conv2D(n_filters//(i+2), 3, activation='elu', padding='same'))
    model.add(UpSampling2D())
  
  model.add(Conv2D(n_filters//n_layers, 3, activation='elu', padding='same'))
  model.add(Conv2D(n_filters//n_layers, 3, activation='elu', padding='same'))

  model.add(Conv2D(3, 3, padding='same'))

  return model
def build_generator(img_shape, z_dim, n_filters, n_layers):
  decoder = build_decoder(img_shape, z_dim, n_filters, n_layers)
  return decoder

def build_discriminator(img_shape, z_dim, n_filters, n_layers):
  encoder = build_encoder(img_shape, z_dim, n_filters, n_layers)
  decoder = build_decoder(img_shape, z_dim, n_filters, n_layers)
  return Sequential((encoder, decoder))
def build_discrim_trainer(discriminator):
  img_shape = discriminator.input_shape[1:]
  real_inputs = Input(img_shape)
  fake_inputs = Input(img_shape)
  real_outputs = discriminator(real_inputs)
  fake_outputs = discriminator(fake_inputs)

  return Model(inputs=[real_inputs, fake_inputs], outputs=[real_outputs, fake_outputs])
generator = build_generator(img_size, z_dim, n_filters, n_layers)
discriminator = build_discriminator(img_size, z_dim, n_filters, n_layers)
discriminator_trainer = build_discrim_trainer(discriminator)

generator.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 4096)              135168    
                                                                 
 reshape (Reshape)           (None, 8, 8, 64)          0         
                                                                 
 conv2d (Conv2D)             (None, 8, 8, 32)          18464     
                                                                 
 conv2d_1 (Conv2D)           (None, 8, 8, 32)          9248      
                                                                 
 up_sampling2d (UpSampling2D  (None, 16, 16, 32)       0         
 )                                                               
                                                                 
 conv2d_2 (Conv2D)           (None, 16, 16, 21)        6069      
                                                                 
 conv2d_3 (Conv2D)           (None, 16, 16, 21)        3990      
                                                                 
 up_sampling2d_1 (UpSampling  (None, 32, 32, 21)       0         
 2D)                                                             
                                                                 
 conv2d_4 (Conv2D)           (None, 32, 32, 16)        3040      
                                                                 
 conv2d_5 (Conv2D)           (None, 32, 32, 16)        2320      
                                                                 
 up_sampling2d_2 (UpSampling  (None, 64, 64, 16)       0         
 2D)                                                             
                                                                 
 conv2d_6 (Conv2D)           (None, 64, 64, 16)        2320      
                                                                 
 conv2d_7 (Conv2D)           (None, 64, 64, 16)        2320      
                                                                 
 conv2d_8 (Conv2D)           (None, 64, 64, 3)         435       
                                                                 
=================================================================
Total params: 183,374
Trainable params: 183,374
Non-trainable params: 0
_________________________________________________________________
discriminator.summary()
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 sequential_1 (Sequential)   (None, 32)                2960608   
                                                                 
 sequential_2 (Sequential)   (None, 64, 64, 3)         183374    
                                                                 
=================================================================
Total params: 3,143,982
Trainable params: 3,143,982
Non-trainable params: 0
_________________________________________________________________
discriminator_trainer.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 input_2 (InputLayer)           [(None, 64, 64, 3)]  0           []                               
                                                                                                  
 sequential_3 (Sequential)      (None, 64, 64, 3)    3143982     ['input_1[0][0]',                
                                                                  'input_2[0][0]']                
                                                                                                  
==================================================================================================
Total params: 3,143,982
Trainable params: 3,143,982
Non-trainable params: 0
__________________________________________________________________________________________________

Loss 구현

  • 생성자 손실은 build_generator_loss 함수로 구현하며 실제 이미지는 필요없기 때문에 무시하고 재구성 이미지와 생성 이미지의 mae loss 값을 리턴한다.

  • 구분자 손실은 실제 이미지와 생성 이미지를 각각의 재구성 이미지와 mae loss로 구현하며 loss_weight를 통해 위에서 설명한 kt를 가중치로 설정한다. 이때 K.variable을 통해 학습 도중에 가중치를 갱신할 수 있도록 한다.

from keras.losses import mean_absolute_error

def build_generator_loss(discriminator):
  def loss(y_true, y_pred):
    reconst = discriminator(y_pred)
    return mean_absolute_error(reconst, y_pred)
  return loss

def measure(real_loss, fake_loss, gamma):
  return real_loss + np.abs(gamma*real_loss - fake_loss)
g_lr = 0.0001

generator_loss = build_generator_loss(discriminator)
generator.compile(loss=generator_loss, optimizer=tf.keras.optimizers.Adam(g_lr))
d_lr = 0.0001

k_var = 0.0
k = K.variable(k_var)
discriminator_trainer.compile(loss=[mean_absolute_error, mean_absolute_error], loss_weights=[1., -k], optimizer=tf.keras.optimizers.Adam(d_lr))
def sample_image(seeds):
  imgs = generator.predict(seeds)

  plt.figure(figsize=(16,10))

  for i in range(1,len(imgs) + 1):
    plt.subplot(5, 5, i)
    plt.imshow((imgs[i-1] * 255).astype(np.uint8))
  plt.show()

Train(학습)

gamma = 0.5
lr_k = 0.001

total_steps = 1000
start = time.time()

sample_seeds = np.random.uniform(-1, 1, (25, z_dim))
history = []

for step, batch in enumerate(train_data_generator):
  if len(batch) < batch_size:
    continue
  
  if step > total_steps:
    break

  z_g = np.random.uniform(-1, 1, (batch_size, z_dim))
  z_d = np.random.uniform(-1, 1, (batch_size, z_dim))

  g_pred = generator.predict(z_d)

  generator.train_on_batch(z_g, batch)

  _, real_loss, fake_loss = discriminator_trainer.train_on_batch([batch, g_pred], [batch, g_pred])

  k_var += lr_k*(gamma*real_loss - fake_loss)
  K.set_value(k, k_var)

  history.append({'real_loss': real_loss, 'fake_loss': fake_loss})

  if step % 100 == 0:
    meas = np.mean([measure(loss['real_loss'], loss['fake_loss'], gamma) for loss in history])
    print("step: {}, total_loss: {}, time: {}".format(step, meas, time.time()-start))
    sample_image(sample_seeds)
sample_image(sample_seeds)

태그:

카테고리:

업데이트:

댓글남기기