Spaces:
Runtime error
TorchScript๋ก ๋ด๋ณด๋ด๊ธฐ[[export-to-torchscript]]
TorchScript๋ฅผ ํ์ฉํ ์คํ์ ์์ง ์ด๊ธฐ ๋จ๊ณ๋ก, ๊ฐ๋ณ์ ์ธ ์ ๋ ฅ ํฌ๊ธฐ ๋ชจ๋ธ๋ค์ ํตํด ๊ทธ ๊ธฐ๋ฅ์ฑ์ ๊ณ์ ํ๊ตฌํ๊ณ ์์ต๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ ํฌ๊ฐ ๊ด์ฌ์ ๋๊ณ ์๋ ๋ถ์ผ ์ค ํ๋์ด๋ฉฐ, ์์ผ๋ก ์ถ์๋ ๋ฒ์ ์์ ๋ ๋ง์ ์ฝ๋ ์์ , ๋ ์ ์ฐํ ๊ตฌํ, ๊ทธ๋ฆฌ๊ณ Python ๊ธฐ๋ฐ ์ฝ๋์ ์ปดํ์ผ๋ TorchScript๋ฅผ ๋น๊ตํ๋ ๋ฒค์น๋งํฌ๋ฅผ ๋ฑ์ ํตํด ๋ถ์์ ์ฌํํ ์์ ์ ๋๋ค.
TorchScript ๋ฌธ์์์๋ ์ด๋ ๊ฒ ๋งํฉ๋๋ค.
TorchScript๋ PyTorch ์ฝ๋์์ ์ง๋ ฌํ ๋ฐ ์ต์ ํ ๊ฐ๋ฅํ ๋ชจ๋ธ์ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
JIT๊ณผ TRACE๋ ๊ฐ๋ฐ์๊ฐ ๋ชจ๋ธ์ ๋ด๋ณด๋ด์ ํจ์จ ์งํฅ์ ์ธ C++ ํ๋ก๊ทธ๋จ๊ณผ ๊ฐ์ ๋ค๋ฅธ ํ๋ก๊ทธ๋จ์์ ์ฌ์ฌ์ฉํ ์ ์๋๋ก ํ๋ PyTorch ๋ชจ๋์ ๋๋ค.
PyTorch ๊ธฐ๋ฐ Python ํ๋ก๊ทธ๋จ๊ณผ ๋ค๋ฅธ ํ๊ฒฝ์์ ๋ชจ๋ธ์ ์ฌ์ฌ์ฉํ ์ ์๋๋ก, ๐ค Transformers ๋ชจ๋ธ์ TorchScript๋ก ๋ด๋ณด๋ผ ์ ์๋ ์ธํฐํ์ด์ค๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด ๋ฌธ์์์๋ TorchScript๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๊ณ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ๋ ๊ฐ์ง๊ฐ ํ์ํฉ๋๋ค:
torchscript
ํ๋๊ทธ๋ก ๋ชจ๋ธ ์ธ์คํด์คํ- ๋๋ฏธ ์ ๋ ฅ์ ์ฌ์ฉํ ์์ ํ(forward pass)
์ด ํ์ ์กฐ๊ฑด๋ค์ ์๋์ ์์ธํ ์ค๋ช ๋ ๊ฒ์ฒ๋ผ ๊ฐ๋ฐ์๋ค์ด ์ฃผ์ํด์ผ ํ ์ฌ๋ฌ ์ฌํญ๋ค์ ์๋ฏธํฉ๋๋ค.
TorchScript ํ๋๊ทธ์ ๋ฌถ์ธ ๊ฐ์ค์น(tied weights)[[torchscript-flag-and-tied-weights]]
torchscript
ํ๋๊ทธ๊ฐ ํ์ํ ์ด์ ๋ ๋๋ถ๋ถ์ ๐ค Transformers ์ธ์ด ๋ชจ๋ธ์์ Embedding
๋ ์ด์ด์ Decoding
๋ ์ด์ด ๊ฐ์ ๋ฌถ์ธ ๊ฐ์ค์น(tied weights)๊ฐ ์กด์ฌํ๊ธฐ ๋๋ฌธ์
๋๋ค.
TorchScript๋ ๋ฌถ์ธ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ์ ์์ผ๋ฏ๋ก, ๋ฏธ๋ฆฌ ๊ฐ์ค์น๋ฅผ ํ๊ณ ๋ณต์ ํด์ผ ํฉ๋๋ค.
torchscript
ํ๋๊ทธ๋ก ์ธ์คํด์คํ๋ ๋ชจ๋ธ์ Embedding
๋ ์ด์ด์ Decoding
๋ ์ด์ด๊ฐ ๋ถ๋ฆฌ๋์ด ์์ผ๋ฏ๋ก ์ดํ์ ํ๋ จํด์๋ ์ ๋ฉ๋๋ค.
ํ๋ จ์ ํ๊ฒ ๋๋ฉด ๋ ๋ ์ด์ด ๊ฐ ๋๊ธฐํ๊ฐ ํด์ ๋์ด ์์์น ๋ชปํ ๊ฒฐ๊ณผ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.
์ธ์ด ๋ชจ๋ธ ํค๋๋ฅผ ๊ฐ์ง ์์ ๋ชจ๋ธ์ ๊ฐ์ค์น๊ฐ ๋ฌถ์ฌ ์์ง ์์์ ์ด ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.
์ด๋ฌํ ๋ชจ๋ธ๋ค์ torchscript
ํ๋๊ทธ ์์ด ์์ ํ๊ฒ ๋ด๋ณด๋ผ ์ ์์ต๋๋ค.
๋๋ฏธ ์ ๋ ฅ๊ณผ ํ์ค ๊ธธ์ด[[dummy-inputs-and-standard-lengths]]
๋๋ฏธ ์ ๋ ฅ(dummy inputs)์ ๋ชจ๋ธ์ ์์ ํ(forward pass)์ ์ฌ์ฉ๋ฉ๋๋ค. ์ ๋ ฅ ๊ฐ์ด ๋ ์ด์ด๋ฅผ ํตํด ์ ํ๋๋ ๋์, PyTorch๋ ๊ฐ ํ ์์์ ์คํ๋ ๋ค๋ฅธ ์ฐ์ฐ์ ์ถ์ ํฉ๋๋ค. ์ด๋ฌํ ๊ธฐ๋ก๋ ์ฐ์ฐ์ ๋ชจ๋ธ์ *์ถ์ (trace)*์ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
์ถ์ ์ ์ ๋ ฅ์ ์ฐจ์์ ๊ธฐ์ค์ผ๋ก ์์ฑ๋ฉ๋๋ค. ๋ฐ๋ผ์ ๋๋ฏธ ์ ๋ ฅ์ ์ฐจ์์ ์ ํ๋์ด, ๋ค๋ฅธ ์ํ์ค ๊ธธ์ด๋ ๋ฐฐ์น ํฌ๊ธฐ์์๋ ์๋ํ์ง ์์ต๋๋ค. ๋ค๋ฅธ ํฌ๊ธฐ๋ก ์๋ํ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค:
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
์ถ๋ก ์ค ๋ชจ๋ธ์ ๊ณต๊ธ๋ ๊ฐ์ฅ ํฐ ์ ๋ ฅ๋งํผ ํฐ ๋๋ฏธ ์ ๋ ฅ ํฌ๊ธฐ๋ก ๋ชจ๋ธ์ ์ถ์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ํจ๋ฉ์ ๋๋ฝ๋ ๊ฐ์ ์ฑ์ฐ๋ ๋ฐ ๋์์ด ๋ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ชจ๋ธ์ด ๋ ํฐ ์ ๋ ฅ ํฌ๊ธฐ๋ก ์ถ์ ๋๊ธฐ ๋๋ฌธ์, ํ๋ ฌ์ ์ฐจ์์ด ์ปค์ง๊ณ ๊ณ์ฐ๋์ด ๋ง์์ง๋๋ค.
๋ค์ํ ์ํ์ค ๊ธธ์ด ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ๋๋ ๊ฐ ์ ๋ ฅ์ ๋ํด ์ํ๋๋ ์ด ์ฐ์ฐ ํ์์ ์ฃผ์ํ๊ณ ์ฑ๋ฅ์ ์ฃผ์ ๊น๊ฒ ํ์ธํ์ธ์.
Python์์ TorchScript ์ฌ์ฉํ๊ธฐ[[using-torchscript-in-python]]
์ด ์น์ ์์๋ ๋ชจ๋ธ์ ์ ์ฅํ๊ณ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ, ์ถ์ ์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
๋ชจ๋ธ ์ ์ฅํ๊ธฐ[[saving-a-model]]
BertModel
์ TorchScript๋ก ๋ด๋ณด๋ด๋ ค๋ฉด BertConfig
ํด๋์ค์์ BertModel
์ ์ธ์คํด์คํํ ๋ค์, traced_bert.pt
๋ผ๋ ํ์ผ๋ช
์ผ๋ก ๋์คํฌ์ ์ ์ฅํ๋ฉด ๋ฉ๋๋ค.
from transformers import BertModel, BertTokenizer, BertConfig
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# ์
๋ ฅ ํ
์คํธ ํ ํฐํํ๊ธฐ
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# ์
๋ ฅ ํ ํฐ ์ค ํ๋๋ฅผ ๋ง์คํนํ๊ธฐ
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# ๋๋ฏธ ์
๋ ฅ ๋ง๋ค๊ธฐ
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# torchscript ํ๋๊ทธ๋ก ๋ชจ๋ธ ์ด๊ธฐํํ๊ธฐ
# ์ด ๋ชจ๋ธ์ LM ํค๋๊ฐ ์์ผ๋ฏ๋ก ํ์ํ์ง ์์ง๋ง, ํ๋๊ทธ๋ฅผ True๋ก ์ค์ ํฉ๋๋ค.
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
)
# ๋ชจ๋ธ์ ์ธ์คํดํธํํ๊ธฐ
model = BertModel(config)
# ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ๋์ด์ผ ํฉ๋๋ค.
model.eval()
# ๋ง์ฝ *from_pretrained*๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ธ์คํด์คํํ๋ ๊ฒฝ์ฐ, TorchScript ํ๋๊ทธ๋ฅผ ์ฝ๊ฒ ์ค์ ํ ์ ์์ต๋๋ค
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# ์ถ์ ์์ฑํ๊ธฐ
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ[[loading-a-model]]
์ด์ ์ด์ ์ ์ ์ฅํ BertModel
, ์ฆ traced_bert.pt
๋ฅผ ๋์คํฌ์์ ๊ฐ์ ธ์ค๊ณ , ์ด์ ์ ์ด๊ธฐํํ dummy_input
์์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
์ถ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํ๊ธฐ[[using-a-traced-model-for-inference]]
__call__
์ด์ค ์ธ๋์ค์ฝ์ด(dunder) ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์ถ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ธ์:
traced_model(tokens_tensor, segments_tensors)
Neuron SDK๋ก Hugging Face TorchScript ๋ชจ๋ธ์ AWS์ ๋ฐฐํฌํ๊ธฐ[[deploy-hugging-face-torchscript-models-to-aws-with-the-neuron-sdk]]
AWS๊ฐ ํด๋ผ์ฐ๋์์ ์ ๋น์ฉ, ๊ณ ์ฑ๋ฅ ๋จธ์ ๋ฌ๋ ์ถ๋ก ์ ์ํ Amazon EC2 Inf1 ์ธ์คํด์ค ์ ํ๊ตฐ์ ์ถ์ํ์ต๋๋ค. Inf1 ์ธ์คํด์ค๋ ๋ฅ๋ฌ๋ ์ถ๋ก ์ํฌ๋ก๋์ ํนํ๋ ๋ง์ถค ํ๋์จ์ด ๊ฐ์๊ธฐ์ธ AWS Inferentia ์นฉ์ผ๋ก ๊ตฌ๋๋ฉ๋๋ค. AWS Neuron์ Inferentia๋ฅผ ์ํ SDK๋ก, Inf1์ ๋ฐฐํฌํ๊ธฐ ์ํ transformers ๋ชจ๋ธ ์ถ์ ๋ฐ ์ต์ ํ๋ฅผ ์ง์ํฉ๋๋ค. Neuron SDK๋ ๋ค์๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค:
- ์ฝ๋ ํ ์ค๋ง ๋ณ๊ฒฝํ๋ฉด ํด๋ผ์ฐ๋ ์ถ๋ก ๋ฅผ ์ํด TorchScript ๋ชจ๋ธ์ ์ถ์ ํ๊ณ ์ต์ ํํ ์ ์๋ ์ฌ์ด API
- ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅํ ์ฑ๋ฅ ์ต์ ํ๋ก ๋น์ฉ ํจ์จ ํฅ์
- PyTorch ๋๋ TensorFlow๋ก ๊ตฌ์ถ๋ Hugging Face transformers ๋ชจ๋ธ ์ง์
์์ฌ์ [[implications]]
BERT (Bidirectional Encoder Representations from Transformers) ์ํคํ ์ฒ ๋๋ ๊ทธ ๋ณํ์ธ distilBERT ๋ฐ roBERTa๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ Transformers ๋ชจ๋ธ์ ์ถ์ถ ๊ธฐ๋ฐ ์ง์์๋ต, ์ํ์ค ๋ถ๋ฅ ๋ฐ ํ ํฐ ๋ถ๋ฅ์ ๊ฐ์ ๋น์์ฑ ์์ ์ Inf1์์ ์ต์์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. ๊ทธ๋ฌ๋ ํ ์คํธ ์์ฑ ์์ ๋ AWS Neuron MarianMT ํํ ๋ฆฌ์ผ์ ๋ฐ๋ผ Inf1์์ ์คํ๋๋๋ก ์กฐ์ ํ ์ ์์ต๋๋ค.
Inferentia์์ ๋ฐ๋ก ๋ณํํ ์ ์๋ ๋ชจ๋ธ์ ๋ํ ์์ธํ ์ ๋ณด๋ Neuron ๋ฌธ์์ Model Architecture Fit ์น์ ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ข ์์ฑ[[dependencies]]
AWS Neuron์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ณํํ๋ ค๋ฉด Neuron SDK ํ๊ฒฝ์ด ํ์ํฉ๋๋ค. ์ด๋ AWS Deep Learning AMI์ ๋ฏธ๋ฆฌ ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
AWS Neuron์ผ๋ก ๋ชจ๋ธ ๋ณํํ๊ธฐ[[converting-a-model-for-aws-neuron]]
BertModel
์ ์ถ์ ํ๋ ค๋ฉด, Python์์ TorchScript ์ฌ์ฉํ๊ธฐ์์์ ๋์ผํ ์ฝ๋๋ฅผ ์ฌ์ฉํด์ AWS NEURON์ฉ ๋ชจ๋ธ์ ๋ณํํฉ๋๋ค.
torch.neuron
ํ๋ ์์ํฌ ์ต์คํ
์
์ ๊ฐ์ ธ์ Python API๋ฅผ ํตํด Neuron SDK์ ๊ตฌ์ฑ ์์์ ์ ๊ทผํฉ๋๋ค:
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron
๋ค์ ์ค๋ง ์์ ํ๋ฉด ๋ฉ๋๋ค:
- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [token_tensor, segments_tensors])
์ด๋ก์จ Neuron SDK๊ฐ ๋ชจ๋ธ์ ์ถ์ ํ๊ณ Inf1 ์ธ์คํด์ค์ ์ต์ ํํ ์ ์๊ฒ ๋ฉ๋๋ค.
AWS Neuron SDK์ ๊ธฐ๋ฅ, ๋๊ตฌ, ์์ ํํ ๋ฆฌ์ผ ๋ฐ ์ต์ ์ ๋ฐ์ดํธ์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด AWS NeuronSDK ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.