from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show
from torch import nn
from transformers import AutoConfig
import torch
from math import sqrt
import torch.nn.functional as F
# step 1: 모델, 토크나이저, text 선언
= "bert-base-uncased" # 원하는 모델 선언
model_ckpt = AutoTokenizer.from_pretrained(model_ckpt) # 모델에서 사용된 pretrained 된 tokenizer를 불러온다.
tokenizer = BertModel.from_pretrained(model_ckpt) # pretrained 된 model을 불러온다.
model = "time flies like an arrow"
text
# step 2: 토크나이징
= tokenizer(text, return_tensors="pt", add_special_tokens=False) # pytorch를 사용해서 CLS,SEP 토큰을 제외하여 토크나이징하는 코드
inputs # 결과는 토크나이저에 있는 어휘사전에 고유한 ID에 매핑된 값이다.
inputs.input_ids
# step 3: 모델의 하이퍼파라미터를 불러오고 임베딩
= AutoConfig.from_pretrained(model_ckpt) # config란 모델의 하이퍼파라미터들을 저장하는 객체, 모델에서 pretrained된 정보를 불러오는 코드.
config = nn.Embedding(config.vocab_size, config.hidden_size) # 임베딩의 층을 생성하는 것
token_emb #config.vocab_size는 모델이 사용하는 단어사전의 크기. config.hidden_size는 모델의 임베딩 차원이다. 각 단어를 몇 차원의 벡터로 변환할지에 대한 것이다
# 결과의 30522는 단어사전의 크기이고 768은 임베딩 차원이다.
= token_emb(inputs.input_ids) #위에서 선언한 token_emb을 이용해서 5개의 단어를 임베딩시켰다.
inputs_embeds # 임베딩 시켰다는 것은 각 단어가 768개의 특성을 가지는 벡터값으로 변환된 것이라고 이해하면 된다.
inputs_embeds.size()
# step 4: 점곱과 softmax를 이용한 가중치 계산
= key = value = inputs_embeds
query = key.size(-1) # 임베딩 차원을 선택
dim_k = torch.bmm(query, key.transpose(1,2)) / sqrt(dim_k) # 점곱을 사용. self attention의 경우 batch단위로 연산이 필요하기에 torch.bmm을 사용
scores # 점곱을 진행하면 dim_k개의 요소가 더해지므로 dim_k와 비례해서 점곱 값이 결정됨. 그래서 dim_k의 제곱근으로 나누면 통계적으로 표준화와 비슷한 역할을 한다.
= F.softmax(scores, dim=-1) # softmax를 사용
weights
# step 5: 최종적으로 계산된 가중치와 value를 점곱
= torch.bmm(weights, value) # 마지막으로 value에 가중치를 곱해준다.
attn_outputs attn_outputs.size()
/root/anaconda3/envs/asdf/lib/python3.9/site-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
torch.Size([1, 5, 768])