SRGAN 설명 및 코드 구현하기
SRGAN 설명
Super-resolution GAN
SRGAN은 저화질의 이미지를 고화질의 이미지로 만듭니다.
학습 동안, 고화질 이미지(HR)는 저화질 이미지로(LR) 다운샘플링 됩니다. GAN 생성자는 저화질(LR)이미지를 고화질(SR)이미지로 업샘플링 시킵니다.
여기서, 고화질 이미지를 구분하기 위해 구분자를 사용하며, 구분자와 생성자를 학습시키기 위해 GAN loss를 역전파 시키게 됩니다.
데이터셋 준비
CelebA
CelebA 데이터셋은 대표적인 얼굴(face) 데이터셋이다. 이때 CelebA 데이터셋은 약 200,000개 정도의 얼굴 이미지로 구성된다. 기본적으로 10,000명가량의 사람이 포함되어 있다. 즉, 한 명당 20장 정도의 이미지가 있다고 보면 된다. 각 이미지는 178 x 218 해상도로 존재한다.
www.kaggle.com/jessicali9530/celeba-dataset
에서 받아옵니다.
# 압축 풀기
!unzip -qq '/content/drive/MyDrive/archive.zip'
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import glob
data_list = glob.glob('/content/img_align_celeba/img_align_celeba/*.jpg')
train = []
label = []
for path in data_list[:10000]:
img = cv2.imread(path)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img = np.array(img)
train.append(img)
label.append(np.ones(1))
train = np.array(train)
train.shape
(10000, 218, 178, 3)
def preprocessing(lr, hr):
hr = tf.cast(hr, tf.float32) / 255.
# 이미지의 크기가 크므로 (96,96,3) 크기로 임의 영역을 잘라내어 사용합니다.
hr_patch = tf.image.random_crop(hr, size=[40,40,3])
# 잘라낸 고해상도 이미지의 가로, 세로 픽셀 수를 1/4배로 줄입니다
# 이렇게 만든 저해상도 이미지를 SRGAN의 입력으로 사용합니다.
lr_patch = tf.image.resize(hr_patch, [40//4, 40//4], "bicubic")
return lr_patch, hr_patch
train = tf.data.Dataset.from_tensor_slices((label, train))
train = train.map(preprocessing).shuffle(buffer_size=10).repeat().batch(8)
# import cv2
# import numpy as np
# import matplotlib.pyplot as plt
# import tensorflow_datasets as tfds
# # 데이터 불러오기
# train, valid = tfds.load(
# "div2k/bicubic_x4",
# split=['train', 'validation'],
# as_supervised=True
# )
# def preprocessing(lr, hr):
# hr = tf.cast(hr, tf.float32) / 255.
# # 이미지의 크기가 크므로 (96,96,3) 크기로 임의 영역을 잘라내어 사용합니다.
# hr_patch = tf.image.random_crop(hr, size=[96,96,3])
# # 잘라낸 고해상도 이미지의 가로, 세로 픽셀 수를 1/4배로 줄입니다
# # 이렇게 만든 저해상도 이미지를 SRGAN의 입력으로 사용합니다.
# lr_patch = tf.image.resize(hr_patch, [96//4, 96//4], "bicubic")
# return lr_patch, hr_patch
# train = train.map(preprocessing).shuffle(buffer_size=10).repeat().batch(8)
Generator
SRGAN의 generator부분으로 중요한 부분은
-
residual block과 subpixel block입니다.
-
활성화 함수로 PRelu를 사용합니다.
-
Sub-pixel block에는 PixelShuffel이 사용되는데 여기서
-
tf.nn.depth_to_space(input, block_size, name=None) 함수가 이용됩니다,
이 함수는 만약 입력의 구조(shape)가 [batch, height, width, depth]라면, 출력 텐서의 구조는[batch, heightblock_size, widthblock_size, depth/(block_size*block_size)]와 같습니다.
합성곱 연산 사이의 활성화에서 데이터를 그대로 유지한 채로 크기를 변경시키는 때 유용합니다(예로, 풀링 대신 사용할 수 있습니다). 합성곱 연산만으로 이루어진 모델의 훈련에도 유용합니다.
from tensorflow.keras import Input, Model, layers
# 그림의 Residual 블록을 정의합니다.
def res_block(x):
out = layers.Conv2D(64, 3, 1, "same")(x)
out = layers.BatchNormalization()(out)
out = layers.PReLU(shared_axes=[1,2])(out)
out = layers.Conv2D(64, 3, 1, "same")(out)
out = layers.BatchNormalization()(out)
return layers.Add()([x, out])
# 그림의 뒤쪽 Sub-pixel 블록을 정의합니다.
def upsample_block(x):
out = layers.Conv2D(256, 3, 1, "same")(x)
# 그림의 PixelShuffler 라고 쓰여진 부분을 아래와 같이 구현합니다.
out = layers.Lambda(lambda x: tf.nn.depth_to_space(x, 2))(out)
return layers.PReLU(shared_axes=[1,2])(out)
# 전체 Generator를 정의합니다.
def get_generator(input_shape=(None, None, 3)):
inputs = Input(input_shape)
out = layers.Conv2D(64, 9, 1, "same")(inputs)
out = residual = layers.PReLU(shared_axes=[1,2])(out)
for _ in range(5):
out = res_block(out)
out = layers.Conv2D(64, 3, 1, "same")(out)
out = layers.BatchNormalization()(out)
out = layers.Add()([residual, out])
for _ in range(2):
out = upsample_block(out)
out = layers.Conv2D(3, 9, 1, "same", activation="tanh")(out)
return Model(inputs, out)
Discriminator
# 그림의 파란색 블록을 정의합니다.
def disc_base_block(x, n_filters=128):
out = layers.Conv2D(n_filters, 3, 1, "same")(x)
out = layers.BatchNormalization()(out)
out = layers.LeakyReLU()(out)
out = layers.Conv2D(n_filters, 3, 2, "same")(out)
out = layers.BatchNormalization()(out)
return layers.LeakyReLU()(out)
# 전체 Discriminator 정의합니다.
def get_discriminator(input_shape=(None, None, 3)):
inputs = Input(input_shape)
out = layers.Conv2D(64, 3, 1, "same")(inputs)
out = layers.LeakyReLU()(out)
out = layers.Conv2D(64, 3, 2, "same")(out)
out = layers.BatchNormalization()(out)
out = layers.LeakyReLU()(out)
for n_filters in [128, 256, 512]:
out = disc_base_block(out, n_filters)
out = layers.Dense(1024)(out)
out = layers.LeakyReLU()(out)
out = layers.Dense(1, activation="sigmoid")(out)
return Model(inputs, out)
Content Loss
SRGAN은 VGG19으로 content loss를 계산합니다. Tensorflow는 이미지넷 데이터로부터 잘 학습된 VGG19를 제공하고 있습니다.
from tensorflow.keras import applications
def get_feature_extractor(input_shape=(None, None, 3)):
vgg = applications.vgg19.VGG19(
include_top = False,
weights = "imagenet",
input_shape=input_shape
)
# 아래 vgg.layers[20]은 vgg 내의 마지막 conv layer입니다.
return Model(vgg.input, vgg.layers[20].output)
Train
from tensorflow.keras import losses, metrics, optimizers
generator = get_generator()
discriminator = get_discriminator()
vgg = get_feature_extractor()
# 사용할 loss function 및 optimizer 를 정의합니다.
bce = losses.BinaryCrossentropy(from_logits=False)
mse = losses.MeanSquaredError()
gene_opt = optimizers.Adam()
disc_opt = optimizers.Adam()
def get_gene_loss(fake_out):
return bce(tf.ones_like(fake_out), fake_out)
def get_disc_loss(real_out, fake_out):
return bce(tf.ones_like(real_out), real_out) + bce(tf.zeros_like(fake_out), fake_out)
@tf.function
def get_content_loss(hr_real, hr_fake):
hr_real = applications.vgg19.preprocess_input(hr_real)
hr_fake = applications.vgg19.preprocess_input(hr_fake)
hr_real_feature = vgg(hr_real) / 12.75
hr_fake_feature = vgg(hr_fake) / 12.75
return mse(hr_real_feature, hr_fake_feature)
@tf.function
def step(lr, hr_real):
with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
hr_fake = generator(lr, training=True)
real_out = discriminator(hr_real, training=True)
fake_out = discriminator(hr_fake, training=True)
perceptual_loss = get_content_loss(hr_real, hr_fake) + 1e-3 * get_gene_loss(fake_out)
discriminator_loss = get_disc_loss(real_out, fake_out)
gene_gradient = gene_tape.gradient(perceptual_loss, generator.trainable_variables)
disc_gradient = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)
gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
return perceptual_loss, discriminator_loss
gene_losses = metrics.Mean()
disc_losses = metrics.Mean()
for epoch in range(1, 2):
for i, (lr, hr) in enumerate(train):
g_loss, d_loss = step(lr, hr)
gene_losses.update_state(g_loss)
disc_losses.update_state(d_loss)
# 10회 반복마다 loss를 출력합니다.
if (i+1) % 10 == 0:
print(f"EPOCH[{epoch}] - STEP[{i+1}] \nGenerator_loss:{gene_losses.result():.4f} \nDiscriminator_loss:{disc_losses.result():.4f}", end="\n\n")
if (i+1) == 200:
break
gene_losses.reset_states()
disc_losses.reset_states()
Test
import numpy as np
def apply_srgan(image):
image = tf.cast(image[np.newaxis, ...], tf.float32)
sr = generator.predict(image)
sr = tf.clip_by_value(sr, 0, 255)
sr = tf.round(sr)
sr = tf.cast(sr, tf.uint8)
return np.array(sr)[0]
img = cv2.imread('/content/drive/MyDrive/lenna.jpg',1)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
img_resize = cv2.resize(img,dsize=(0,0),fx=0.25,fy=0.25,interpolation=cv2.INTER_CUBIC)
bicubic_hr = cv2.resize(img_resize,dsize=(0,0),fx=4,fy=4,interpolation=cv2.INTER_CUBIC)
for lr, hr in train:
img = hr[0]
img_resize = lr[0]
bicubic_hr = cv2.resize(img_resize,dsize=(0,0),fx=4,fy=4,interpolation=cv2.INTER_CUBIC)
plt.figure(figsize=(20,12))
plt.subplot(1,3,1)
plt.title('Original')
plt.imshow(img)
plt.subplot(1,3,2)
plt.title('SRGAN')
plt.imshow(apply_srgan(img_resize))
plt.subplot(1,3,3)
plt.title('bicubic_hr')
plt.imshow(bicubic_hr)
댓글남기기