Spaces:
Runtime error
A newer version of the Gradio SDK is available:
5.25.2
TensorFlow๋ก TPU์์ ํ๋ จํ๊ธฐ[[training-on-tpu-with-tensorflow]]
์์ธํ ์ค๋ช ์ด ํ์ํ์ง ์๊ณ ๋ฐ๋ก TPU ์ํ ์ฝ๋๋ฅผ ์์ํ๊ณ ์ถ๋ค๋ฉด ์ฐ๋ฆฌ์ TPU ์์ ๋ ธํธ๋ถ!์ ํ์ธํ์ธ์.
TPU๊ฐ ๋ฌด์์ธ๊ฐ์?[[what-is-a-tpu]]
TPU๋ ํ ์ ์ฒ๋ฆฌ ์ฅ์น์ ๋๋ค. Google์์ ์ค๊ณํ ํ๋์จ์ด๋ก, GPU์ฒ๋ผ ์ ๊ฒฝ๋ง ๋ด์์ ํ ์ ์ฐ์ฐ์ ๋์ฑ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌํ๊ธฐ ์ํด ์ฌ์ฉ๋ฉ๋๋ค. ๋คํธ์ํฌ ํ๋ จ๊ณผ ์ถ๋ก ๋ชจ๋์ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก Google์ ํด๋ผ์ฐ๋ ์๋น์ค๋ฅผ ํตํด ์ด์ฉํ ์ ์์ง๋ง, Google Colab๊ณผ Kaggle Kernel์ ํตํด ์๊ท๋ชจ TPU๋ฅผ ๋ฌด๋ฃ๋ก ์ง์ ์ด์ฉํ ์๋ ์์ต๋๋ค.
๐ค Transformers์ ๋ชจ๋ Tensorflow ๋ชจ๋ธ์ Keras ๋ชจ๋ธ์ด๊ธฐ ๋๋ฌธ์, ์ด ๋ฌธ์์์ ๋ค๋ฃจ๋ ๋๋ถ๋ถ์ ๋ฉ์๋๋ ๋์ฒด๋ก ๋ชจ๋ Keras ๋ชจ๋ธ์ ์ํ TPU ํ๋ จ์ ์ ์ฉํ ์ ์์ต๋๋ค! ํ์ง๋ง Transformer์ ๋ฐ์ดํฐ ์ธํธ์ HuggingFace ์ํ๊ณ(hug-o-system?)์ ํนํ๋ ๋ช ๊ฐ์ง ์ฌํญ์ด ์์ผ๋ฉฐ, ํด๋น ์ฌํญ์ ๋ํด ์ค๋ช ํ ๋ ๋ฐ๋์ ์ธ๊ธํ๋๋ก ํ๊ฒ ์ต๋๋ค.
์ด๋ค ์ข ๋ฅ์ TPU๊ฐ ์๋์?[[what-kinds-of-tpu-are-available]]
์ ๊ท ์ฌ์ฉ์๋ TPU์ ๋ฒ์์ ๋ค์ํ ์ด์ฉ ๋ฐฉ๋ฒ์ ๋ํด ๋งค์ฐ ํผ๋์ค๋ฌ์ํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค. TPU ๋ ธ๋์ TPU VM์ ์ฐจ์ด์ ์ ๊ฐ์ฅ ๋จผ์ ์ดํดํด์ผ ํ ํต์ฌ์ ์ธ ๊ตฌ๋ถ ์ฌํญ์ ๋๋ค.
TPU ๋ ธ๋๋ฅผ ์ฌ์ฉํ๋ค๋ฉด, ์ค์ ๋ก๋ ์๊ฒฉ TPU๋ฅผ ๊ฐ์ ์ ์ผ๋ก ์ด์ฉํ๋ ๊ฒ์ ๋๋ค. ๋คํธ์ํฌ์ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ์ด๊ธฐํํ ๋ค์, ์ด๋ฅผ ์๊ฒฉ ๋ ธ๋๋ก ์ ๋ฌํ ๋ณ๋์ VM์ด ํ์ํฉ๋๋ค. Google Colab์์ TPU๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, TPU ๋ ธ๋ ๋ฐฉ์์ผ๋ก ์ด์ฉํ๊ฒ ๋ฉ๋๋ค.
TPU ๋ ธ๋๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ์ด๋ฅผ ์ฌ์ฉํ์ง ์๋ ์ฌ์ฉ์์๊ฒ ์๊ธฐ์น ์์ ํ์์ด ๋ฐ์ํ๊ธฐ๋ ํฉ๋๋ค! ํนํ, TPU๋ ํ์ด์ฌ ์ฝ๋๋ฅผ ์คํํ๋ ๊ธฐ๊ธฐ(machine)์ ๋ฌผ๋ฆฌ์ ์ผ๋ก ๋ค๋ฅธ ์์คํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ก์ปฌ ๊ธฐ๊ธฐ์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ ์ ์์ต๋๋ค. ์ฆ, ์ปดํจํฐ์ ๋ด๋ถ ์ ์ฅ์์์ ๊ฐ์ ธ์ค๋ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ ์ ๋ ์๋ํ์ง ์์ต๋๋ค! ๋ก์ปฌ ๊ธฐ๊ธฐ์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํ๋ ๋์ ์, ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ด ์๊ฒฉ TPU ๋ ธ๋์์ ์คํ ์ค์ผ ๋์๋ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ด ๊ณ์ ์ด์ฉํ ์ ์๋ Google Cloud Storage์ ๋ฐ์ดํฐ๋ฅผ ์ ์ฅํด์ผ ํฉ๋๋ค.
๋ฉ๋ชจ๋ฆฌ์ ์๋ ๋ชจ๋ ๋ฐ์ดํฐ๋ฅผ np.ndarray
๋๋ tf.Tensor
๋ก ๋ง์ถ ์ ์๋ค๋ฉด, Google Cloud Storage์ ์
๋ก๋ํ ํ์ ์์ด, Colab ๋๋ TPU ๋
ธ๋๋ฅผ ์ฌ์ฉํด์ ํด๋น ๋ฐ์ดํฐ์ fit()
ํ ์ ์์ต๋๋ค.
๐คํน์ํ Hugging Face ํ๐ค: TF ์ฝ๋ ์์ ์์ ๋ณผ ์ ์๋ Dataset.to_tf_dataset()
๋ฉ์๋์ ๊ทธ ์์ ๋ํผ(wrapper)์ธ model.prepare_tf_dataset()
๋ ๋ชจ๋ TPU ๋
ธ๋์์ ์๋ํ์ง ์์ต๋๋ค. ๊ทธ ์ด์ ๋ tf.data.Dataset
์ ์์ฑํ๋๋ผ๋ โ์์ํโ tf.data
ํ์ดํ๋ผ์ธ์ด ์๋๋ฉฐ tf.numpy_function
๋๋ Dataset.from_generator()
๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ๋ณธ HuggingFace Dataset
์์ ๋ฐ์ดํฐ๋ฅผ ์ ์กํ๊ธฐ ๋๋ฌธ์
๋๋ค. ์ด HuggingFace Dataset
๋ ๋ก์ปฌ ๋์คํฌ์ ์๋ ๋ฐ์ดํฐ๋ก ์ง์๋๋ฉฐ ์๊ฒฉ TPU ๋
ธ๋๊ฐ ์ฝ์ ์ ์์ต๋๋ค.
TPU๋ฅผ ์ด์ฉํ๋ ๋ ๋ฒ์งธ ๋ฐฉ๋ฒ์ TPU VM์ ์ฌ์ฉํ๋ ๊ฒ์ ๋๋ค. TPU VM์ ์ฌ์ฉํ ๋, GPU VM์์ ํ๋ จํ๋ ๊ฒ๊ณผ ๊ฐ์ด TPU๊ฐ ์ฅ์ฐฉ๋ ๊ธฐ๊ธฐ์ ์ง์ ์ฐ๊ฒฐํฉ๋๋ค. ํนํ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ๊ณผ ๊ด๋ จํ์ฌ, TPU VM์ ๋์ฒด๋ก ์์ ํ๊ธฐ ๋ ์ฝ์ต๋๋ค. ์์ ๋ชจ๋ ๊ฒฝ๊ณ ๋ TPU VM์๋ ํด๋น๋์ง ์์ต๋๋ค!
์ด ๋ฌธ์๋ ์๊ฒฌ์ด ํฌํจ๋ ๋ฌธ์์ด๋ฉฐ, ์ ํฌ์ ์๊ฒฌ์ด ์ฌ๊ธฐ์ ์์ต๋๋ค: ๊ฐ๋ฅํ๋ฉด TPU ๋ ธ๋๋ฅผ ์ฌ์ฉํ์ง ๋ง์ธ์. TPU ๋ ธ๋๋ TPU VM๋ณด๋ค ๋ ๋ณต์กํ๊ณ ๋๋ฒ๊น ํ๊ธฐ๊ฐ ๋ ์ด๋ ต์ต๋๋ค. ๋ํ ํฅํ์๋ ์ง์๋์ง ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค. Google์ ์ต์ TPU์ธ TPUv4๋ TPU VM์ผ๋ก๋ง ์ด์ฉํ ์ ์์ผ๋ฏ๋ก, TPU ๋ ธ๋๋ ์ ์ ๋ "๊ตฌ์" ์ด์ฉ ๋ฐฉ๋ฒ์ด ๋ ๊ฒ์ผ๋ก ์ ๋ง๋ฉ๋๋ค. ๊ทธ๋ฌ๋ TPU ๋ ธ๋๋ฅผ ์ฌ์ฉํ๋ Colab๊ณผ Kaggle Kernel์์๋ง ๋ฌด๋ฃ TPU ์ด์ฉ์ด ๊ฐ๋ฅํ ๊ฒ์ผ๋ก ํ์ธ๋์ด, ํ์ํ ๊ฒฝ์ฐ ์ด๋ฅผ ๋ค๋ฃจ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํด ๋๋ฆฌ๊ฒ ์ต๋๋ค! ์ด์ ๋ํ ์์ธํ ์ค๋ช ์ด ๋ด๊ธด ์ฝ๋ ์ํ์ TPU ์์ ๋ ธํธ๋ถ์์ ํ์ธํ์๊ธฐ ๋ฐ๋๋๋ค.
์ด๋ค ํฌ๊ธฐ์ TPU๋ฅผ ์ฌ์ฉํ ์ ์๋์?[[what-sizes-of-tpu-are-available]]
๋จ์ผ TPU(v2-8/v3-8/v4-8)๋ 8๊ฐ์ ๋ณต์ ๋ณธ(replicas)์ ์คํํฉ๋๋ค. TPU๋ ์๋ฐฑ ๋๋ ์์ฒ ๊ฐ์ ๋ณต์ ๋ณธ์ ๋์์ ์คํํ ์ ์๋ pod๋ก ์กด์ฌํฉ๋๋ค. ๋จ์ผ TPU๋ฅผ ํ๋ ์ด์ ์ฌ์ฉํ์ง๋ง ์ ์ฒด Pod๋ณด๋ค ์ ๊ฒ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ(์๋ฅผ ๋ค๋ฉด, v3-32), TPU ๊ตฌ์ฑ์ pod ์ฌ๋ผ์ด์ค๋ผ๊ณ ํฉ๋๋ค.
Colab์ ํตํด ๋ฌด๋ฃ TPU์ ์ด์ฉํ๋ ๊ฒฝ์ฐ, ๊ธฐ๋ณธ์ ์ผ๋ก ๋จ์ผ v2-8 TPU๋ฅผ ์ ๊ณต๋ฐ์ต๋๋ค.
XLA์ ๋ํด ๋ค์ด๋ณธ ์ ์ด ์์ต๋๋ค. XLA๋ ๋ฌด์์ด๊ณ TPU์ ์ด๋ค ๊ด๋ จ์ด ์๋์?[[i-keep-hearing-about-this-xla-thing-whats-xla-and-how-does-it-relate-to-tpus]]
XLA๋ ์ต์ ํ ์ปดํ์ผ๋ฌ๋ก, TensorFlow์ JAX์์ ๋ชจ๋ ์ฌ์ฉ๋ฉ๋๋ค. JAX์์๋ ์ ์ผํ ์ปดํ์ผ๋ฌ์ด์ง๋ง, TensorFlow์์๋ ์ ํ ์ฌํญ์
๋๋ค(ํ์ง๋ง TPU์์๋ ํ์์
๋๋ค!). Keras ๋ชจ๋ธ์ ํ๋ จํ ๋ ์ด๋ฅผ ํ์ฑํํ๋ ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ์ jit_compile=True
์ธ์๋ฅผ model.compile()
์ ์ ๋ฌํ๋ ๊ฒ์
๋๋ค. ์ค๋ฅ๊ฐ ์๊ณ ์ฑ๋ฅ์ด ์ํธํ๋ค๋ฉด, TPU๋ก ์ ํํ ์ค๋น๊ฐ ๋์๋ค๋ ์ข์ ์ ํธ์
๋๋ค!
TPU์์ ๋๋ฒ๊น ํ๋ ๊ฒ์ ๋๊ฐ CPU/GPU๋ณด๋ค ์กฐ๊ธ ๋ ์ด๋ ต๊ธฐ ๋๋ฌธ์, TPU์์ ์๋ํ๊ธฐ ์ ์ ๋จผ์ XLA๋ก CPU/GPU์์ ์ฝ๋๋ฅผ ์คํํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. ๋ฌผ๋ก ์ค๋ ํ์ตํ ํ์๋ ์์ต๋๋ค. ์ฆ, ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ ํ์ดํ๋ผ์ธ์ด ์์๋๋ก ์๋ํ๋์ง ํ์ธํ๊ธฐ ์ํด ๋ช ๋จ๊ณ๋ง ๊ฑฐ์น๋ฉด ๋ฉ๋๋ค.
XLA๋ก ์ปดํ์ผ๋ ์ฝ๋๋ ๋์ฒด๋ก ๋ ๋น ๋ฆ
๋๋ค. ๋ฐ๋ผ์ TPU์์ ์คํํ ๊ณํ์ด ์๋๋ผ๋, jit_compile=True
๋ฅผ ์ถ๊ฐํ๋ฉด ์ฑ๋ฅ์ด ํฅ์๋ ์ ์์ต๋๋ค. ํ์ง๋ง XLA ํธํ์ฑ์ ๋ํ ์๋ ์ฃผ์ ์ฌํญ์ ๋ฐ๋์ ํ์ธํ์ธ์!
๋ผ์ํ ๊ฒฝํ์์ ์ป์ ํ: jit_compile=True
๋ฅผ ์ฌ์ฉํ๋ฉด ์๋๋ฅผ ๋์ด๊ณ CPU/GPU ์ฝ๋๊ฐ XLA์ ํธํ๋๋์ง ๊ฒ์ฆํ ์ ์๋ ์ข์ ๋ฐฉ๋ฒ์ด์ง๋ง, ์ค์ TPU์์ ํ๋ จํ ๋ ๊ทธ๋๋ก ๋จ๊ฒจ๋๋ฉด ๋ง์ ๋ฌธ์ ๋ฅผ ์ด๋ํ ์ ์์ต๋๋ค. XLA ์ปดํ์ผ์ TPU์์ ์์์ ์ผ๋ก ์ด๋ค์ง๋ฏ๋ก, ์ค์ TPU์์ ์ฝ๋๋ฅผ ์คํํ๊ธฐ ์ ์ ํด๋น ์ค์ ์ ๊ฑฐํ๋ ๊ฒ์ ์์ง ๋ง์ธ์!
์ XLA ๋ชจ๋ธ๊ณผ ํธํํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?[[how-do-i-make-my-model-xla-compatible]]
๋๋ถ๋ถ์ ๊ฒฝ์ฐ, ์ฌ๋ฌ๋ถ์ ์ฝ๋๋ ์ด๋ฏธ XLA์ ํธํ๋ ๊ฒ์ ๋๋ค! ๊ทธ๋ฌ๋ ํ์ค TensorFlow์์ ์๋ํ์ง๋ง, XLA์์๋ ์๋ํ์ง ์๋ ๋ช ๊ฐ์ง ์ฌํญ์ด ์์ต๋๋ค. ์ด๋ฅผ ์๋ ์ธ ๊ฐ์ง ํต์ฌ ๊ท์น์ผ๋ก ๊ฐ์ถ๋ ธ์ต๋๋ค:
ํน์ํ HuggingFace ํ๐ค: ์ ํฌ๋ TensorFlow ๋ชจ๋ธ๊ณผ ์์ค ํจ์๋ฅผ XLA์ ํธํ๋๋๋ก ์ฌ์์ฑํ๋ ๋ฐ ๋ง์ ๋
ธ๋ ฅ์ ๊ธฐ์ธ์์ต๋๋ค. ์ ํฌ์ ๋ชจ๋ธ๊ณผ ์์ค ํจ์๋ ๋๊ฐ ๊ธฐ๋ณธ์ ์ผ๋ก ๊ท์น #1๊ณผ #2๋ฅผ ๋ฐ๋ฅด๋ฏ๋ก transformers
๋ชจ๋ธ์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ์ด๋ฅผ ๊ฑด๋๋ธ ์ ์์ต๋๋ค. ํ์ง๋ง ์์ฒด ๋ชจ๋ธ๊ณผ ์์ค ํจ์๋ฅผ ์์ฑํ ๋๋ ์ด๋ฌํ ๊ท์น์ ์์ง ๋ง์ธ์!
XLA ๊ท์น #1: ์ฝ๋์์ โ๋ฐ์ดํฐ ์ข ์ ์กฐ๊ฑด๋ฌธโ์ ์ฌ์ฉํ ์ ์์ต๋๋ค[[xla-rule-1-your-code-cannot-have-datadependent-conditionals]]
์ด๋ค if
๋ฌธ๋ tf.Tensor
๋ด๋ถ์ ๊ฐ์ ์ข
์๋ ์ ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์ด ์ฝ๋ ๋ธ๋ก์ XLA๋ก ์ปดํ์ผํ ์ ์์ต๋๋ค!
if tf.reduce_sum(tensor) > 10:
tensor = tensor / 2.0
์ฒ์์๋ ๋งค์ฐ ์ ํ์ ์ผ๋ก ๋ณด์ผ ์ ์์ง๋ง, ๋๋ถ๋ถ์ ์ ๊ฒฝ๋ง ์ฝ๋์์๋ ์ด๋ฅผ ์ํํ ํ์๊ฐ ์์ต๋๋ค. tf.cond
๋ฅผ ์ฌ์ฉํ๊ฑฐ๋(์ฌ๊ธฐ ๋ฌธ์๋ฅผ ์ฐธ์กฐ), ๋ค์๊ณผ ๊ฐ์ด ์กฐ๊ฑด๋ฌธ์ ์ ๊ฑฐํ๊ณ ๋์ ์งํ ๋ณ์๋ฅผ ์ฌ์ฉํ๋ ์๋ฆฌํ ์ํ ํธ๋ฆญ์ ์ฐพ์๋ด์ด ์ด ์ ํ์ ์ฐํํ ์ ์์ต๋๋ค:
sum_over_10 = tf.cast(tf.reduce_sum(tensor) > 10, tf.float32)
tensor = tensor / (1.0 + sum_over_10)
์ด ์ฝ๋๋ ์์ ์ฝ๋์ ์ ํํ ๋์ผํ ํจ๊ณผ๋ฅผ ๊ตฌํํ์ง๋ง, ์กฐ๊ฑด๋ฌธ์ ์ ๊ฑฐํ์ฌ ๋ฌธ์ ์์ด XLA๋ก ์ปดํ์ผ๋๋๋ก ํฉ๋๋ค!
XLA ๊ท์น #2: ์ฝ๋์์ "๋ฐ์ดํฐ ์ข ์ ํฌ๊ธฐ"๋ฅผ ๊ฐ์ง ์ ์์ต๋๋ค[[xla-rule-2-your-code-cannot-have-datadependent-shapes]]
์ฝ๋์์ ๋ชจ๋ tf.Tensor
๊ฐ์ฒด์ ํฌ๊ธฐ๊ฐ ํด๋น ๊ฐ์ ์ข
์๋ ์ ์๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ์๋ฅผ ๋ค์ด, tf.unique
ํจ์๋ ์
๋ ฅ์์ ๊ฐ ๊ณ ์ ๊ฐ์ ์ธ์คํด์ค ํ๋๋ฅผ ํฌํจํ๋ tensor
๋ฅผ ๋ฐํํ๊ธฐ ๋๋ฌธ์ XLA๋ก ์ปดํ์ผํ ์ ์์ต๋๋ค. ์ด ์ถ๋ ฅ์ ํฌ๊ธฐ๋ ์
๋ ฅ Tensor
๊ฐ ์ผ๋ง๋ ๋ฐ๋ณต์ ์ธ์ง์ ๋ฐ๋ผ ๋ถ๋ช
ํ ๋ฌ๋ผ์ง ๊ฒ์ด๋ฏ๋ก, XLA๋ ์ด๋ฅผ ์ฒ๋ฆฌํ์ง ๋ชปํฉ๋๋ค!
์ผ๋ฐ์ ์ผ๋ก, ๋๋ถ๋ถ์ ์ ๊ฒฝ๋ง ์ฝ๋๋ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ๊ท์น 2๋ฅผ ๋ฐ๋ฆ ๋๋ค. ๊ทธ๋ฌ๋ ๋ฌธ์ ๊ฐ ๋๋ ๋ช ๊ฐ์ง ๋ํ์ ์ธ ์ฌ๋ก๊ฐ ์์ต๋๋ค. ๊ฐ์ฅ ํํ ์ฌ๋ก ์ค ํ๋๋ ๋ ์ด๋ธ ๋ง์คํน์ ์ฌ์ฉํ์ฌ ์์ค(loss)์ ๊ณ์ฐํ ๋, ํด๋น ์์น๋ฅผ ๋ฌด์ํ๋๋ก ๋ํ๋ด๊ธฐ ์ํด ๋ ์ด๋ธ์ ์์ ๊ฐ์ผ๋ก ์ค์ ํ๋ ๊ฒฝ์ฐ์ ๋๋ค. ๋ ์ด๋ธ ๋ง์คํน์ ์ง์ํ๋ NumPy๋ PyTorch ์์ค ํจ์๋ฅผ ๋ณด๋ฉด ๋ถ ์ธ๋ฑ์ฑ์ ์ฌ์ฉํ๋ ๋ค์๊ณผ ๊ฐ์ ์ฝ๋๋ฅผ ์์ฃผ ์ ํ ์ ์์ต๋๋ค:
label_mask = labels >= 0
masked_outputs = outputs[label_mask]
masked_labels = labels[label_mask]
loss = compute_loss(masked_outputs, masked_labels)
mean_loss = torch.mean(loss)
์ด ์ฝ๋๋ NumPy๋ PyTorch์์๋ ๋ฌธ์ ์์ด ์๋ํ์ง๋ง, XLA์์๋ ์์๋ฉ๋๋ค! ์ ๊ทธ๋ด๊น์? ์ผ๋ง๋ ๋ง์ ์์น๊ฐ ๋ง์คํน๋๋์ง์ ๋ฐ๋ผ masked_outputs
์ masked_labels
์ ํฌ๊ธฐ๊ฐ ๋ฌ๋ผ์ ธ์, ๋ฐ์ดํฐ ์ข
์ ํฌ๊ธฐ๊ฐ ๋๊ธฐ ๋๋ฌธ์
๋๋ค. ๊ทธ๋ฌ๋ ๊ท์น #1๊ณผ ๋ง์ฐฌ๊ฐ์ง๋ก, ์ด ์ฝ๋๋ฅผ ๋ค์ ์์ฑํ๋ฉด ๋ฐ์ดํฐ ์ข
์์ ๋ชจ์ ํฌ๊ธฐ๊ฐ ์ ํํ ๋์ผํ ์ถ๋ ฅ์ ์ฐ์ถํ ์ ์์ต๋๋ค.
label_mask = tf.cast(labels >= 0, tf.float32)
loss = compute_loss(outputs, labels)
loss = loss * label_mask # Set negative label positions to 0
mean_loss = tf.reduce_sum(loss) / tf.reduce_sum(label_mask)
์ฌ๊ธฐ์, ๋ชจ๋ ์์น์ ๋ํ ์์ค์ ๊ณ์ฐํ์ง๋ง, ํ๊ท ์ ๊ณ์ฐํ ๋ ๋ถ์์ ๋ถ๋ชจ ๋ชจ๋์์ ๋ง์คํฌ๋ ์์น๋ฅผ 0์ผ๋ก ์ฒ๋ฆฌํฉ๋๋ค. ์ด๋ ๋ฐ์ดํฐ ์ข
์ ํฌ๊ธฐ๋ฅผ ๋ฐฉ์งํ๊ณ XLA ํธํ์ฑ์ ์ ์งํ๋ฉด์ ์ฒซ ๋ฒ์งธ ๋ธ๋ก๊ณผ ์ ํํ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ฐ์ถํฉ๋๋ค. ๊ท์น #1์์์ ๋์ผํ ํธ๋ฆญ์ ์ฌ์ฉํ์ฌ tf.bool
์ tf.float32
๋ก ๋ณํํ๊ณ ์ด๋ฅผ ์งํ ๋ณ์๋ก ์ฌ์ฉํฉ๋๋ค. ํด๋น ํธ๋ฆญ์ ๋งค์ฐ ์ ์ฉํ๋ฉฐ, ์์ฒด ์ฝ๋๋ฅผ XLA๋ก ๋ณํํด์ผ ํ ๊ฒฝ์ฐ ๊ธฐ์ตํด ๋์ธ์!
XLA ๊ท์น #3: XLA๋ ๊ฐ๊ธฐ ๋ค๋ฅธ ์ ๋ ฅ ํฌ๊ธฐ๊ฐ ๋ํ๋ ๋๋ง๋ค ๋ชจ๋ธ์ ๋ค์ ์ปดํ์ผํด์ผ ํฉ๋๋ค[[xla-rule-3-xla-will-need-to-recompile-your-model-for-every-different-input-shape-it-sees]]
์ด๊ฒ์ ๊ฐ์ฅ ํฐ ๋ฌธ์ ์ ๋๋ค. ์ ๋ ฅ ํฌ๊ธฐ๊ฐ ๋งค์ฐ ๊ฐ๋ณ์ ์ธ ๊ฒฝ์ฐ, XLA๋ ๋ชจ๋ธ์ ๋ฐ๋ณตํด์ ๋ค์ ์ปดํ์ผํด์ผ ํ๋ฏ๋ก ์ฑ๋ฅ์ ํฐ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค. ์ด ๋ฌธ์ ๋ ํ ํฐํ ํ ์ ๋ ฅ ํ ์คํธ์ ๊ธธ์ด๊ฐ ๊ฐ๋ณ์ ์ธ NLP ๋ชจ๋ธ์์ ์ฃผ๋ก ๋ฐ์ํฉ๋๋ค. ๋ค๋ฅธ ๋ชจ๋ฌ๋ฆฌํฐ์์๋ ์ ์ ํฌ๊ธฐ๊ฐ ๋ ํํ๋ฉฐ, ํด๋น ๊ท์น์ด ํจ์ฌ ๋ ๋ฌธ์ ์ ๋ฉ๋๋ค.
๊ท์น #3์ ์ด๋ป๊ฒ ์ฐํํ ์ ์์๊น์? ํต์ฌ์ ํจ๋ฉ์
๋๋ค. ๋ชจ๋ ์
๋ ฅ์ ๋์ผํ ๊ธธ์ด๋ก ํจ๋ฉํ ๋ค์, attention_mask
๋ฅผ ์ฌ์ฉํ๋ฉด ์ด๋ค XLA ๋ฌธ์ ๋ ์์ด ๊ฐ๋ณ ํฌ๊ธฐ์์ ๊ฐ์ ธ์จ ๊ฒ๊ณผ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๊ณผ๋ํ ํจ๋ฉ์ ์ฌ๊ฐํ ์๋ ์ ํ๋ฅผ ์ผ๊ธฐํ ์๋ ์์ต๋๋ค. ๋ชจ๋ ์ํ์ ์ ์ฒด ๋ฐ์ดํฐ ์ธํธ์ ์ต๋ ๊ธธ์ด๋ก ํจ๋ฉํ๋ฉด, ๋ฌดํํ ํจ๋ฉ ํ ํฐ์ผ๋ก ๊ตฌ์ฑ๋ ๋ฐฐ์น๊ฐ ์์ฑ๋์ด ๋ง์ ์ฐ์ฐ๊ณผ ๋ฉ๋ชจ๋ฆฌ๊ฐ ๋ญ๋น๋ ์ ์์ต๋๋ค!
์ด ๋ฌธ์ ์ ๋ํ ์๋ฒฝํ ํด๊ฒฐ์ฑ ์ ์์ต๋๋ค. ํ์ง๋ง, ๋ช ๊ฐ์ง ํธ๋ฆญ์ ์๋ํด๋ณผ ์ ์์ต๋๋ค. ํ ๊ฐ์ง ์ ์ฉํ ํธ๋ฆญ์ ์ํ ๋ฐฐ์น๋ฅผ 32 ๋๋ 64 ํ ํฐ๊ณผ ๊ฐ์ ์ซ์์ ๋ฐฐ์๊น์ง ํจ๋ฉํ๋ ๊ฒ์ ๋๋ค. ์ด๋ ํ ํฐ ์๊ฐ ์ํญ ์ฆ๊ฐํ์ง๋ง, ๋ชจ๋ ์ ๋ ฅ ํฌ๊ธฐ๊ฐ 32 ๋๋ 64์ ๋ฐฐ์์ฌ์ผ ํ๊ธฐ ๋๋ฌธ์ ๊ณ ์ ํ ์ ๋ ฅ ํฌ๊ธฐ์ ์๊ฐ ๋ํญ ์ค์ด๋ญ๋๋ค. ๊ณ ์ ํ ์ ๋ ฅ ํฌ๊ธฐ๊ฐ ์ ๋ค๋ ๊ฒ์ XLA ์ปดํ์ผ ํ์๊ฐ ์ ์ด์ง๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค!
๐คํน์ํ HuggingFace ํ๐ค: ํ ํฌ๋์ด์ ์ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ์ ๋์์ด ๋ ์ ์๋ ๋ฉ์๋๊ฐ ์์ต๋๋ค. ํ ํฌ๋์ด์ ๋ฅผ ๋ถ๋ฌ์ฌ ๋ padding="max_length"
๋๋ padding="longest"
๋ฅผ ์ฌ์ฉํ์ฌ ํจ๋ฉ๋ ๋ฐ์ดํฐ๋ฅผ ์ถ๋ ฅํ๋๋ก ํ ์ ์์ต๋๋ค. ํ ํฌ๋์ด์ ์ ๋ฐ์ดํฐ ์ฝ๋ ์ดํฐ๋ ๋ํ๋๋ ๊ณ ์ ํ ์
๋ ฅ ํฌ๊ธฐ์ ์๋ฅผ ์ค์ด๊ธฐ ์ํด ์ฌ์ฉํ ์ ์๋ pad_to_multiple_of
์ธ์๋ ์์ต๋๋ค!
์ค์ TPU๋ก ๋ชจ๋ธ์ ํ๋ จํ๋ ค๋ฉด ์ด๋ป๊ฒ ํด์ผ ํ๋์?[[how-do-i-actually-train-my-model-on-tpu]]
ํ๋ จ์ด XLA์ ํธํ๋๊ณ (TPU ๋
ธ๋/Colab์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ) ๋ฐ์ดํฐ ์ธํธ๊ฐ ์ ์ ํ๊ฒ ์ค๋น๋์๋ค๋ฉด, TPU์์ ์คํํ๋ ๊ฒ์ ๋๋๋๋ก ์ฝ์ต๋๋ค! ์ฝ๋์์ ๋ช ์ค๋ง ์ถ๊ฐํ์ฌ, TPU๋ฅผ ์ด๊ธฐํํ๊ณ ๋ชจ๋ธ๊ณผ ๋ฐ์ดํฐ ์ธํธ๊ฐ TPUStrategy
๋ฒ์ ๋ด์ ์์ฑ๋๋๋ก ๋ณ๊ฒฝํ๋ฉด ๋ฉ๋๋ค. ์ฐ๋ฆฌ์ TPU ์์ ๋
ธํธ๋ถ์ ์ฐธ์กฐํ์ฌ ์ค์ ๋ก ์๋ํ๋ ๋ชจ์ต์ ํ์ธํด ๋ณด์ธ์!
์์ฝ[[summary]]
์ฌ๊ธฐ์ ๋ง์ ๋ด์ฉ์ด ํฌํจ๋์ด ์์ผ๋ฏ๋ก, TPU ํ๋ จ์ ์ํ ๋ชจ๋ธ์ ์ค๋นํ ๋ ๋ฐ๋ฅผ ์ ์๋ ๊ฐ๋ตํ ์ฒดํฌ๋ฆฌ์คํธ๋ก ์์ฝํด ๋ณด๊ฒ ์ต๋๋ค:
- ์ฝ๋๊ฐ XLA์ ์ธ ๊ฐ์ง ๊ท์น์ ๋ฐ๋ฅด๋์ง ํ์ธํฉ๋๋ค.
- CPU/GPU์์
jit_compile=True
๋ก ๋ชจ๋ธ์ ์ปดํ์ผํ๊ณ XLA๋ก ํ๋ จํ ์ ์๋์ง ํ์ธํฉ๋๋ค. - ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ๊ฐ์ ธ์ค๊ฑฐ๋ TPU ํธํ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๋ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค(๋ ธํธ๋ถ ์ฐธ์กฐ)
- ์ฝ๋๋ฅผ Colab(accelerator๊ฐ โTPUโ๋ก ์ค์ ๋จ) ๋๋ Google Cloud์ TPU VM์ผ๋ก ๋ง์ด๊ทธ๋ ์ด์ ํฉ๋๋ค.
- TPU ์ด๊ธฐํ ์ฝ๋๋ฅผ ์ถ๊ฐํฉ๋๋ค(๋ ธํธ๋ถ ์ฐธ์กฐ)
TPUStrategy
๋ฅผ ์์ฑํ๊ณ ๋ฐ์ดํฐ ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๋ ๊ฒ๊ณผ ๋ชจ๋ธ ์์ฑ์ดstrategy.scope()
๋ด์ ์๋์ง ํ์ธํฉ๋๋ค(๋ ธํธ๋ถ ์ฐธ์กฐ)- TPU๋ก ์ด๋ํ ๋
jit_compile=True
๋ฅผ ๋ค์ ์ค์ ํ๋ ๊ฒ์ ์์ง ๋ง์ธ์! - ๐๐๐๐ฅบ๐ฅบ๐ฅบ
- model.fit()์ ๋ถ๋ฌ์ต๋๋ค.
- ์ฌ๋ฌ๋ถ์ด ํด๋์ต๋๋ค!