BEGAN (Boundary EquilibriumGAN) 설명 및 구현
BEGAN
- 특징: 가장 큰 특징은 구분자로 분류기가 아닌 CAE를 사용한다는 것이다.
구분자는 실제 이미지는 재구성 오차가 작아지도록 학습하고, 생성된 이미지는 재구성 오차가 커지도록 학습한다.
따라서 생성자는 이 재구성 오차가 적어지도록 학습한다.
이를 통해 gan은 단순하지만 강력한 구조와 빠르고 안정적인 학습과 수렴하고 이미지의 다양성과 질(quality) 사이의 trade-off를 조정하는 것이 가능해진다.
Loss
-
Lg는 생성자로 인한 재구성 오차이다.
-
Ld에서 x는 실제 이미지이고 L(x)는 실제 이미지에 대한 재구성 오차, L(G)는 생성된 이미지에 대한 재구성 오차이다. 이때 kt는 양수로 L(G)의 가중치이다.
-
kt 또한 학습시키면서 계속 조정해줘야 할 파라미터로 t는 스텝 수, r는 실제와 생성 이미지의 재구성 오차의 균형을 나타내는 파라미터로 r이 작을수록 실제 이미지를 모방하게 되고 r이 클수록 다양성이 있는 이미지를 생성한다. 논문에서는 r={0.3,0.5,0.7} 과 같은 패턴을 사용한다.
데이터 Celeba(유명인 얼굴 데이터셋) 받아오기
아래 주소에서 데이터를 받아온다.
https://www.kaggle.com/datasets/jessicali9530/celeba-dataset?select=img_align_celeba
import numpy as np
import matplotlib.pyplot as plt
import time
import tensorflow as tf
import keras.backend as K
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, Reshape, UpSampling2D
from tensorflow.keras import Sequential, Model
batch_size = 16
img_size = (64,64,3)
n_filters = 64
n_layers = 4
z_dim = 32
data_dir = '/content/drive/MyDrive/dataset/img_align_celeba'
data_gen = ImageDataGenerator(rescale=1/255.)
train_data_generator = data_gen.flow_from_directory(directory=data_dir,
batch_size=batch_size,
target_size=img_size[:2])
Found 54719 images belonging to 1 classes.
모델 BEGAN 구현
def build_encoder(input_shape, z_dim, n_filters, n_layers):
model = Sequential()
model.add(
Conv2D(n_filters, 3, activation='elu', input_shape=input_shape, padding='same'),
)
model.add(Conv2D(n_filters, 3, padding='same'))
for i in range(2, n_layers + 1):
model.add(Conv2D(i*n_filters, 3, activation='elu', padding='same'))
model.add(Conv2D(i*n_filters, 3, activation='elu', strides=2, padding='same'))
model.add(Conv2D(n_layers*n_filters, 3, padding='same'))
model.add(Flatten())
model.add(Dense(z_dim))
return model
def build_decoder(output_shape, z_dim, n_filters, n_layers):
scale = 2**(n_layers - 1)
fc_shape = (output_shape[0]//scale, output_shape[1]//scale, n_filters)
fc_size = fc_shape[0] * fc_shape[1] * fc_shape[2]
model = Sequential()
model.add(Dense(fc_size, input_shape=(z_dim,)))
model.add(Reshape(fc_shape))
for i in range(n_layers - 1):
model.add(Conv2D(n_filters//(i+2), 3, activation='elu', padding='same'))
model.add(Conv2D(n_filters//(i+2), 3, activation='elu', padding='same'))
model.add(UpSampling2D())
model.add(Conv2D(n_filters//n_layers, 3, activation='elu', padding='same'))
model.add(Conv2D(n_filters//n_layers, 3, activation='elu', padding='same'))
model.add(Conv2D(3, 3, padding='same'))
return model
def build_generator(img_shape, z_dim, n_filters, n_layers):
decoder = build_decoder(img_shape, z_dim, n_filters, n_layers)
return decoder
def build_discriminator(img_shape, z_dim, n_filters, n_layers):
encoder = build_encoder(img_shape, z_dim, n_filters, n_layers)
decoder = build_decoder(img_shape, z_dim, n_filters, n_layers)
return Sequential((encoder, decoder))
def build_discrim_trainer(discriminator):
img_shape = discriminator.input_shape[1:]
real_inputs = Input(img_shape)
fake_inputs = Input(img_shape)
real_outputs = discriminator(real_inputs)
fake_outputs = discriminator(fake_inputs)
return Model(inputs=[real_inputs, fake_inputs], outputs=[real_outputs, fake_outputs])
generator = build_generator(img_size, z_dim, n_filters, n_layers)
discriminator = build_discriminator(img_size, z_dim, n_filters, n_layers)
discriminator_trainer = build_discrim_trainer(discriminator)
generator.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 4096) 135168 reshape (Reshape) (None, 8, 8, 64) 0 conv2d (Conv2D) (None, 8, 8, 32) 18464 conv2d_1 (Conv2D) (None, 8, 8, 32) 9248 up_sampling2d (UpSampling2D (None, 16, 16, 32) 0 ) conv2d_2 (Conv2D) (None, 16, 16, 21) 6069 conv2d_3 (Conv2D) (None, 16, 16, 21) 3990 up_sampling2d_1 (UpSampling (None, 32, 32, 21) 0 2D) conv2d_4 (Conv2D) (None, 32, 32, 16) 3040 conv2d_5 (Conv2D) (None, 32, 32, 16) 2320 up_sampling2d_2 (UpSampling (None, 64, 64, 16) 0 2D) conv2d_6 (Conv2D) (None, 64, 64, 16) 2320 conv2d_7 (Conv2D) (None, 64, 64, 16) 2320 conv2d_8 (Conv2D) (None, 64, 64, 3) 435 ================================================================= Total params: 183,374 Trainable params: 183,374 Non-trainable params: 0 _________________________________________________________________
discriminator.summary()
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= sequential_1 (Sequential) (None, 32) 2960608 sequential_2 (Sequential) (None, 64, 64, 3) 183374 ================================================================= Total params: 3,143,982 Trainable params: 3,143,982 Non-trainable params: 0 _________________________________________________________________
discriminator_trainer.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 64, 64, 3)] 0 [] input_2 (InputLayer) [(None, 64, 64, 3)] 0 [] sequential_3 (Sequential) (None, 64, 64, 3) 3143982 ['input_1[0][0]', 'input_2[0][0]'] ================================================================================================== Total params: 3,143,982 Trainable params: 3,143,982 Non-trainable params: 0 __________________________________________________________________________________________________
Loss 구현
-
생성자 손실은 build_generator_loss 함수로 구현하며 실제 이미지는 필요없기 때문에 무시하고 재구성 이미지와 생성 이미지의 mae loss 값을 리턴한다.
-
구분자 손실은 실제 이미지와 생성 이미지를 각각의 재구성 이미지와 mae loss로 구현하며 loss_weight를 통해 위에서 설명한 kt를 가중치로 설정한다. 이때 K.variable을 통해 학습 도중에 가중치를 갱신할 수 있도록 한다.
from keras.losses import mean_absolute_error
def build_generator_loss(discriminator):
def loss(y_true, y_pred):
reconst = discriminator(y_pred)
return mean_absolute_error(reconst, y_pred)
return loss
def measure(real_loss, fake_loss, gamma):
return real_loss + np.abs(gamma*real_loss - fake_loss)
g_lr = 0.0001
generator_loss = build_generator_loss(discriminator)
generator.compile(loss=generator_loss, optimizer=tf.keras.optimizers.Adam(g_lr))
d_lr = 0.0001
k_var = 0.0
k = K.variable(k_var)
discriminator_trainer.compile(loss=[mean_absolute_error, mean_absolute_error], loss_weights=[1., -k], optimizer=tf.keras.optimizers.Adam(d_lr))
def sample_image(seeds):
imgs = generator.predict(seeds)
plt.figure(figsize=(16,10))
for i in range(1,len(imgs) + 1):
plt.subplot(5, 5, i)
plt.imshow((imgs[i-1] * 255).astype(np.uint8))
plt.show()
Train(학습)
gamma = 0.5
lr_k = 0.001
total_steps = 1000
start = time.time()
sample_seeds = np.random.uniform(-1, 1, (25, z_dim))
history = []
for step, batch in enumerate(train_data_generator):
if len(batch) < batch_size:
continue
if step > total_steps:
break
z_g = np.random.uniform(-1, 1, (batch_size, z_dim))
z_d = np.random.uniform(-1, 1, (batch_size, z_dim))
g_pred = generator.predict(z_d)
generator.train_on_batch(z_g, batch)
_, real_loss, fake_loss = discriminator_trainer.train_on_batch([batch, g_pred], [batch, g_pred])
k_var += lr_k*(gamma*real_loss - fake_loss)
K.set_value(k, k_var)
history.append({'real_loss': real_loss, 'fake_loss': fake_loss})
if step % 100 == 0:
meas = np.mean([measure(loss['real_loss'], loss['fake_loss'], gamma) for loss in history])
print("step: {}, total_loss: {}, time: {}".format(step, meas, time.time()-start))
sample_image(sample_seeds)
sample_image(sample_seeds)
댓글남기기