์ ์: Will Constable, Wei Feng ๋ฒ์ญ: ๊ฐ์งํ .. note:
|edit| ์ด ํํ ๋ฆฌ์ผ์ ์ฌ๊ธฐ์ ๋ณด๊ณ ํธ์งํ์ธ์ `github <https://github.com/pytorchkorea/tutorials-kr/blob/main/beginner_source/dist_overview.rst>`__.
์ด ๋ฌธ์๋ torch.distributed ํจํค์ง์ ๊ฐ์ ํ์ด์ง์
๋๋ค.
์ด ํ์ด์ง์ ๋ชฉํ๋ ๋ฌธ์๋ฅผ ์ฃผ์ ๋ณ๋ก ๋ถ๋ฅํ๊ณ
๊ฐ ์ฃผ์ ๋ฅผ ๊ฐ๋ตํ ์ค๋ช
ํ๋ ๊ฒ์
๋๋ค. PyTorch๋ก ๋ถ์ฐ ํ์ต ์ ํ๋ฆฌ์ผ์ด์
์ ์ฒ์ ๊ตฌ์ถํ๋ค๋ฉด,
์ด ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ฌ ์ฌ๋ฌ๋ถ์ ์ฌ์ฉ ์ฌ๋ก์ ๊ฐ์ฅ ์ ํฉํ ๊ธฐ์ ์ ์ฐพ์๋ณด๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
ํ์ดํ ์น ๋ถ์ฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ ์ฌ๋ฌ ๋ณ๋ ฌํ ๋ชจ๋, ํต์ ๊ณ์ธต, ๊ทธ๋ฆฌ๊ณ ๋๊ท๋ชจ ํ์ต ์์ ์ ์คํ ๋ฐ ๋๋ฒ๊น ์ ์ํ ์ธํ๋ผ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค.
์ด๋ฌํ ๋ณ๋ ฌํ ๋ชจ๋์ ๊ณ ์์ค ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ฉฐ ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ์กฐํฉํ์ฌ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ (DDP, Distributed Data-Parallel)
- ์์ ์ค๋ฉ ๋ฐ์ดํฐ ๋ณ๋ ฌ ํ์ต (FSDP2, Fully Sharded Data-Parallel Training)
- ํ ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ (TP, Tensor Parallel)
- ํ์ดํ๋ผ์ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ (PP, Pipeline Parallel)
DTensor ์ DeviceMesh ๋ N์ฐจ์ ํ๋ก์ธ์ค ๊ทธ๋ฃน์์ ํ
์๋ฅผ ์ค๋ฉํ๊ฑฐ๋ ๋ณต์ ํ๋ ๋ฐฉ์์ผ๋ก ๋ณ๋ ฌํ๋ฅผ ๊ตฌ์ฑํ ๋ ์ฌ์ฉํ๋ ๊ธฐ๋ณธ ๊ตฌ์ฑ์์์
๋๋ค.
- DTensor ๋ ์ค๋ฉ๋๊ฑฐ๋/๋ณต์ ๋ ํ ์๋ฅผ ๋ํ๋ด๋ฉฐ, ์ฐ์ฐ์ ์๊ตฌ์ ๋ฐ๋ผ ํ ์๋ฅผ ์ฌ์ค๋ฉํ๊ธฐ ์ํ ํต์ ์ ์๋์ผ๋ก ์ํํฉ๋๋ค.
- DeviceMesh ๋ ๊ฐ์๊ธฐ ๋๋ฐ์ด์ค์ ์ปค๋ฎค๋์ผ์ดํฐ(communicator)๋ฅผ ๋ค์ฐจ์ ๋ฐฐ์ด๋ก ์ถ์ํํ๋ฉฐ, ๋ค์ฐจ์ ๋ณ๋ ฌ์ฑ์์ ์งํฉ(collective) ํต์ ์ ์ํํ๊ธฐ ์ํ ํ์
ProcessGroup์ธ์คํด์ค๋ค์ ๊ด๋ฆฌํฉ๋๋ค. ๋ ์์๋ณด๋ ค๋ฉด Device Mesh ๋ ์ํผ ๋ฅผ ์ง์ ๋ฐ๋ผ ํด๋ณด์ธ์.
- PyTorch ๋ถ์ฐ ํต์ ๊ณ์ธต (C10D) ์ ์งํฉ ํต์ API (์: all_reduce(์ ์ฒด ์ถ์)
- , all_gather(์ ์ฒด ์์ง)) ์ P2P ํต์ API (์: send(๋๊ธฐ ์ ์ก) , isend(๋น๋๊ธฐ ์ ์ก))๋ฅผ ๋ชจ๋ ์ ๊ณตํ๋ฉฐ, ์ด๋ฌํ API๋ ๋ชจ๋ ๋ณ๋ ฌํ ๊ตฌํ์์ ๋ด๋ถ์ ์ผ๋ก ์ฌ์ฉ๋ฉ๋๋ค. PyTorch๋ก ๋ถ์ฐ ์ ํ๋ฆฌ์ผ์ด์ ์์ฑํ๊ธฐ ๋ C10D ํต์ API ์ฌ์ฉ ์์ ๋ฅผ ๋ณด์ฌ ์ค๋๋ค.
torchrun ์ ๋๋ฆฌ ์ฐ์ด๋ ์คํ๊ธฐ ์คํฌ๋ฆฝํธ๋ก, ๋ถ์ฐ PyTorch ํ๋ก๊ทธ๋จ์ ์คํํ๊ธฐ ์ํด ๋ก์ปฌ ๋ฐ ์๊ฒฉ ๋จธ์ ์์ ํ๋ก์ธ์ค๋ฅผ ์์ฑํฉ๋๋ค.
๋ฐ์ดํฐ ๋ณ๋ ฌํ(Data Parallelism)๋ ๋๋ฆฌ ์ฑํ๋ SPMD(single-program multiple-data) ํ์ต ํจ๋ฌ๋ค์์ผ๋ก, ๋ชจ๋ธ์ด ๋ชจ๋ ํ๋ก์ธ์ค์ ๋ณต์ ๋๊ณ ๊ฐ ๋ชจ๋ธ์ ๋ณต์ ๋ณธ์ด ์๋ก ๋ค๋ฅธ ์ ๋ ฅ ๋ฐ์ดํฐ ์ํ ์งํฉ์ ๋ํด ๋ก์ปฌ ๋ณํ๋๋ฅผ ๊ณ์ฐํฉ๋๋ค. ๊ทธ๋ฐ ๋ค์ ๊ฐ ์ตํฐ๋ง์ด์ ์คํ ์ ์ ๋ฐ์ดํฐ-๋ณ๋ ฌ ํต์ ๊ทธ๋ฃน ๋ด์์ ๋ณํ๋๋ฅผ ํ๊ท ํํฉ๋๋ค.
๋ชจ๋ธ ๋ณ๋ ฌํ(Model Parallelism) ๊ธฐ๋ฒ(๋๋ ์ค๋ฉ๋ ๋ฐ์ดํฐ ๋ณ๋ ฌํ)์ ๋ชจ๋ธ์ด GPU ๋ฉ๋ชจ๋ฆฌ์ ๋ค์ด๊ฐ์ง ์์ ๋ ํ์ํ๋ฉฐ, ์๋ก ๊ฒฐํฉํด ๋ค์ฐจ์(N-D) ๋ณ๋ ฌํ ๊ธฐ๋ฒ์ ๊ตฌ์ฑํ ์ ์์ต๋๋ค.
๋ชจ๋ธ์ ์ ์ฉํ ๋ณ๋ ฌํ ๊ธฐ๋ฒ์ ๊ฒฐ์ ํ ๋๋ ๋ค์์ ์ผ๋ฐ์ ์ธ ์ง์นจ์ ์ฐธ๊ณ ํ์ธ์.
- ๋ชจ๋ธ์ด ๋จ์ผ GPU๋ฅผ ํ์ฌํ ์ ์์ง๋ง, ์ฌ๋ฌ GPU๋ก ์ฝ๊ฒ ํ์ต์ ํ์ฅํ๊ณ ์ถ๋ค๋ฉด
DistributedDataParallel (DDP, ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌํ) ๋ฅผ ์ฌ์ฉํ์ธ์.
- ์ฌ๋ฌ ๋ ธ๋๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ, ์ฌ๋ฌ PyTorch ํ๋ก์ธ์ค๋ฅผ ์์ํ๋ ค๋ฉด torchrun ์ ์ฌ์ฉํ์ธ์.
- ์ฐธ๊ณ : ์์ํ๊ธฐ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DDP)
- ๋ชจ๋ธ์ด ๋จ์ผ GPU์ ํ์ฌ๋์ง ์์ ๋๋ FullyShardedDataParallel (FSDP2, ์์ ์ค๋ฉ ๋ฐ์ดํฐ ๋ณ๋ ฌํ) ์ ์ฌ์ฉํ์ธ์.
- ์ฐธ๊ณ : ์์ํ๊ธฐ FSDP2
- FSDP2๋ก๋ ํ์ฅ ํ๊ณ์ ๋๋ฌํ ๊ฒฝ์ฐ, Tensor Parallel (TP, Tensor ๋ณ๋ ฌํ) ๋ฐ/๋๋ Pipeline Parallel (PP, ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ) ๋ฅผ ์ฌ์ฉํ์ธ์.
- ํ ์ ๋ณ๋ ฌํ ํํ ๋ฆฌ์ผ ์ ํ์ธํด ๋ณด์ธ์.
- ์ฐธ๊ณ : TorchTitan 3D ๋ณ๋ ฌํ ์ ์ฒด(end to end) ์์
Note
๋ฐ์ดํฐ ๋ณ๋ ฌ ํ์ต์ ์๋ ํผํฉ ์ ๋ฐ๋(AMP, Automatic Mixed Precision) ์ ํจ๊ป์์๋ ๋์ํฉ๋๋ค.
PyTorch ๋ถ์ฐ์ ๊ธฐ์ฌํ๊ณ ์ถ๋ค๋ฉด ๊ฐ๋ฐ์ ๊ฐ์ด๋ ๋ฅผ ์ฐธ๊ณ ํ์ธ์.