Spaces:
Runtime error
CPU์์ ํจ์จ์ ์ธ ์ถ๋ก ํ๊ธฐ [[efficient-inference-on-cpu]]
์ด ๊ฐ์ด๋๋ CPU์์ ๋๊ท๋ชจ ๋ชจ๋ธ์ ํจ์จ์ ์ผ๋ก ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ์ ์ค์ ์ ๋๊ณ ์์ต๋๋ค.
๋ ๋น ๋ฅธ ์ถ๋ก ์ ์ํ BetterTransformer
[[bettertransformer-for-faster-inference]]
์ฐ๋ฆฌ๋ ์ต๊ทผ CPU์์ ํ
์คํธ, ์ด๋ฏธ์ง ๋ฐ ์ค๋์ค ๋ชจ๋ธ์ ๋น ๋ฅธ ์ถ๋ก ์ ์ํด BetterTransformer
๋ฅผ ํตํฉํ์ต๋๋ค. ์ด ํตํฉ์ ๋ํ ๋ ์์ธํ ๋ด์ฉ์ ์ด ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.
PyTorch JIT ๋ชจ๋ (TorchScript) [[pytorch-jitmode-torchscript]]
TorchScript๋ PyTorch ์ฝ๋์์ ์ง๋ ฌํ์ ์ต์ ํ๊ฐ ๊ฐ๋ฅํ ๋ชจ๋ธ์ ์์ฑํ ๋ ์ฐ์
๋๋ค. TorchScript๋ก ๋ง๋ค์ด์ง ํ๋ก๊ทธ๋จ์ ๊ธฐ์กด Python ํ๋ก์ธ์ค์์ ์ ์ฅํ ๋ค, ์ข
์์ฑ์ด ์๋ ์๋ก์ด ํ๋ก์ธ์ค๋ก ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. PyTorch์ ๊ธฐ๋ณธ ์ค์ ์ธ eager
๋ชจ๋์ ๋น๊ตํ์๋, jit
๋ชจ๋๋ ์ฐ์ฐ์ ๊ฒฐํฉ๊ณผ ๊ฐ์ ์ต์ ํ ๋ฐฉ๋ฒ๋ก ์ ํตํด ๋ชจ๋ธ ์ถ๋ก ์์ ๋๋ถ๋ถ ๋ ๋์ ์ฑ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
TorchScript์ ๋ํ ์น์ ํ ์๊ฐ๋ PyTorch TorchScript ํํ ๋ฆฌ์ผ์ ์ฐธ์กฐํ์ธ์.
JIT ๋ชจ๋์ ํจ๊ปํ๋ IPEX ๊ทธ๋ํ ์ต์ ํ [[ipex-graph-optimization-with-jitmode]]
Intelยฎ Extension for PyTorch(IPEX)๋ Transformers ๊ณ์ด ๋ชจ๋ธ์ jit ๋ชจ๋์์ ์ถ๊ฐ์ ์ธ ์ต์ ํ๋ฅผ ์ ๊ณตํฉ๋๋ค. jit ๋ชจ๋์ ๋๋ถ์ด Intelยฎ Extension for PyTorch(IPEX)๋ฅผ ํ์ฉํ์๊ธธ ๊ฐ๋ ฅํ ๊ถ์ฅ๋๋ฆฝ๋๋ค. Transformers ๋ชจ๋ธ์์ ์์ฃผ ์ฌ์ฉ๋๋ ์ผ๋ถ ์ฐ์ฐ์ ํจํด์ ์ด๋ฏธ jit ๋ชจ๋ ์ฐ์ฐ์ ๊ฒฐํฉ(operator fusion)์ ํํ๋ก Intelยฎ Extension for PyTorch(IPEX)์์ ์ง์๋๊ณ ์์ต๋๋ค. Multi-head-attention, Concat Linear, Linear+Add, Linear+Gelu, Add+LayerNorm ๊ฒฐํฉ ํจํด ๋ฑ์ด ์ด์ฉ ๊ฐ๋ฅํ๋ฉฐ ํ์ฉํ์ ๋ ์ฑ๋ฅ์ด ์ฐ์ํฉ๋๋ค. ์ฐ์ฐ์ ๊ฒฐํฉ์ ์ด์ ์ ์ฌ์ฉ์์๊ฒ ๊ณ ์ค๋ํ ์ ๋ฌ๋ฉ๋๋ค. ๋ถ์์ ๋ฐ๋ฅด๋ฉด, ์ง์ ์๋ต, ํ ์คํธ ๋ถ๋ฅ ๋ฐ ํ ํฐ ๋ถ๋ฅ์ ๊ฐ์ ๊ฐ์ฅ ์ธ๊ธฐ ์๋ NLP ํ์คํฌ ์ค ์ฝ 70%๊ฐ ์ด๋ฌํ ๊ฒฐํฉ ํจํด์ ์ฌ์ฉํ์ฌ Float32 ์ ๋ฐ๋์ BFloat16 ํผํฉ ์ ๋ฐ๋ ๋ชจ๋์์ ์ฑ๋ฅ์์ ์ด์ ์ ์ป์ ์ ์์ต๋๋ค.
IPEX ๊ทธ๋ํ ์ต์ ํ์ ๋ํ ์์ธํ ์ ๋ณด๋ฅผ ํ์ธํ์ธ์.
IPEX ์ค์น: [[ipex-installation]]
IPEX ๋ฐฐํฌ ์ฃผ๊ธฐ๋ PyTorch๋ฅผ ๋ฐ๋ผ์ ์ด๋ฃจ์ด์ง๋๋ค. ์์ธํ ์ ๋ณด๋ IPEX ์ค์น ๋ฐฉ๋ฒ์ ํ์ธํ์ธ์.
JIT ๋ชจ๋ ์ฌ์ฉ๋ฒ [[usage-of-jitmode]]
ํ๊ฐ ๋๋ ์์ธก์ ์ํด Trainer์์ JIT ๋ชจ๋๋ฅผ ์ฌ์ฉํ๋ ค๋ฉด Trainer์ ๋ช
๋ น ์ธ์์ jit_mode_eval
์ ์ถ๊ฐํด์ผ ํฉ๋๋ค.
PyTorch์ ๋ฒ์ ์ด 1.14.0 ์ด์์ด๋ผ๋ฉด, jit ๋ชจ๋๋ jit.trace์์ dict ์ ๋ ฅ์ด ์ง์๋๋ฏ๋ก, ๋ชจ๋ ๋ชจ๋ธ์ ์์ธก๊ณผ ํ๊ฐ๊ฐ ๊ฐ์ ๋ ์ ์์ต๋๋ค.
PyTorch์ ๋ฒ์ ์ด 1.14.0 ๋ฏธ๋ง์ด๋ผ๋ฉด, ์ง์ ์๋ต ๋ชจ๋ธ๊ณผ ๊ฐ์ด forward ๋งค๊ฐ๋ณ์์ ์์๊ฐ jit.trace์ ํํ ์ ๋ ฅ ์์์ ์ผ์นํ๋ ๋ชจ๋ธ์ ๋์ด ๋ ์ ์์ต๋๋ค. ํ ์คํธ ๋ถ๋ฅ ๋ชจ๋ธ๊ณผ ๊ฐ์ด forward ๋งค๊ฐ๋ณ์ ์์๊ฐ jit.trace์ ํํ ์ ๋ ฅ ์์์ ๋ค๋ฅธ ๊ฒฝ์ฐ, jit.trace๊ฐ ์คํจํ๋ฉฐ ์์ธ๊ฐ ๋ฐ์ํฉ๋๋ค. ์ด๋ ์์ธ์ํฉ์ ์ฌ์ฉ์์๊ฒ ์๋ฆฌ๊ธฐ ์ํด Logging์ด ์ฌ์ฉ๋ฉ๋๋ค.
Transformers ์ง์ ์๋ต์ ์ฌ์ฉ ์ฌ๋ก ์์๋ฅผ ์ฐธ์กฐํ์ธ์.
CPU์์ jit ๋ชจ๋๋ฅผ ์ฌ์ฉํ ์ถ๋ก :
python run_qa.py \ --model_name_or_path csarron/bert-base-uncased-squad-v1 \ --dataset_name squad \ --do_eval \ --max_seq_length 384 \ --doc_stride 128 \ --output_dir /tmp/ \ --no_cuda \ --jit_mode_eval
CPU์์ IPEX์ ํจ๊ป jit ๋ชจ๋๋ฅผ ์ฌ์ฉํ ์ถ๋ก :
python run_qa.py \ --model_name_or_path csarron/bert-base-uncased-squad-v1 \ --dataset_name squad \ --do_eval \ --max_seq_length 384 \ --doc_stride 128 \ --output_dir /tmp/ \ --no_cuda \ --use_ipex \ --jit_mode_eval