Skip to content

Latest commit

ย 

History

History
87 lines (56 loc) ยท 6.27 KB

File metadata and controls

87 lines (56 loc) ยท 6.27 KB

PyTorch ๋ถ„์‚ฐ ๊ฐœ์š”

์ €์ž: Will Constable, Wei Feng ๋ฒˆ์—ญ: ๊ฐ•์ง€ํ˜„ .. note:

|edit| ์ด ํŠœํ† ๋ฆฌ์–ผ์„ ์—ฌ๊ธฐ์„œ ๋ณด๊ณ  ํŽธ์ง‘ํ•˜์„ธ์š” `github <https://github.com/pytorchkorea/tutorials-kr/blob/main/beginner_source/dist_overview.rst>`__.

์ด ๋ฌธ์„œ๋Š” torch.distributed ํŒจํ‚ค์ง€์˜ ๊ฐœ์š” ํŽ˜์ด์ง€์ž…๋‹ˆ๋‹ค. ์ด ํŽ˜์ด์ง€์˜ ๋ชฉํ‘œ๋Š” ๋ฌธ์„œ๋ฅผ ์ฃผ์ œ๋ณ„๋กœ ๋ถ„๋ฅ˜ํ•˜๊ณ  ๊ฐ ์ฃผ์ œ๋ฅผ ๊ฐ„๋žตํžˆ ์„ค๋ช…ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. PyTorch๋กœ ๋ถ„์‚ฐ ํ•™์Šต ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์„ ์ฒ˜์Œ ๊ตฌ์ถ•ํ•œ๋‹ค๋ฉด, ์ด ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์—ฌ ์—ฌ๋Ÿฌ๋ถ„์˜ ์‚ฌ์šฉ ์‚ฌ๋ก€์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ๊ธฐ์ˆ ์„ ์ฐพ์•„๋ณด๋Š” ๊ฒƒ์„ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค.

์„œ๋ก 

ํŒŒ์ดํ† ์น˜ ๋ถ„์‚ฐ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋Š” ์—ฌ๋Ÿฌ ๋ณ‘๋ ฌํ™” ๋ชจ๋“ˆ, ํ†ต์‹  ๊ณ„์ธต, ๊ทธ๋ฆฌ๊ณ  ๋Œ€๊ทœ๋ชจ ํ•™์Šต ์ž‘์—…์˜ ์‹คํ–‰ ๋ฐ ๋””๋ฒ„๊น…์„ ์œ„ํ•œ ์ธํ”„๋ผ๋กœ ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.

๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ API

์ด๋Ÿฌํ•œ ๋ณ‘๋ ฌํ™” ๋ชจ๋“ˆ์€ ๊ณ ์ˆ˜์ค€ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•˜๋ฉฐ ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ์กฐํ•ฉํ•˜์—ฌ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ƒค๋”ฉ ๊ธฐ๋ณธ ์š”์†Œ(Sharding primitives)

DTensor ์™€ DeviceMesh ๋Š” N์ฐจ์› ํ”„๋กœ์„ธ์Šค ๊ทธ๋ฃน์—์„œ ํ…์„œ๋ฅผ ์ƒค๋”ฉํ•˜๊ฑฐ๋‚˜ ๋ณต์ œํ•˜๋Š” ๋ฐฉ์‹์œผ๋กœ ๋ณ‘๋ ฌํ™”๋ฅผ ๊ตฌ์„ฑํ•  ๋•Œ ์‚ฌ์šฉํ•˜๋Š” ๊ธฐ๋ณธ ๊ตฌ์„ฑ์š”์†Œ์ž…๋‹ˆ๋‹ค.

  • DTensor ๋Š” ์ƒค๋”ฉ๋˜๊ฑฐ๋‚˜/๋ณต์ œ๋œ ํ…์„œ๋ฅผ ๋‚˜ํƒ€๋‚ด๋ฉฐ, ์—ฐ์‚ฐ์˜ ์š”๊ตฌ์— ๋”ฐ๋ผ ํ…์„œ๋ฅผ ์žฌ์ƒค๋”ฉํ•˜๊ธฐ ์œ„ํ•œ ํ†ต์‹ ์„ ์ž๋™์œผ๋กœ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  • DeviceMesh ๋Š” ๊ฐ€์†๊ธฐ ๋””๋ฐ”์ด์Šค์˜ ์ปค๋ฎค๋‹ˆ์ผ€์ดํ„ฐ(communicator)๋ฅผ ๋‹ค์ฐจ์› ๋ฐฐ์—ด๋กœ ์ถ”์ƒํ™”ํ•˜๋ฉฐ, ๋‹ค์ฐจ์› ๋ณ‘๋ ฌ์„ฑ์—์„œ ์ง‘ํ•ฉ(collective) ํ†ต์‹ ์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์œ„ํ•œ ํ•˜์œ„ ProcessGroup ์ธ์Šคํ„ด์Šค๋“ค์„ ๊ด€๋ฆฌํ•ฉ๋‹ˆ๋‹ค. ๋” ์•Œ์•„๋ณด๋ ค๋ฉด Device Mesh ๋ ˆ์‹œํ”ผ ๋ฅผ ์ง์ ‘ ๋”ฐ๋ผ ํ•ด๋ณด์„ธ์š”.

ํ†ต์‹  API

PyTorch ๋ถ„์‚ฐ ํ†ต์‹  ๊ณ„์ธต (C10D) ์€ ์ง‘ํ•ฉ ํ†ต์‹  API (์˜ˆ: all_reduce(์ „์ฒด ์ถ•์†Œ)
, all_gather(์ „์ฒด ์ˆ˜์ง‘)) ์™€ P2P ํ†ต์‹  API (์˜ˆ: send(๋™๊ธฐ ์ „์†ก) , isend(๋น„๋™๊ธฐ ์ „์†ก))๋ฅผ ๋ชจ๋‘ ์ œ๊ณตํ•˜๋ฉฐ, ์ด๋Ÿฌํ•œ API๋Š” ๋ชจ๋“  ๋ณ‘๋ ฌํ™” ๊ตฌํ˜„์—์„œ ๋‚ด๋ถ€์ ์œผ๋กœ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. PyTorch๋กœ ๋ถ„์‚ฐ ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์ž‘์„ฑํ•˜๊ธฐ ๋Š” C10D ํ†ต์‹  API ์‚ฌ์šฉ ์˜ˆ์ œ๋ฅผ ๋ณด์—ฌ ์ค๋‹ˆ๋‹ค.

์‹คํ–‰๊ธฐ(Launcher)

torchrun ์€ ๋„๋ฆฌ ์“ฐ์ด๋Š” ์‹คํ–‰๊ธฐ ์Šคํฌ๋ฆฝํŠธ๋กœ, ๋ถ„์‚ฐ PyTorch ํ”„๋กœ๊ทธ๋žจ์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•ด ๋กœ์ปฌ ๋ฐ ์›๊ฒฉ ๋จธ์‹ ์—์„œ ํ”„๋กœ์„ธ์Šค๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ ํ™•์žฅ์„ ์œ„ํ•œ ๋ณ‘๋ ฌํ™” ์ ์šฉ

๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”(Data Parallelism)๋Š” ๋„๋ฆฌ ์ฑ„ํƒ๋œ SPMD(single-program multiple-data) ํ•™์Šต ํŒจ๋Ÿฌ๋‹ค์ž„์œผ๋กœ, ๋ชจ๋ธ์ด ๋ชจ๋“  ํ”„๋กœ์„ธ์Šค์— ๋ณต์ œ๋˜๊ณ  ๊ฐ ๋ชจ๋ธ์˜ ๋ณต์ œ๋ณธ์ด ์„œ๋กœ ๋‹ค๋ฅธ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ ์ง‘ํ•ฉ์— ๋Œ€ํ•ด ๋กœ์ปฌ ๋ณ€ํ™”๋„๋ฅผ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๊ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ์Šคํ… ์ „์— ๋ฐ์ดํ„ฐ-๋ณ‘๋ ฌ ํ†ต์‹  ๊ทธ๋ฃน ๋‚ด์—์„œ ๋ณ€ํ™”๋„๋ฅผ ํ‰๊ท ํ™”ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ ๋ณ‘๋ ฌํ™”(Model Parallelism) ๊ธฐ๋ฒ•(๋˜๋Š” ์ƒค๋”ฉ๋œ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”)์€ ๋ชจ๋ธ์ด GPU ๋ฉ”๋ชจ๋ฆฌ์— ๋“ค์–ด๊ฐ€์ง€ ์•Š์„ ๋•Œ ํ•„์š”ํ•˜๋ฉฐ, ์„œ๋กœ ๊ฒฐํ•ฉํ•ด ๋‹ค์ฐจ์›(N-D) ๋ณ‘๋ ฌํ™” ๊ธฐ๋ฒ•์„ ๊ตฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋ชจ๋ธ์— ์ ์šฉํ•  ๋ณ‘๋ ฌํ™” ๊ธฐ๋ฒ•์„ ๊ฒฐ์ •ํ•  ๋•Œ๋Š” ๋‹ค์Œ์˜ ์ผ๋ฐ˜์ ์ธ ์ง€์นจ์„ ์ฐธ๊ณ ํ•˜์„ธ์š”.

  1. ๋ชจ๋ธ์ด ๋‹จ์ผ GPU๋ฅผ ํƒ‘์žฌํ•  ์ˆ˜ ์žˆ์ง€๋งŒ, ์—ฌ๋Ÿฌ GPU๋กœ ์‰ฝ๊ฒŒ ํ•™์Šต์„ ํ™•์žฅํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด DistributedDataParallel (DDP, ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”) ๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.
  2. ๋ชจ๋ธ์ด ๋‹จ์ผ GPU์— ํƒ‘์žฌ๋˜์ง€ ์•Š์„ ๋•Œ๋Š” FullyShardedDataParallel (FSDP2, ์™„์ „ ์ƒค๋”ฉ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌํ™”) ์„ ์‚ฌ์šฉํ•˜์„ธ์š”.
  3. FSDP2๋กœ๋Š” ํ™•์žฅ ํ•œ๊ณ„์— ๋„๋‹ฌํ•œ ๊ฒฝ์šฐ, Tensor Parallel (TP, Tensor ๋ณ‘๋ ฌํ™”) ๋ฐ/๋˜๋Š” Pipeline Parallel (PP, ํŒŒ์ดํ”„๋ผ์ธ ๋ณ‘๋ ฌํ™”) ๋ฅผ ์‚ฌ์šฉํ•˜์„ธ์š”.

Note

๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ํ•™์Šต์€ ์ž๋™ ํ˜ผํ•ฉ ์ •๋ฐ€๋„(AMP, Automatic Mixed Precision) ์™€ ํ•จ๊ป˜์—์„œ๋„ ๋™์ž‘ํ•ฉ๋‹ˆ๋‹ค.

PyTorch ๋ถ„์‚ฐ ๊ฐœ๋ฐœ์ž

PyTorch ๋ถ„์‚ฐ์— ๊ธฐ์—ฌํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ๊ฐœ๋ฐœ์ž ๊ฐ€์ด๋“œ ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.