Spaces:
Runtime error
Runtime error
<!--Copyright 2022 The HuggingFace Team. All rights reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
the License. You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
โ ๏ธ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | |
rendered properly in your Markdown viewer. | |
--> | |
# Trainer API๋ฅผ ์ฌ์ฉํ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ [[hyperparameter-search-using-trainer-api]] | |
๐ค Transformers์์๋ ๐ค Transformers ๋ชจ๋ธ์ ํ์ต์ํค๋๋ฐ ์ต์ ํ๋ [`Trainer`] ํด๋์ค๋ฅผ ์ ๊ณตํ๊ธฐ ๋๋ฌธ์, ์ฌ์ฉ์๋ ์ง์ ํ๋ จ ๋ฃจํ๋ฅผ ์์ฑํ ํ์ ์์ด ๋์ฑ ๊ฐํธํ๊ฒ ํ์ต์ ์ํฌ ์ ์์ต๋๋ค. ๋ํ, [`Trainer`]๋ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์์ ์ํ API๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด ๋ฌธ์์์ ์ด API๋ฅผ ํ์ฉํ๋ ๋ฐฉ๋ฒ์ ์์์ ํจ๊ป ๋ณด์ฌ๋๋ฆฌ๊ฒ ์ต๋๋ค. | |
## ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฐฑ์๋ [[hyperparameter-search-backend]] | |
[`Trainer`]๋ ํ์ฌ ์๋ 4๊ฐ์ง ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฐฑ์๋๋ฅผ ์ง์ํฉ๋๋ค: | |
[optuna](https://optuna.org/)์ [sigopt](https://sigopt.com/), [raytune](https://docs.ray.io/en/latest/tune/index.html), [wandb](https://wandb.ai/site/sweeps) ์ ๋๋ค. | |
ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฐฑ์๋๋ก ์ฌ์ฉํ๊ธฐ ์ ์ ์๋์ ๋ช ๋ น์ด๋ฅผ ์ฌ์ฉํ์ฌ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ค์ ์ค์นํ์ธ์. | |
```bash | |
pip install optuna/sigopt/wandb/ray[tune] | |
``` | |
## ์์ ์์ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์์ ํ์ฑํํ๋ ๋ฐฉ๋ฒ [[how-to-enable-hyperparameter-search-in-example]] | |
ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๊ณต๊ฐ์ ์ ์ํ์ธ์. ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๋ฐฑ์๋๋ง๋ค ์๋ก ๋ค๋ฅธ ํ์์ด ํ์ํฉ๋๋ค. | |
sigopt์ ๊ฒฝ์ฐ, ํด๋น [object_parameter](https://docs.sigopt.com/ai-module-api-references/api_reference/objects/object_parameter) ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ฌ ์๋์ ๊ฐ์ด ์์ฑํ์ธ์: | |
```py | |
>>> def sigopt_hp_space(trial): | |
... return [ | |
... {"bounds": {"min": 1e-6, "max": 1e-4}, "name": "learning_rate", "type": "double"}, | |
... { | |
... "categorical_values": ["16", "32", "64", "128"], | |
... "name": "per_device_train_batch_size", | |
... "type": "categorical", | |
... }, | |
... ] | |
``` | |
optuna์ ๊ฒฝ์ฐ, ํด๋น [object_parameter](https://optuna.readthedocs.io/en/stable/tutorial/10_key_features/002_configurations.html#sphx-glr-tutorial-10-key-features-002-configurations-py) ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ฌ ์๋์ ๊ฐ์ด ์์ฑํ์ธ์: | |
```py | |
>>> def optuna_hp_space(trial): | |
... return { | |
... "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True), | |
... "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]), | |
... } | |
``` | |
raytune์ ๊ฒฝ์ฐ, ํด๋น [object_parameter](https://docs.ray.io/en/latest/tune/api/search_space.html) ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ฌ ์๋์ ๊ฐ์ด ์์ฑํ์ธ์: | |
```py | |
>>> def ray_hp_space(trial): | |
... return { | |
... "learning_rate": tune.loguniform(1e-6, 1e-4), | |
... "per_device_train_batch_size": tune.choice([16, 32, 64, 128]), | |
... } | |
``` | |
wandb์ ๊ฒฝ์ฐ, ํด๋น [object_parameter](https://docs.wandb.ai/guides/sweeps/configuration) ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ฌ ์๋์ ๊ฐ์ด ์์ฑํ์ธ์: | |
```py | |
>>> def wandb_hp_space(trial): | |
... return { | |
... "method": "random", | |
... "metric": {"name": "objective", "goal": "minimize"}, | |
... "parameters": { | |
... "learning_rate": {"distribution": "uniform", "min": 1e-6, "max": 1e-4}, | |
... "per_device_train_batch_size": {"values": [16, 32, 64, 128]}, | |
... }, | |
... } | |
``` | |
`model_init` ํจ์๋ฅผ ์ ์ํ๊ณ ์ด๋ฅผ [`Trainer`]์ ์ ๋ฌํ์ธ์. ์๋๋ ๊ทธ ์์์ ๋๋ค. | |
```py | |
>>> def model_init(trial): | |
... return AutoModelForSequenceClassification.from_pretrained( | |
... model_args.model_name_or_path, | |
... from_tf=bool(".ckpt" in model_args.model_name_or_path), | |
... config=config, | |
... cache_dir=model_args.cache_dir, | |
... revision=model_args.model_revision, | |
... use_auth_token=True if model_args.use_auth_token else None, | |
... ) | |
``` | |
์๋์ ๊ฐ์ด `model_init` ํจ์, ํ๋ จ ์ธ์, ํ๋ จ ๋ฐ ํ ์คํธ ๋ฐ์ดํฐ์ , ๊ทธ๋ฆฌ๊ณ ํ๊ฐ ํจ์๋ฅผ ์ฌ์ฉํ์ฌ [`Trainer`]๋ฅผ ์์ฑํ์ธ์: | |
```py | |
>>> trainer = Trainer( | |
... model=None, | |
... args=training_args, | |
... train_dataset=small_train_dataset, | |
... eval_dataset=small_eval_dataset, | |
... compute_metrics=compute_metrics, | |
... tokenizer=tokenizer, | |
... model_init=model_init, | |
... data_collator=data_collator, | |
... ) | |
``` | |
ํ์ดํผํ๋ผ๋ฏธํฐ ํ์์ ํธ์ถํ๊ณ , ์ต์ ์ ์ํ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฐ์ ธ์ค์ธ์. ๋ฐฑ์๋๋ `"optuna"`/`"sigopt"`/`"wandb"`/`"ray"` ์ค์์ ์ ํํ ์ ์์ต๋๋ค. ๋ฐฉํฅ์ `"minimize"` ๋๋ `"maximize"` ์ค ์ ํํ๋ฉฐ, ๋ชฉํ๋ฅผ ์ต์ํํ ๊ฒ์ธ์ง ์ต๋ํํ ๊ฒ์ธ์ง๋ฅผ ๊ฒฐ์ ํฉ๋๋ค. | |
์์ ๋ง์ compute_objective ํจ์๋ฅผ ์ ์ํ ์ ์์ต๋๋ค. ๋ง์ฝ ์ด ํจ์๋ฅผ ์ ์ํ์ง ์์ผ๋ฉด, ๊ธฐ๋ณธ compute_objective๊ฐ ํธ์ถ๋๊ณ , f1๊ณผ ๊ฐ์ ํ๊ฐ ์งํ์ ํฉ์ด ๋ชฉํฏ๊ฐ์ผ๋ก ๋ฐํ๋ฉ๋๋ค. | |
```py | |
>>> best_trial = trainer.hyperparameter_search( | |
... direction="maximize", | |
... backend="optuna", | |
... hp_space=optuna_hp_space, | |
... n_trials=20, | |
... compute_objective=compute_objective, | |
... ) | |
``` | |
## DDP ๋ฏธ์ธ ์กฐ์ ์ ์ํ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ [[hyperparameter-search-for-ddp-finetune]] | |
ํ์ฌ, DDP(Distributed Data Parallelism; ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ์ฒ๋ฆฌ)๋ฅผ ์ํ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์์ optuna์ sigopt์์ ๊ฐ๋ฅํฉ๋๋ค. ์ต์์ ํ๋ก์ธ์ค๊ฐ ํ์ดํผํ๋ผ๋ฏธํฐ ํ์ ๊ณผ์ ์ ์์ํ๊ณ ๊ทธ ๊ฒฐ๊ณผ๋ฅผ ๋ค๋ฅธ ํ๋ก์ธ์ค์ ์ ๋ฌํฉ๋๋ค. | |