File size: 4,400 Bytes
87d40d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# μ—¬λŸ¬ GPUλ₯Ό μ‚¬μš©ν•œ λΆ„μ‚° μΆ”λ‘ 

λΆ„μ‚° μ„€μ •μ—μ„œλŠ” μ—¬λŸ¬ 개의 ν”„λ‘¬ν”„νŠΈλ₯Ό λ™μ‹œμ— 생성할 λ•Œ μœ μš©ν•œ πŸ€— [Accelerate](https://huggingface.co/docs/accelerate/index) λ˜λŠ” [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html)λ₯Ό μ‚¬μš©ν•˜μ—¬ μ—¬λŸ¬ GPUμ—μ„œ 좔둠을 μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

이 κ°€μ΄λ“œμ—μ„œλŠ” λΆ„μ‚° 좔둠을 μœ„ν•΄ πŸ€— Accelerate와 PyTorch Distributedλ₯Ό μ‚¬μš©ν•˜λŠ” 방법을 λ³΄μ—¬λ“œλ¦½λ‹ˆλ‹€.

## πŸ€— Accelerate

πŸ€— [Accelerate](https://huggingface.co/docs/accelerate/index)λŠ” λΆ„μ‚° μ„€μ •μ—μ„œ 좔둠을 μ‰½κ²Œ ν›ˆλ ¨ν•˜κ±°λ‚˜ μ‹€ν–‰ν•  수 μžˆλ„λ‘ μ„€κ³„λœ λΌμ΄λΈŒλŸ¬λ¦¬μž…λ‹ˆλ‹€. λΆ„μ‚° ν™˜κ²½ μ„€μ • ν”„λ‘œμ„ΈμŠ€λ₯Ό κ°„μ†Œν™”ν•˜μ—¬ PyTorch μ½”λ“œμ— 집쀑할 수 μžˆλ„λ‘ ν•΄μ€λ‹ˆλ‹€.

μ‹œμž‘ν•˜λ €λ©΄ Python νŒŒμΌμ„ μƒμ„±ν•˜κ³  [`accelerate.PartialState`]λ₯Ό μ΄ˆκΈ°ν™”ν•˜μ—¬ λΆ„μ‚° ν™˜κ²½μ„ μƒμ„±ν•˜λ©΄, 섀정이 μžλ™μœΌλ‘œ κ°μ§€λ˜λ―€λ‘œ `rank` λ˜λŠ” `world_size`λ₯Ό λͺ…μ‹œμ μœΌλ‘œ μ •μ˜ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€. ['DiffusionPipeline`]을 `distributed_state.device`둜 μ΄λ™ν•˜μ—¬ 각 ν”„λ‘œμ„ΈμŠ€μ— GPUλ₯Ό ν• λ‹Ήν•©λ‹ˆλ‹€.

이제 μ»¨ν…μŠ€νŠΈ κ΄€λ¦¬μžλ‘œ [`~accelerate.PartialState.split_between_processes`] μœ ν‹Έλ¦¬ν‹°λ₯Ό μ‚¬μš©ν•˜μ—¬ ν”„λ‘œμ„ΈμŠ€ μˆ˜μ— 따라 ν”„λ‘¬ν”„νŠΈλ₯Ό μžλ™μœΌλ‘œ λΆ„λ°°ν•©λ‹ˆλ‹€.


```py
from accelerate import PartialState
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
distributed_state = PartialState()
pipeline.to(distributed_state.device)

with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
    result = pipeline(prompt).images[0]
    result.save(f"result_{distributed_state.process_index}.png")
```

Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script:

```bash
accelerate launch run_distributed.py --num_processes=2
```

<Tip>μžμ„Έν•œ λ‚΄μš©μ€ [πŸ€— Accelerateλ₯Ό μ‚¬μš©ν•œ λΆ„μ‚° μΆ”λ‘ ](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) κ°€μ΄λ“œλ₯Ό μ°Έμ‘°ν•˜μ„Έμš”.

</Tip>

## Pytoerch λΆ„μ‚°

PyTorchλŠ” 데이터 병렬 처리λ₯Ό κ°€λŠ₯ν•˜κ²Œ ν•˜λŠ” [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)을 μ§€μ›ν•©λ‹ˆλ‹€.

μ‹œμž‘ν•˜λ €λ©΄ Python νŒŒμΌμ„ μƒμ„±ν•˜κ³  `torch.distributed` 및 `torch.multiprocessing`을 μž„ν¬νŠΈν•˜μ—¬ λΆ„μ‚° ν”„λ‘œμ„ΈμŠ€ 그룹을 μ„€μ •ν•˜κ³  각 GPUμ—μ„œ μΆ”λ‘ μš© ν”„λ‘œμ„ΈμŠ€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€. 그리고 [`DiffusionPipeline`]도 μ΄ˆκΈ°ν™”ν•΄μ•Ό ν•©λ‹ˆλ‹€:

ν™•μ‚° νŒŒμ΄ν”„λΌμΈμ„ `rank`둜 μ΄λ™ν•˜κ³  `get_rank`λ₯Ό μ‚¬μš©ν•˜μ—¬ 각 ν”„λ‘œμ„ΈμŠ€μ— GPUλ₯Ό ν• λ‹Ήν•˜λ©΄ 각 ν”„λ‘œμ„ΈμŠ€κ°€ λ‹€λ₯Έ ν”„λ‘¬ν”„νŠΈλ₯Ό μ²˜λ¦¬ν•©λ‹ˆλ‹€:

```py
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from diffusers import DiffusionPipeline

sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
```

μ‚¬μš©ν•  λ°±μ—”λ“œ μœ ν˜•, ν˜„μž¬ ν”„λ‘œμ„ΈμŠ€μ˜ `rank`, `world_size` λ˜λŠ” μ°Έμ—¬ν•˜λŠ” ν”„λ‘œμ„ΈμŠ€ 수둜 λΆ„μ‚° ν™˜κ²½ 생성을 μ²˜λ¦¬ν•˜λŠ” ν•¨μˆ˜[`init_process_group`]λ₯Ό λ§Œλ“€μ–΄ 좔둠을 μ‹€ν–‰ν•΄μ•Ό ν•©λ‹ˆλ‹€.

2개의 GPUμ—μ„œ 좔둠을 λ³‘λ ¬λ‘œ μ‹€ν–‰ν•˜λŠ” 경우 `world_size`λŠ” 2μž…λ‹ˆλ‹€.

```py
def run_inference(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    sd.to(rank)

    if torch.distributed.get_rank() == 0:
        prompt = "a dog"
    elif torch.distributed.get_rank() == 1:
        prompt = "a cat"

    image = sd(prompt).images[0]
    image.save(f"./{'_'.join(prompt)}.png")
```

λΆ„μ‚° 좔둠을 μ‹€ν–‰ν•˜λ €λ©΄ [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn)을 ν˜ΈμΆœν•˜μ—¬ `world_size`에 μ •μ˜λœ GPU μˆ˜μ— λŒ€ν•΄ `run_inference` ν•¨μˆ˜λ₯Ό μ‹€ν–‰ν•©λ‹ˆλ‹€:

```py
def main():
    world_size = 2
    mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    main()
```

μΆ”λ‘  슀크립트λ₯Ό μ™„λ£Œν–ˆμœΌλ©΄ `--nproc_per_node` 인수λ₯Ό μ‚¬μš©ν•˜μ—¬ μ‚¬μš©ν•  GPU 수λ₯Ό μ§€μ •ν•˜κ³  `torchrun`을 ν˜ΈμΆœν•˜μ—¬ 슀크립트λ₯Ό μ‹€ν–‰ν•©λ‹ˆλ‹€:

```bash
torchrun run_distributed.py --nproc_per_node=2
```