3 분 소요

DCGan 구현

데이터셋 : MNIST

텐서플로

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, BatchNormalization, Flatten, Reshape, Conv2DTranspose, LeakyReLU, Dropout
from tensorflow.keras import Model

데이터셋 준비

(train, label), (_, _) = tf.keras.datasets.mnist.load_data()
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step
11501568/11490434 [==============================] - 0s 0us/step
for i in range(5):
  plt.subplot(1,5,i+1)
  plt.imshow(train[i])

train = train.reshape(train.shape[0], 28, 28, 1)
train = (train - 127.5) / 127.5
buffer_size = 6000
batch_size = 64
dataset = tf.data.Dataset.from_tensor_slices(train).shuffle(buffer_size).batch(batch_size)

모델 생성

g_input = Input(shape=(100,))

x = Dense(7*7*256)(g_input)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Reshape((7,7,256))(x)
x = Conv2DTranspose(128, 5, 1, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(64, 5, 2, padding='same')(x)
x = BatchNormalization()(x)
x = LeakyReLU()(x)

x = Conv2DTranspose(1, 5, 2, padding='same', activation='tanh')(x)

generator = Model(g_input, x)
noise = tf.random.normal([1, 100])
sample = generator(noise)
plt.imshow(sample[0, :, :, 0])
<matplotlib.image.AxesImage at 0x7f6330460f90>

d_input = Input((28,28,1))

x = Conv2D(64, 5, 2, padding='same')(d_input)
x = LeakyReLU()(x)

x = Conv2D(128, 5, 2, padding='same')(x)
x = LeakyReLU()(x)

x = Flatten()(x)
x = Dropout(0.4)(x)
x = Dense(1, activation='sigmoid')(x)

discriminator = Model(d_input, x)
dec = discriminator(sample)
dec
<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.4998645]], dtype=float32)>
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def dis_loss(real, fake):
  real_loss = loss_fn(tf.ones_like(real), real)
  fake_loss = loss_fn(tf.zeros_like(fake), fake)
  return real_loss + fake_loss

def gen_loss(fake):
  return loss_fn(tf.ones_like(fake), fake)
g_opt = tf.keras.optimizers.Adam(1e-4)
d_opt = tf.keras.optimizers.Adam(1e-4)
epochs = 50
z_dim = 100
@tf.function
def train_step(images):
  noise = tf.random.normal([batch_size, z_dim])

  with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
    gen_imgs = generator(noise, training=True)

    real = discriminator(images, training=True)
    fake = discriminator(gen_imgs, training=True)

    g_loss = gen_loss(fake)
    d_loss = dis_loss(real, fake)

  g_gradient = gen_tape.gradient(g_loss, generator.trainable_variables)
  d_gradient = dis_tape.gradient(d_loss, discriminator.trainable_variables)

  g_opt.apply_gradients(zip(g_gradient, generator.trainable_variables))
  d_opt.apply_gradients(zip(d_gradient, discriminator.trainable_variables))

  return g_loss, d_loss
def train(dataset, epochs):
  for epoch in tqdm(range(epochs)):

    for images in dataset:
      g_loss, d_loss = train_step(images)

    print('g_loss: {}, d_loss: {}'.format(g_loss, d_loss))
    generate_images(generator)

def generate_images(gen):
  noises = tf.random.normal([10, z_dim])
  imgs = gen(noises)
  for i in range(len(imgs)):
    plt.subplot(1,10,i+1)
    plt.imshow(imgs[i,:,:,0])
  plt.show()
train(dataset, 30)
noises = tf.random.normal([10, z_dim])
imgs = generator(noises)
plt.figure(figsize=(20,16))
for i in range(len(imgs)):
  plt.subplot(1,10,i+1)
  plt.imshow(imgs[i,:,:,0])
plt.show()

파이토치

import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from tqdm import tqdm

데이터셋 준비

device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
mnist_train = dsets.MNIST(root='data/',
                         train=True, # 트레인 셋
                         transform=transform,
                         download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/9912422 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/28881 [00:00<?, ?it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/1648877 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/4542 [00:00<?, ?it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

# 랜덤으로 9개만 시각화
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(mnist_train), size=(1,)).item()
    img, label = mnist_train[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.axis("off") # x축, y축 안보이게 설정
    plt.imshow(img.squeeze(), cmap="gray")
plt.show() 

batch_size = 64
z_dim = 100
epochs = 50
dataloader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True)

모델 생성

class Generator(nn.Module):
  def __init__(self, z_dim):
    super(Generator, self).__init__()
    self.z_dim = z_dim
    self.gen = nn.Sequential(
        nn.ConvTranspose2d(z_dim, 128, 7, 1, 0),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
        nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.disc = nn.Sequential(
        nn.Conv2d(1, 64, 4, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        nn.Conv2d(64, 128, 4, stride=2, padding=1),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),
        nn.Conv2d(128, 1, 7, stride=2, padding=0),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)
generator = Generator(z_dim).to(device)
discriminator = Discriminator().to(device)
noise = torch.randn(1, z_dim, 1, 1).to(device)
sample = generator(noise)
img = sample.view(-1, 28, 28)
plt.imshow(img.squeeze().cpu().detach())
<matplotlib.image.AxesImage at 0x7f6eaaef6d50>

discriminator(sample)
tensor([[[[0.7300]]]], device='cuda:0', grad_fn=)
</pre>

```python
fix_noise = torch.randn(8, z_dim, 1, 1).to(device)

loss_fn = nn.BCELoss()
g_opt = torch.optim.Adam(generator.parameters(), lr=0.0002)
d_opt = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
```


```python
def generate_images(gen):
  imgs = generator(fix_noise)
  for i in range(len(imgs)):
    plt.subplot(1, len(imgs), i+1)
    plt.imshow(imgs[i].view(28,28).cpu().detach(), cmap="gray")
  plt.show()

for i in tqdm(range(epochs)):
  for _, (img, label) in enumerate(dataloader):
    real = img.to(device)
    noise = torch.randn(batch_size, z_dim, 1, 1).to(device)
    fake = generator(noise)

    real_output = discriminator(real)
    fake_output = discriminator(fake)

    real_loss = loss_fn(real_output, torch.ones_like(real_output))
    fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))
    d_loss = (real_loss + fake_loss) / 2
    d_opt.zero_grad()
    d_loss.backward(retain_graph=True)
    d_opt.step()

    output = discriminator(fake)
    g_loss = loss_fn(output, torch.ones_like(output))
    g_opt.zero_grad()
    g_loss.backward()
    g_opt.step()
  
  generate_images(generator)
```


  0%|          | 0/50 [00:00<?, ?it/s]
```python generate_images(generator) ``` ```python ```

댓글남기기