Generative Adversarial Network

Author

차상진

Published

May 6, 2024

1. Imports

import torch 
import torchvision
import fastai.vision.all 
import fastai.vision.gan
import matplotlib.pyplot as plt 

2. GAN?

A-1. 생성모형이란? 쉬운 설명

- 사진 속에 있는 고양이와 개를 분류하는 것보다 고양이 이미지와 개 이미지를 만드는 것이 더 어렵지 않은가?

A-2. 수학적인 설명

- 이미지 \(X\)가 주어졌을 때 레이블을 \(y\)라고 한다면. 이미지를 보고 레이블을 맞추는 것은 \(P(y|X)\)이다

- 그렇지만 이미지를 생성하는 것은 \(P(X,y)\)이다

- 쉽게 말하면 고양이인지 개인지 맞추는 것은 \(y\)만 맞추면 됨. 하지만 이미지를 생성하는 것은 \(X,y\)를 모두 맞추어야 함

###
\(P(y|X) = \frac{P(X,y)}{P(X)} = \frac{P(X,y)}{\sum_{y}P(X,y)}\)

- \(P(X,y)\)를 알면 \(P(y|X)\)은 알 수 있음 하지만 역은 불가능

- 한 마디로 이미지를 생성하는 것이 이미지를 분류하는 것보다 더 어렵다

B. GAN은 생성모형 중 하나다

- GAN의 원리는 예를 들어 설명하면 경찰과 위조지폐범이 경쟁을 통해서 서로 발전하는 모형으로 설명 가능하다.

- 위조범은 경찰이 속게끔 위조지폐를 더 정교하게 만들고 경찰은 그 위조지폐를 잘 구별하게끔 학습을 한다.

- 굉장히 우수한 경찰이라도 진짜와 가짜를 구분하지 못 할때(=진짜 이미지를 0.5의 확률로만 진짜라고 할 때= 가짜이미지를 0.5의 확률로 가짜라고 할 때) 학습을 멈춘다.

3. GAN의 구현

A.DATA

path = fastai.data.external.untar_data(fastai.data.external.URLs.MNIST)
path
Path('/root/.fastai/data/mnist_png')
X_real = torch.stack([torchvision.io.read_image(str(l)) for l in (path/'training/3').ls()],axis=0)/255
X_real.shape
torch.Size([6131, 1, 28, 28])
plt.imshow(X_real[0].reshape(28,28),cmap='grey')

B. Faker 생성

- 네트워크 입력 : (n,??)

- 네트워크 출력 : (n,1,28,28)

torch.randn(1,4)
tensor([[ 0.0454, -0.7508, -0.9412, -1.6791]])

reshape을 위한 class를 정의하자

class reshape2828(torch.nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,X):
        return X.reshape(-1,1,28,28)
net_faker = torch.nn.Sequential(
    torch.nn.Linear(4,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,784),
    torch.nn.Sigmoid(),
    reshape2828()
)
net_faker(torch.randn(1,4)).shape
torch.Size([1, 1, 28, 28])

C. 경찰 생성

- 네트워크의 입력 : (n,1,28,28)

- 네트워크의 출력 : 0 or 1

net_police = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,30),
    torch.nn.ReLU(),
    torch.nn.Linear(30,1),
    torch.nn.Sigmoid()
)

D. 패트와 매트

plt.imshow(X_real[0].reshape(28,28),cmap='grey')

진짜 이미지 -> 0 / 가짜 이미지 -> 1

net_police(X_real[[0]])
tensor([[0.5080]], grad_fn=<SigmoidBackward0>)
noise = torch.randn(1,4)
plt.imshow(net_faker(noise).data.reshape(28,28),cmap='grey')

yhat_fake = net_police(net_faker(noise))
yhat_fake
tensor([[0.5001]], grad_fn=<SigmoidBackward0>)

1과는 매우 거리가 멈

아무 학습도 안 했기때문에 당연… 하지만 무능!!

fig , ax = plt.subplots(1,2)
ax[0].imshow(net_faker(noise).data.reshape(28,28),cmap='grey')
ax[1].imshow(X_real[0].reshape(28,28),cmap='grey')

이걸 구분을 잘 못하고 있음…

E. 경찰이 공부했다.

진짜 이미지 -> 0 , 가짜 이미지 -> 1

- step1

noise = torch.randn(6131,4)
X_fake = net_faker(noise)
y_real = torch.tensor([0]*6131).reshape(-1,1).float()
y_fake = torch.tensor([1]*6131).reshape(-1,1).float()
yhat_real = net_police(X_real) # 경찰이 진짜 이미지를 보고 판별한 결과
yhat_fake = net_police(X_fake) # 경찰이 가짜 이미지를 보고 판별한 결과

- step2

bce = torch.nn.BCELoss()
loss_police = bce(yhat_fake , y_fake) + bce(yhat_real , y_real)
loss_police
tensor(1.4071, grad_fn=<AddBackward0>)

- step3,4는 앞에서 공부한 내용과 비슷하기에 생략하고 바로 epoch을 돌리자

net_police = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,30),
    torch.nn.ReLU(),
    torch.nn.Linear(30,1),
    torch.nn.Sigmoid()
)
bce = torch.nn.BCELoss()
optimizr_police = torch.optim.Adam(net_police.parameters())

for epoc in range(30):
    noise = torch.randn(6131,4) # epoc을 돌릴 때마다 새로운 noise를 뽑아야함
    X_fake = net_faker(noise)
    yhat_real = net_police(X_real)
    yhat_fake = net_police(X_fake)
    loss_police = bce(yhat_real,y_real) + bce(yhat_fake,y_fake)
    loss_police.backward()
    optimizr_police.step()
    optimizr_police.zero_grad()
net_police(X_real).mean()
tensor(0.0224, grad_fn=<MeanBackward0>)

꽤 늘었다

F. 발전하는 페이커

- step1

Noise = torch.randn(6131,4)
X_fake = net_faker(Noise) 

- step2

yhat_faker = net_police(X_fake)
loss_faker = bce(yhat_faker,y_real)

가짜를 보고 진짜라고 생각해야한다

- step3~4는 별로 특별한게 없음. 그래서 바로 epoch을 진행시켜보자.

net_faker = torch.nn.Sequential(
    torch.nn.Linear(in_features=4, out_features=30),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=30, out_features=64),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=64, out_features=784),
    torch.nn.Sigmoid(),
    reshape2828()
)
#bce = torch.nn.BCELoss()
optimizr_faker = torch.optim.Adam(net_faker.parameters()) # 얘를 실행시키면 새로운 옵티마이저가 나와서 초기값으로 돌아감
for epoc in range(10):
    noise = torch.randn(6131,4)
    X_fake = net_faker(noise)
    yhat_fake = net_police(X_fake)
    loss_faker = bce(yhat_fake,y_real)
    loss_faker.backward()
    optimizr_faker.step()
    optimizr_faker.zero_grad()
fig,ax = plt.subplots(2,5,figsize=(10,4))
k = 0 
for i in range(2):
    for j in range(5):
        ax[i][j].imshow(X_fake[k].reshape(28,28).data,cmap="gray")
        ax[i][j].set_title(f"police hat = {yhat_fake[k].item():.4f}")
        k = k+1 
fig.tight_layout()

((yhat_fake > 0.5) == 0).float().mean()
tensor(0.)

경찰이 가짜이미지를 진짜라고 생각한 비율

G. 경쟁

torch.manual_seed(21345)
net_police = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(784,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,30),
    torch.nn.ReLU(),
    torch.nn.Linear(30,1),
    torch.nn.Sigmoid()
)
net_faker = torch.nn.Sequential(
    torch.nn.Linear(in_features=4, out_features=30),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=30, out_features=64),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=64, out_features=784),
    torch.nn.Sigmoid(),
    reshape2828()
)
bce = torch.nn.BCELoss()
optimizr_police = torch.optim.Adam(net_police.parameters(),lr=0.001,betas=(0.5,0.999))
optimizr_faker = torch.optim.Adam(net_faker.parameters(),lr=0.002,betas=(0.5,0.999))

- police 네트워크 학습

for epoc in range(30):
    noise = torch.randn(6131,4)
    X_fake = net_faker(noise)
    yhat_real = net_police(X_real)
    yhat_fake = net_police(X_fake)
    loss_police = bce(yhat_real,y_real) + bce(yhat_fake,y_fake)
    loss_police.backward()
    optimizr_police.step()
    optimizr_police.zero_grad()
print(f"""
yhat_real = {yhat_real[0].item():.4f} // 이건 0에 가까워야함 
yhat_fake = {yhat_fake[0].item():.4f} // 이건 1에 가까워야함 
""")

yhat_real = 0.0002 // 이건 0에 가까워야함 
yhat_fake = 0.9945 // 이건 1에 가까워야함 

- faker 네트워크 학습

for epoc in range(10):
    # step1
    Noise = torch.randn(6131,4) 
    X_fake = net_faker(Noise) 
    # step2
    yhat_fake = net_police(X_fake) 
    loss_faker = bce(yhat_fake,y_real)
    # step3 
    loss_faker.backward()
    # step4 
    optimizr_faker.step()
    optimizr_faker.zero_grad()
#---#
fig,ax = plt.subplots(2,5,figsize=(10,4))
k = 0 
for i in range(2):
    for j in range(5):
        ax[i][j].imshow(X_fake[k].reshape(28,28).data,cmap="gray")
        ax[i][j].set_title(f"police hat = {yhat_fake[k].item():.4f}")
        k = k+1 
fig.tight_layout()
print(f"""
scam ratio = {((yhat_fake > 0.5) == 0).float().mean():.4f} 
loss_faker = {loss_faker:.4f} 
""")

scam ratio = 0.0000 
loss_faker = 4.4663 

4. 결론

- 서로 경쟁하듯이 발전해야하는데 한 쪽이 너무 강해져서 경쟁을 할 수 없는 상황이 되면 학습은 실패한다.

- 두 네트워크의 균형이 매우 중요한 균형이 깨지면 학습이 실패한다.