Skip to content

Latest commit

ย 

History

History
109 lines (70 loc) ยท 6.43 KB

File metadata and controls

109 lines (70 loc) ยท 6.43 KB

``torch.compile``์˜ ์ปดํŒŒ์ผ ์‹œ์  ์บ์‹ฑ

์ €์ž Oguz Ulgen ๋ฒˆ์—ญ ๊น€์˜์ค€ Introduction ------------------

PyTorch Compiler๋Š” ์ปดํŒŒ์ผ ์ง€์—ฐ ์‹œ๊ฐ„์„ ์ค„์ด๊ธฐ ์œ„ํ•ด ์—ฌ๋Ÿฌ ๊ฐ€์ง€ ์บ์‹ฑ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ ˆ์‹œํ”ผ์—์„œ๋Š” ์ด๋Ÿฌํ•œ ์บ์‹ฑ ๊ธฐ๋Šฅ๋“ค์„ ์ž์„ธํžˆ ์„ค๋ช…ํ•˜๊ณ , ์‚ฌ์šฉ์ž๊ฐ€ ์ž์‹ ์˜ ํ™œ์šฉ ๋ชฉ์ ์— ๊ฐ€์žฅ ์ ํ•ฉํ•œ ์˜ต์…˜์„ ์„ ํƒํ•  ์ˆ˜ ์žˆ๋„๋ก ์•ˆ๋‚ดํ•ฉ๋‹ˆ๋‹ค.

์บ์‹œ๋ฅผ ์„ค์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์ปดํŒŒ์ผ ์‹œ์  ์บ์‹ฑ ์„ค์ • ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

๋˜ํ•œ PT CacheBench ๋ฒค์น˜๋งˆํฌ ์—์„œ ์บ์‹ฑ ์„ฑ๋Šฅ ๋น„๊ต ๊ฒฐ๊ณผ๋„ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์‚ฌ์ „ ์ค€๋น„ ์‚ฌํ•ญ

์ด ๋ ˆ์‹œํ”ผ๋ฅผ ์‹œ์ž‘ํ•˜๊ธฐ ์ „์— ๋‹ค์Œ ํ•ญ๋ชฉ์„ ์ค€๋น„ํ–ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”.

์บ์‹ฑ ๊ธฐ๋Šฅ

torch.compile ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์บ์‹ฑ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

  • ์—”๋“œ ํˆฌ ์—”๋“œ ์บ์‹ฑ (Mega-Cache ๋ผ๊ณ ๋„ ๋ถˆ๋ฆผ)
  • TorchDynamo, TorchInductor, Triton ๋ชจ๋“ˆ๋ณ„ ์บ์‹ฑ

์บ์‹œ๊ฐ€ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ๋™์ž‘ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์บ์‹œ ์•„ํ‹ฐํŒฉํŠธ๊ฐ€ ๋™์ผํ•œ PyTorch ๋ฐ Triton ๋ฒ„์ „์—์„œ ์ƒ์„ฑ๋œ ๊ฒƒ์ด์–ด์•ผ ํ•˜๋ฉฐ, ๋””๋ฐ”์ด์Šค๊ฐ€ CUDA๋กœ ์„ค์ •๋œ ๊ฒฝ์šฐ์—๋Š” ๊ฐ™์€ GPU ํ™˜๊ฒฝ์—์„œ ์‚ฌ์šฉ๋˜์–ด์•ผ ํ•œ๋‹ค๋Š” ์ ์— ์œ ์˜ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

torch.compile ์—”๋“œ ํˆฌ ์—”๋“œ ์บ์‹ฑ (Mega-Cache)

Mega-Cacheโ€๋กœ ์ง€์นญ๋˜๋Š” ์—”๋“œ ํˆฌ ์—”๋“œ ์บ์‹ฑ์€, ์บ์‹œ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ €์žฅํ•ด ๋‹ค๋ฅธ ๋จธ์‹ ์—์„œ๋„ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜ ์žˆ๋Š” ์ด์‹ ๊ฐ€๋Šฅํ•œ(portable) ์บ์‹ฑ ์†”๋ฃจ์…˜์„ ์ฐพ๋Š” ์‚ฌ์šฉ์ž์—๊ฒŒ ์ด์ƒ์ ์ธ ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

Mega-Cache ๋Š” ๋‹ค์Œ ๋‘ ๊ฐ€์ง€ ์ปดํŒŒ์ผ๋Ÿฌ API๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.

  • torch.compiler.save_cache_artifacts()
  • torch.compiler.load_cache_artifacts()

์ผ๋ฐ˜์ ์ธ ์‚ฌ์šฉ ๋ฐฉ์‹์€ ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค. ๋ชจ๋ธ์„ ์ปดํŒŒ์ผํ•˜๊ณ  ์‹คํ–‰ํ•œ ํ›„, ์‚ฌ์šฉ์ž๋Š” torch.compiler.save_cache_artifacts() ํ•จ์ˆ˜๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์ด์‹ ๊ฐ€๋Šฅํ•œ ํ˜•ํƒœ์˜ ์ปดํŒŒ์ผ๋Ÿฌ ์•„ํ‹ฐํŒฉํŠธ๋ฅผ ๋ฐ˜ํ™˜๋ฐ›์Šต๋‹ˆ๋‹ค. ๊ทธ ํ›„, ๋‹ค๋ฅธ ๋จธ์‹ ์—์„œ ์ด ์•„ํ‹ฐํŒฉํŠธ๋ฅผ torch.compiler.load_cache_artifacts() ์— ์ „๋‹ฌํ•˜์—ฌ torch.compile ์บ์‹œ๋ฅผ ๋ฏธ๋ฆฌ ์ฑ„์›Œ ์บ์‹œ๋ฅผ ๋น ๋ฅด๊ฒŒ ์ดˆ๊ธฐํ™”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋‹ค์Œ ์˜ˆ์‹œ๋ฅผ ์‚ดํŽด๋ณด์„ธ์š”. ๋จผ์ € ๋ชจ๋ธ์„ ์ปดํŒŒ์ผํ•˜๊ณ  ์บ์‹œ ์•„ํ‹ฐํŒฉํŠธ๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค.

@torch.compile
def fn(x, y):
    return x.sin() @ y

a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)

result = fn(a, b)

artifacts = torch.compiler.save_cache_artifacts()

assert artifacts is not None
artifact_bytes, cache_info = artifacts

# ์ด์ œ artifact_bytes๋ฅผ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ €์žฅํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
# cache_info๋Š” ๊ธฐ๋ก(logging)ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Later, you can jump-start the cache by the following:

# ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ์•„ํ‹ฐํŒฉํŠธ๋ฅผ ๋‹ค์šด๋กœ๋“œํ•˜๊ฑฐ๋‚˜ ๋ถˆ๋Ÿฌ์˜ฌ ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค.
torch.compiler.load_cache_artifacts(artifact_bytes)

์ด ์ž‘์—…์€ ๋‹ค์Œ ์„น์…˜์—์„œ ๋‹ค๋ฃฐ ๋ชจ๋“  ๋ชจ๋“ˆ๋ณ„ ์บ์‹œ(modular caches)๋ฅผ ๋ฏธ๋ฆฌ ์ฑ„์›๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์—๋Š” PGO, AOTAutograd, Inductor, Triton, ๊ทธ๋ฆฌ๊ณ  Autotuning ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

TorchDynamo, TorchInductor, ๊ทธ๋ฆฌ๊ณ  Triton ์˜ ๋ชจ๋“ˆ๋ณ„ ์บ์‹ฑ

์•ž์„œ ์–ธ๊ธ‰ํ•œ Mega-Cache ๋Š” ์‚ฌ์šฉ์ž์˜ ๋ณ„๋„ ๊ฐœ์ž… ์—†์ด ์ž๋™์œผ๋กœ ๋™์ž‘ํ•˜๋Š” ๊ฐœ๋ณ„ ๊ตฌ์„ฑ์š”์†Œ๋“ค๋กœ ์ด๋ฃจ์–ด์ ธ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ธฐ๋ณธ์ ์œผ๋กœ PyTorch Compiler๋Š” TorchDynamo, TorchInductor, ๊ทธ๋ฆฌ๊ณ  Triton ์„ ์œ„ํ•œ ๋กœ์ปฌ ๋””์Šคํฌ ๊ธฐ๋ฐ˜(on-disk) ์บ์‹œ๋ฅผ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์บ์‹œ์—๋Š” ๋‹ค์Œ์ด ํฌํ•จ๋ฉ๋‹ˆ๋‹ค.

  • FXGraphCache: ํŒŒ์ผ ๊ณผ์ •์—์„œ ์‚ฌ์šฉ๋˜๋Š” ๊ทธ๋ž˜ํ”„ ๊ธฐ๋ฐ˜ ์ค‘๊ฐ„ ํ‘œํ˜„(IR, Intermediate Representation) ๊ตฌ์„ฑ์š”์†Œ๋ฅผ ์ €์žฅํ•˜๋Š” ์บ์‹œ์ž…๋‹ˆ๋‹ค.
  • TritonCache: ์ปดํŒŒ์ผ ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•˜๋Š” ์บ์‹œ๋กœ, Triton ์— ์˜ํ•ด ์ƒ์„ฑ๋œ cubin ํŒŒ์ผ๊ณผ ๊ธฐํƒ€ ์บ์‹ฑ ๊ด€๋ จ ์•„ํ‹ฐํŒฉํŠธ๋ฅผ ํฌํ•จํ•ฉ๋‹ˆ๋‹ค.
  • InductorCache: FXGraphCache ์™€ Triton ์บ์‹œ๋ฅผ ํ•จ๊ป˜ ํฌํ•จํ•˜๋Š” ํ†ตํ•ฉ ์บ์‹œ(bundled cache) ์ž…๋‹ˆ๋‹ค.
  • AOTAutogradCache: ํ†ตํ•ฉ ๊ทธ๋ž˜ํ”„(joint graph) ๊ด€๋ จ ์•„ํ‹ฐํŒฉํŠธ๋ฅผ ์ €์žฅํ•˜๋Š” ์บ์‹œ์ž…๋‹ˆ๋‹ค.
  • PGO-cache: ๋™์  ์ž…๋ ฅ ํ˜•ํƒœ ์— ๋Œ€ํ•œ ๊ฒฐ์ • ์ •๋ณด๋ฅผ ์ €์žฅํ•˜์—ฌ ์žฌ์ปดํŒŒ์ผ ํšŸ์ˆ˜๋ฅผ ์ค„์ด๋Š” ๋ฐ ์‚ฌ์šฉ๋˜๋Š” ์บ์‹œ์ž…๋‹ˆ๋‹ค.
  • AutotuningCache:
    • Inductor ๋Š” Triton ์ปค๋„์„ ์ƒ์„ฑํ•˜๊ณ , ๊ฐ€์žฅ ๋น ๋ฅธ ์ปค๋„์„ ์„ ํƒํ•˜๊ธฐ ์œ„ํ•ด ๋ˆ„๊ฐ€ ๋” ๋น ๋ฅธ์ง€, ํšจ์œจ์ ์ธ์ง€๋ฅผ ๋น„๊ตํ•ฉ๋‹ˆ๋‹ค.
    • torch.compile ์— ๋‚ด์žฅ๋œ AutotuningCache ๋Š” ์ด ๊ฒฐ๊ณผ๋ฅผ ์บ์‹ฑํ•ฉ๋‹ˆ๋‹ค.

์ด ๋ชจ๋“  ์บ์‹œ ์•„ํ‹ฐํŒฉํŠธ๋Š” TORCHINDUCTOR_CACHE_DIR ๊ฒฝ๋กœ์— ์ €์žฅ๋ฉ๋‹ˆ๋‹ค. ๊ธฐ๋ณธ๊ฐ’(default)์€ /tmp/torchinductor_myusername ํ˜•ํƒœ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.

์›๊ฒฉ ์บ์‹ฑ(Remote Caching)

Redis ๊ธฐ๋ฐ˜ ์บ์‹œ๋ฅผ ํ™œ์šฉํ•˜๊ณ ์ž ํ•˜๋Š” ์‚ฌ์šฉ์ž๋ฅผ ์œ„ํ•ด ์›๊ฒฉ ์บ์‹ฑ ์˜ต์…˜๋„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. Redis ๊ธฐ๋ฐ˜ ์บ์‹ฑ์„ ํ™œ์„ฑํ™”ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด์„œ๋Š” ์ปดํŒŒ์ผ ์‹œ์  ์บ์‹ฑ ์„ค์ • ๋ฌธ์„œ๋ฅผ ์ฐธ๊ณ ํ•˜์„ธ์š”.

๊ฒฐ๋ก 

์ด ๋ ˆ์‹œํ”ผ์—์„œ๋Š” PyTorch Inductor์˜ ์บ์‹ฑ ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ๋กœ์ปฌ ์บ์‹œ์™€ ์›๊ฒฉ ์บ์‹œ๋ฅผ ๋ชจ๋‘ ํ™œ์šฉํ•˜์—ฌ ์ปดํŒŒ์ผ ์ง€์—ฐ ์‹œ๊ฐ„์„ ํฌ๊ฒŒ ์ค„์ผ ์ˆ˜ ์žˆ๋‹ค๋Š” ์ ์„ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ์บ์‹œ๋“ค์€ ์‚ฌ์šฉ์ž์˜ ๋ณ„๋„ ๊ฐœ์ž… ์—†์ด ๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ์›ํ™œํ•˜๊ฒŒ ์ž‘๋™ํ•ฉ๋‹ˆ๋‹ค.