3 분 소요

SRCNN 설명

SRCNN은 이미지의 해상도를 높이는 task인 super-resolution로 CNN을 사용하여 low-resolution 이미지를 high-resolution 이미지로 mapping 합니다. 즉, 저해상도와 고해상도 사이의 관계를 학습하여 하나의 함수를 만드는 것으로 이해해볼 수 있습니다.

다운로드.jpg

  1. 기존 방법들보다 구조가 간단하지만 정확도가 훨씬 높다.

  2. CPU에서 써도 될만큼 속도가 빠르다.

  3. 모델이 크고 깊을수록, 데이터셋이 더 다양할수록 성능이 증가한다. 기존 방법들에선 그러한 방법들의 적용이 힘들다.

  4. 이미지의 RGB 3채널을 처리함과 동시에 고해상도 이미지를 얻을 수 있다.

https://deep-learning-study.tistory.com/688

블로그를 참고했습니다.

데이터셋 준비

CelebA

CelebA 데이터셋은 대표적인 얼굴(face) 데이터셋이다. 이때 CelebA 데이터셋은 약 200,000개 정도의 얼굴 이미지로 구성된다. 기본적으로 10,000명가량의 사람이 포함되어 있다. 즉, 한 명당 20장 정도의 이미지가 있다고 보면 된다. 각 이미지는 178 x 218 해상도로 존재한다.

www.kaggle.com/jessicali9530/celeba-dataset

에서 받아옵니다.

import torch
import torch.optim as optim
import torch.nn as nn
from torchvision.transforms import ToPILImage
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import torch.optim as optim

import torch.nn as nn
import torch.nn.functional as F
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 압축 풀기
!unzip -qq '/content/drive/MyDrive/archive.zip'
import cv2
import numpy as np
import matplotlib.pyplot as plt
import glob

data_list = glob.glob('/content/img_align_celeba/img_align_celeba/*.jpg')
class SRdataset(Dataset):

    def __init__(self, paths):
        self.paths = paths        

    def __len__(self):
        return len(self.paths)

    def __getitem__(self,idx):
        path = self.paths[idx]

        img = cv2.imread(path)
        img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, dsize=(40, 40), interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32) / 255.0
        inp = cv2.GaussianBlur(img,(0,0),1)
        img = np.transpose(img, (2,0,1))
        inp = np.transpose(inp, (2,0,1))

        input_sample, label_sample = torch.tensor(inp, dtype=torch.float32), torch.tensor(img, dtype=torch.float32)

        return input_sample,label_sample
train_ds = SRdataset(data_list)
train_dl = DataLoader(train_ds, batch_size=64)
for image, label in train_dl:
  img = image[0]
  lab = label[0]
  break

plt.subplot(1,2,1)
plt.imshow(np.transpose(img, (1,2,0)))
plt.subplot(1,2,2)
plt.imshow(np.transpose(lab, (1,2,0)))
<matplotlib.image.AxesImage at 0x7f4969825390>
<Figure size 432x288 with 2 Axes>

모델 생성

class SRCNN(nn.Module):
    def __init__(self):
        super().__init__()

        # padding_mode='replicate'는 zero padding이 아닌, 주변 값을 복사해서 padding 합니다.
        self.conv1 = nn.Conv2d(3, 64, 9, padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(64, 32, 1, padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(32, 3, 5, padding=2, padding_mode='replicate')

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)

        return x
# 가중치 초기화
def initialize_weights(model):
    classname = model.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)

model = SRCNN().to(device)
model.apply(initialize_weights);
# 손실함수
loss_func = nn.MSELoss()

optimizer = optim.Adam(model.parameters())

PSNR은 super-resolution의 평가지표입니다.

import math

# PSNR function: 모델의 출력값과 high-resoultion의 유사도를 측정합니다.
# PSNR 값이 클수록 좋습니다.
def psnr(label, outputs, max_val=1.):
    label = label.cpu().detach().numpy()
    outputs = outputs.cpu().detach().numpy()
    img_diff = outputs - label
    rmse = math.sqrt(np.mean((img_diff)**2))
    if rmse == 0: # label과 output이 완전히 일치하는 경우
        return 100
    else:
        psnr = 20 * math.log10(max_val/rmse)
        return psnr
# train 함수
def train_step(model, data_dl):
    model.train()
    running_loss = 0.0
    running_psnr = 0.0

    for i, (image, label) in enumerate(data_dl):
        image = image.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        outputs = model(image)
        loss = loss_func(outputs, label)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        batch_psnr = psnr(label, outputs)
        running_psnr += batch_psnr
    
    final_loss = running_loss / len(data_dl.dataset)
    final_psnr = running_psnr / int(len(train_ds)/data_dl.batch_size)
    return final_loss, final_psnr
num_epochs = 10

train_loss = []
train_psnr = []
start = time.time()
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1} of {num_epochs}')
    train_epoch_loss, train_epoch_psnr = train_step(model, train_dl)

    train_loss.append(train_epoch_loss)
    train_psnr.append(train_epoch_psnr)
    end = time.time()
    print(f'Train PSNR: {train_epoch_psnr:.3f}, Time: {end-start:.2f} sec')
Epoch 1 of 10
Train PSNR: 24.711, Time: 391.86 sec
Epoch 2 of 10
Train PSNR: 26.460, Time: 738.09 sec
Epoch 3 of 10
Train PSNR: 27.029, Time: 1083.62 sec
Epoch 4 of 10
Train PSNR: 27.019, Time: 1429.15 sec
Epoch 5 of 10
Train PSNR: 27.772, Time: 1771.29 sec
Epoch 6 of 10
Train PSNR: 28.237, Time: 2108.81 sec
Epoch 7 of 10
Train PSNR: 27.671, Time: 2450.54 sec
Epoch 8 of 10
Train PSNR: 28.368, Time: 2794.14 sec
Epoch 9 of 10
Train PSNR: 28.659, Time: 3131.25 sec
Epoch 10 of 10
Train PSNR: 27.954, Time: 3464.51 sec
# loss plots
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# psnr plots
plt.figure(figsize=(10, 7))
plt.plot(train_psnr, color='green', label='train PSNR dB')
plt.xlabel('Epochs')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.show()
<Figure size 720x504 with 1 Axes>
<Figure size 720x504 with 1 Axes>
for img, label in train_dl:
    img = img[0]
    label = label[0]
    break

# super-resolution
model.eval()
with torch.no_grad():
    img_ = img.unsqueeze(0)
    img_ = img_.to(device)
    output = model(img_)
    output = output.squeeze(0)
plt.figure(figsize=(15,15))
plt.subplot(1,3,1)
plt.imshow(np.transpose(img, (1,2,0)))
plt.title('input')
plt.subplot(1,3,2)
plt.imshow(np.transpose(output.cpu(), (1,2,0)))
plt.title('output')
plt.subplot(1,3,3)
plt.imshow(np.transpose(label, (1,2,0)))
plt.title('origingal')
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Text(0.5, 1.0, 'origingal')
<Figure size 1080x1080 with 3 Axes>

태그:

카테고리:

업데이트:

댓글남기기