import torch
import torchvision
import fastai.vision.all
import fastai.vision.gan
import matplotlib.pyplot as plt
Generative Adversarial Network
1. Imports
2. GAN?
A-1. 생성모형이란? 쉬운 설명
-
사진 속에 있는 고양이와 개를 분류하는 것보다 고양이 이미지와 개 이미지를 만드는 것이 더 어렵지 않은가?
A-2. 수학적인 설명
-
이미지 \(X\)가 주어졌을 때 레이블을 \(y\)라고 한다면. 이미지를 보고 레이블을 맞추는 것은 \(P(y|X)\)이다
-
그렇지만 이미지를 생성하는 것은 \(P(X,y)\)이다
-
쉽게 말하면 고양이인지 개인지 맞추는 것은 \(y\)만 맞추면 됨. 하지만 이미지를 생성하는 것은 \(X,y\)를 모두 맞추어야 함
-
\(P(X,y)\)를 알면 \(P(y|X)\)은 알 수 있음 하지만 역은 불가능
-
한 마디로 이미지를 생성하는 것이 이미지를 분류하는 것보다 더 어렵다
B. GAN은 생성모형 중 하나다
-
GAN의 원리는 예를 들어 설명하면 경찰과 위조지폐범이 경쟁을 통해서 서로 발전하는 모형으로 설명 가능하다.
-
위조범은 경찰이 속게끔 위조지폐를 더 정교하게 만들고 경찰은 그 위조지폐를 잘 구별하게끔 학습을 한다.
-
굉장히 우수한 경찰이라도 진짜와 가짜를 구분하지 못 할때(=진짜 이미지를 0.5의 확률로만 진짜라고 할 때= 가짜이미지를 0.5의 확률로 가짜라고 할 때) 학습을 멈춘다.
3. GAN의 구현
A.DATA
= fastai.data.external.untar_data(fastai.data.external.URLs.MNIST)
path path
Path('/root/.fastai/data/mnist_png')
= torch.stack([torchvision.io.read_image(str(l)) for l in (path/'training/3').ls()],axis=0)/255
X_real X_real.shape
torch.Size([6131, 1, 28, 28])
0].reshape(28,28),cmap='grey') plt.imshow(X_real[
B. Faker 생성
-
네트워크 입력 : (n,??)
-
네트워크 출력 : (n,1,28,28)
1,4) torch.randn(
tensor([[ 0.0454, -0.7508, -0.9412, -1.6791]])
reshape을 위한 class를 정의하자
class reshape2828(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self,X):
return X.reshape(-1,1,28,28)
= torch.nn.Sequential(
net_faker 4,64),
torch.nn.Linear(
torch.nn.ReLU(),64,64),
torch.nn.Linear(
torch.nn.ReLU(),64,784),
torch.nn.Linear(
torch.nn.Sigmoid(),
reshape2828() )
1,4)).shape net_faker(torch.randn(
torch.Size([1, 1, 28, 28])
C. 경찰 생성
-
네트워크의 입력 : (n,1,28,28)
-
네트워크의 출력 : 0 or 1
= torch.nn.Sequential(
net_police
torch.nn.Flatten(),784,64),
torch.nn.Linear(
torch.nn.ReLU(),64,30),
torch.nn.Linear(
torch.nn.ReLU(),30,1),
torch.nn.Linear(
torch.nn.Sigmoid() )
D. 패트와 매트
0].reshape(28,28),cmap='grey') plt.imshow(X_real[
진짜 이미지 -> 0 / 가짜 이미지 -> 1
0]]) net_police(X_real[[
tensor([[0.5080]], grad_fn=<SigmoidBackward0>)
= torch.randn(1,4)
noise 28,28),cmap='grey') plt.imshow(net_faker(noise).data.reshape(
= net_police(net_faker(noise))
yhat_fake yhat_fake
tensor([[0.5001]], grad_fn=<SigmoidBackward0>)
1과는 매우 거리가 멈
아무 학습도 안 했기때문에 당연… 하지만 무능!!
= plt.subplots(1,2)
fig , ax 0].imshow(net_faker(noise).data.reshape(28,28),cmap='grey')
ax[1].imshow(X_real[0].reshape(28,28),cmap='grey') ax[
이걸 구분을 잘 못하고 있음…
E. 경찰이 공부했다.
진짜 이미지 -> 0 , 가짜 이미지 -> 1
-
step1
= torch.randn(6131,4)
noise = net_faker(noise)
X_fake = torch.tensor([0]*6131).reshape(-1,1).float()
y_real = torch.tensor([1]*6131).reshape(-1,1).float() y_fake
= net_police(X_real) # 경찰이 진짜 이미지를 보고 판별한 결과
yhat_real = net_police(X_fake) # 경찰이 가짜 이미지를 보고 판별한 결과 yhat_fake
-
step2
= torch.nn.BCELoss()
bce = bce(yhat_fake , y_fake) + bce(yhat_real , y_real)
loss_police loss_police
tensor(1.4071, grad_fn=<AddBackward0>)
-
step3,4는 앞에서 공부한 내용과 비슷하기에 생략하고 바로 epoch을 돌리자
= torch.nn.Sequential(
net_police
torch.nn.Flatten(),784,64),
torch.nn.Linear(
torch.nn.ReLU(),64,30),
torch.nn.Linear(
torch.nn.ReLU(),30,1),
torch.nn.Linear(
torch.nn.Sigmoid()
)= torch.nn.BCELoss()
bce = torch.optim.Adam(net_police.parameters())
optimizr_police
for epoc in range(30):
= torch.randn(6131,4) # epoc을 돌릴 때마다 새로운 noise를 뽑아야함
noise = net_faker(noise)
X_fake = net_police(X_real)
yhat_real = net_police(X_fake)
yhat_fake = bce(yhat_real,y_real) + bce(yhat_fake,y_fake)
loss_police
loss_police.backward()
optimizr_police.step() optimizr_police.zero_grad()
net_police(X_real).mean()
tensor(0.0224, grad_fn=<MeanBackward0>)
꽤 늘었다
F. 발전하는 페이커
-
step1
= torch.randn(6131,4)
Noise = net_faker(Noise) X_fake
-
step2
= net_police(X_fake)
yhat_faker = bce(yhat_faker,y_real) loss_faker
가짜를 보고 진짜라고 생각해야한다
-
step3~4는 별로 특별한게 없음. 그래서 바로 epoch을 진행시켜보자.
= torch.nn.Sequential(
net_faker =4, out_features=30),
torch.nn.Linear(in_features
torch.nn.ReLU(),=30, out_features=64),
torch.nn.Linear(in_features
torch.nn.ReLU(),=64, out_features=784),
torch.nn.Linear(in_features
torch.nn.Sigmoid(),
reshape2828()
)#bce = torch.nn.BCELoss()
= torch.optim.Adam(net_faker.parameters()) # 얘를 실행시키면 새로운 옵티마이저가 나와서 초기값으로 돌아감 optimizr_faker
for epoc in range(10):
= torch.randn(6131,4)
noise = net_faker(noise)
X_fake = net_police(X_fake)
yhat_fake = bce(yhat_fake,y_real)
loss_faker
loss_faker.backward()
optimizr_faker.step() optimizr_faker.zero_grad()
= plt.subplots(2,5,figsize=(10,4))
fig,ax = 0
k for i in range(2):
for j in range(5):
28,28).data,cmap="gray")
ax[i][j].imshow(X_fake[k].reshape(f"police hat = {yhat_fake[k].item():.4f}")
ax[i][j].set_title(= k+1
k fig.tight_layout()
> 0.5) == 0).float().mean() ((yhat_fake
tensor(0.)
경찰이 가짜이미지를 진짜라고 생각한 비율
G. 경쟁
21345)
torch.manual_seed(= torch.nn.Sequential(
net_police
torch.nn.Flatten(),784,64),
torch.nn.Linear(
torch.nn.ReLU(),64,30),
torch.nn.Linear(
torch.nn.ReLU(),30,1),
torch.nn.Linear(
torch.nn.Sigmoid()
)= torch.nn.Sequential(
net_faker =4, out_features=30),
torch.nn.Linear(in_features
torch.nn.ReLU(),=30, out_features=64),
torch.nn.Linear(in_features
torch.nn.ReLU(),=64, out_features=784),
torch.nn.Linear(in_features
torch.nn.Sigmoid(),
reshape2828()
)= torch.nn.BCELoss()
bce = torch.optim.Adam(net_police.parameters(),lr=0.001,betas=(0.5,0.999))
optimizr_police = torch.optim.Adam(net_faker.parameters(),lr=0.002,betas=(0.5,0.999)) optimizr_faker
-
police 네트워크 학습
for epoc in range(30):
= torch.randn(6131,4)
noise = net_faker(noise)
X_fake = net_police(X_real)
yhat_real = net_police(X_fake)
yhat_fake = bce(yhat_real,y_real) + bce(yhat_fake,y_fake)
loss_police
loss_police.backward()
optimizr_police.step()
optimizr_police.zero_grad()print(f"""
yhat_real = {yhat_real[0].item():.4f} // 이건 0에 가까워야함
yhat_fake = {yhat_fake[0].item():.4f} // 이건 1에 가까워야함
""")
yhat_real = 0.0002 // 이건 0에 가까워야함
yhat_fake = 0.9945 // 이건 1에 가까워야함
-
faker 네트워크 학습
for epoc in range(10):
# step1
= torch.randn(6131,4)
Noise = net_faker(Noise)
X_fake # step2
= net_police(X_fake)
yhat_fake = bce(yhat_fake,y_real)
loss_faker # step3
loss_faker.backward()# step4
optimizr_faker.step()
optimizr_faker.zero_grad()#---#
= plt.subplots(2,5,figsize=(10,4))
fig,ax = 0
k for i in range(2):
for j in range(5):
28,28).data,cmap="gray")
ax[i][j].imshow(X_fake[k].reshape(f"police hat = {yhat_fake[k].item():.4f}")
ax[i][j].set_title(= k+1
k
fig.tight_layout()print(f"""
scam ratio = {((yhat_fake > 0.5) == 0).float().mean():.4f}
loss_faker = {loss_faker:.4f}
""")
scam ratio = 0.0000
loss_faker = 4.4663
4. 결론
-
서로 경쟁하듯이 발전해야하는데 한 쪽이 너무 강해져서 경쟁을 할 수 없는 상황이 되면 학습은 실패한다.
-
두 네트워크의 균형이 매우 중요한 균형이 깨지면 학습이 실패한다.