``torch.compile``์ ์ปดํ์ผ ์์ ์บ์ฑ
์ ์ Oguz Ulgen ๋ฒ์ญ ๊น์์ค Introduction ------------------
PyTorch Compiler๋ ์ปดํ์ผ ์ง์ฐ ์๊ฐ์ ์ค์ด๊ธฐ ์ํด ์ฌ๋ฌ ๊ฐ์ง ์บ์ฑ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค. ์ด ๋ ์ํผ์์๋ ์ด๋ฌํ ์บ์ฑ ๊ธฐ๋ฅ๋ค์ ์์ธํ ์ค๋ช ํ๊ณ , ์ฌ์ฉ์๊ฐ ์์ ์ ํ์ฉ ๋ชฉ์ ์ ๊ฐ์ฅ ์ ํฉํ ์ต์ ์ ์ ํํ ์ ์๋๋ก ์๋ดํฉ๋๋ค.
์บ์๋ฅผ ์ค์ ํ๋ ๋ฐฉ๋ฒ์ ์ปดํ์ผ ์์ ์บ์ฑ ์ค์ ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ธ์.
๋ํ PT CacheBench ๋ฒค์น๋งํฌ ์์ ์บ์ฑ ์ฑ๋ฅ ๋น๊ต ๊ฒฐ๊ณผ๋ ํ์ธํ ์ ์์ต๋๋ค.
์ด ๋ ์ํผ๋ฅผ ์์ํ๊ธฐ ์ ์ ๋ค์ ํญ๋ชฉ์ ์ค๋นํ๋์ง ํ์ธํ์ธ์.
torch.compile์ ๋ํ ๊ธฐ๋ณธ์ ์ธ ์ดํด๊ฐ ํ์ํฉ๋๋ค. ์๋ ์๋ฃ๋ฅผ ์ฐธ๊ณ ํ์ธ์.- PyTorch 2.4 ์ด์ ๋ฒ์
torch.compile ์ ๋ค์๊ณผ ๊ฐ์ ์บ์ฑ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
- ์๋ ํฌ ์๋ ์บ์ฑ (
Mega-Cache๋ผ๊ณ ๋ ๋ถ๋ฆผ) TorchDynamo,TorchInductor,Triton๋ชจ๋๋ณ ์บ์ฑ
์บ์๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ๋์ํ๊ธฐ ์ํด์๋ ์บ์ ์ํฐํฉํธ๊ฐ ๋์ผํ PyTorch ๋ฐ Triton ๋ฒ์ ์์ ์์ฑ๋ ๊ฒ์ด์ด์ผ ํ๋ฉฐ, ๋๋ฐ์ด์ค๊ฐ CUDA๋ก ์ค์ ๋ ๊ฒฝ์ฐ์๋ ๊ฐ์ GPU ํ๊ฒฝ์์ ์ฌ์ฉ๋์ด์ผ ํ๋ค๋ ์ ์ ์ ์ํด์ผ ํฉ๋๋ค.
Mega-Cacheโ๋ก ์ง์นญ๋๋ ์๋ ํฌ ์๋ ์บ์ฑ์, ์บ์ ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ดํฐ๋ฒ ์ด์ค์ ์ ์ฅํด ๋ค๋ฅธ ๋จธ์ ์์๋ ๋ถ๋ฌ์ฌ ์ ์๋ ์ด์ ๊ฐ๋ฅํ(portable) ์บ์ฑ ์๋ฃจ์
์ ์ฐพ๋ ์ฌ์ฉ์์๊ฒ ์ด์์ ์ธ ๋ฐฉ๋ฒ์
๋๋ค.
Mega-Cache ๋ ๋ค์ ๋ ๊ฐ์ง ์ปดํ์ผ๋ฌ API๋ฅผ ์ ๊ณตํฉ๋๋ค.
torch.compiler.save_cache_artifacts()torch.compiler.load_cache_artifacts()
์ผ๋ฐ์ ์ธ ์ฌ์ฉ ๋ฐฉ์์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค. ๋ชจ๋ธ์ ์ปดํ์ผํ๊ณ ์คํํ ํ, ์ฌ์ฉ์๋ torch.compiler.save_cache_artifacts() ํจ์๋ฅผ ํธ์ถํ์ฌ ์ด์ ๊ฐ๋ฅํ ํํ์ ์ปดํ์ผ๋ฌ ์ํฐํฉํธ๋ฅผ ๋ฐํ๋ฐ์ต๋๋ค.
๊ทธ ํ, ๋ค๋ฅธ ๋จธ์ ์์ ์ด ์ํฐํฉํธ๋ฅผ torch.compiler.load_cache_artifacts() ์ ์ ๋ฌํ์ฌ torch.compile ์บ์๋ฅผ ๋ฏธ๋ฆฌ ์ฑ์ ์บ์๋ฅผ ๋น ๋ฅด๊ฒ ์ด๊ธฐํํ ์ ์์ต๋๋ค.
๋ค์ ์์๋ฅผ ์ดํด๋ณด์ธ์. ๋จผ์ ๋ชจ๋ธ์ ์ปดํ์ผํ๊ณ ์บ์ ์ํฐํฉํธ๋ฅผ ์ ์ฅํฉ๋๋ค.
@torch.compile
def fn(x, y):
return x.sin() @ y
a = torch.rand(100, 100, dtype=dtype, device=device)
b = torch.rand(100, 100, dtype=dtype, device=device)
result = fn(a, b)
artifacts = torch.compiler.save_cache_artifacts()
assert artifacts is not None
artifact_bytes, cache_info = artifacts
# ์ด์ artifact_bytes๋ฅผ ๋ฐ์ดํฐ๋ฒ ์ด์ค์ ์ ์ฅํ ์๋ ์์ต๋๋ค.
# cache_info๋ ๊ธฐ๋ก(logging)ํ ์ ์์ต๋๋ค.Later, you can jump-start the cache by the following:
# ๋ฐ์ดํฐ๋ฒ ์ด์ค์์ ์ํฐํฉํธ๋ฅผ ๋ค์ด๋ก๋ํ๊ฑฐ๋ ๋ถ๋ฌ์ฌ ์๋ ์์ต๋๋ค.
torch.compiler.load_cache_artifacts(artifact_bytes)์ด ์์
์ ๋ค์ ์น์
์์ ๋ค๋ฃฐ ๋ชจ๋ ๋ชจ๋๋ณ ์บ์(modular caches)๋ฅผ ๋ฏธ๋ฆฌ ์ฑ์๋๋ค. ์ฌ๊ธฐ์๋ PGO, AOTAutograd, Inductor, Triton, ๊ทธ๋ฆฌ๊ณ Autotuning ์ด ํฌํจ๋ฉ๋๋ค.
์์ ์ธ๊ธํ Mega-Cache ๋ ์ฌ์ฉ์์ ๋ณ๋ ๊ฐ์
์์ด ์๋์ผ๋ก ๋์ํ๋ ๊ฐ๋ณ ๊ตฌ์ฑ์์๋ค๋ก ์ด๋ฃจ์ด์ ธ ์์ต๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก PyTorch Compiler๋ TorchDynamo, TorchInductor, ๊ทธ๋ฆฌ๊ณ Triton ์ ์ํ
๋ก์ปฌ ๋์คํฌ ๊ธฐ๋ฐ(on-disk) ์บ์๋ฅผ ํจ๊ป ์ ๊ณตํฉ๋๋ค. ์ด๋ฌํ ์บ์์๋ ๋ค์์ด ํฌํจ๋ฉ๋๋ค.
FXGraphCache: ํ์ผ ๊ณผ์ ์์ ์ฌ์ฉ๋๋ ๊ทธ๋ํ ๊ธฐ๋ฐ ์ค๊ฐ ํํ(IR, Intermediate Representation) ๊ตฌ์ฑ์์๋ฅผ ์ ์ฅํ๋ ์บ์์ ๋๋ค.TritonCache: ์ปดํ์ผ ๊ฒฐ๊ณผ๋ฅผ ์ ์ฅํ๋ ์บ์๋ก,Triton์ ์ํด ์์ฑ๋cubinํ์ผ๊ณผ ๊ธฐํ ์บ์ฑ ๊ด๋ จ ์ํฐํฉํธ๋ฅผ ํฌํจํฉ๋๋ค.InductorCache:FXGraphCache์Triton์บ์๋ฅผ ํจ๊ป ํฌํจํ๋ ํตํฉ ์บ์(bundled cache) ์ ๋๋ค.AOTAutogradCache: ํตํฉ ๊ทธ๋ํ(joint graph) ๊ด๋ จ ์ํฐํฉํธ๋ฅผ ์ ์ฅํ๋ ์บ์์ ๋๋ค.PGO-cache: ๋์ ์ ๋ ฅ ํํ ์ ๋ํ ๊ฒฐ์ ์ ๋ณด๋ฅผ ์ ์ฅํ์ฌ ์ฌ์ปดํ์ผ ํ์๋ฅผ ์ค์ด๋ ๋ฐ ์ฌ์ฉ๋๋ ์บ์์ ๋๋ค.- AutotuningCache:
Inductor๋Triton์ปค๋์ ์์ฑํ๊ณ , ๊ฐ์ฅ ๋น ๋ฅธ ์ปค๋์ ์ ํํ๊ธฐ ์ํด ๋๊ฐ ๋ ๋น ๋ฅธ์ง, ํจ์จ์ ์ธ์ง๋ฅผ ๋น๊ตํฉ๋๋ค.torch.compile์ ๋ด์ฅ๋AutotuningCache๋ ์ด ๊ฒฐ๊ณผ๋ฅผ ์บ์ฑํฉ๋๋ค.
์ด ๋ชจ๋ ์บ์ ์ํฐํฉํธ๋ TORCHINDUCTOR_CACHE_DIR ๊ฒฝ๋ก์ ์ ์ฅ๋ฉ๋๋ค. ๊ธฐ๋ณธ๊ฐ(default)์ /tmp/torchinductor_myusername ํํ๋ก ์ค์ ๋ฉ๋๋ค.
Redis ๊ธฐ๋ฐ ์บ์๋ฅผ ํ์ฉํ๊ณ ์ ํ๋ ์ฌ์ฉ์๋ฅผ ์ํด ์๊ฒฉ ์บ์ฑ ์ต์ ๋ ์ ๊ณตํฉ๋๋ค. Redis ๊ธฐ๋ฐ ์บ์ฑ์ ํ์ฑํํ๋ ๋ฐฉ๋ฒ์ ๋ํด์๋ ์ปดํ์ผ ์์ ์บ์ฑ ์ค์ ๋ฌธ์๋ฅผ ์ฐธ๊ณ ํ์ธ์.
์ด ๋ ์ํผ์์๋ PyTorch Inductor์ ์บ์ฑ ๋ฉ์ปค๋์ฆ์ด ๋ก์ปฌ ์บ์์ ์๊ฒฉ ์บ์๋ฅผ ๋ชจ๋ ํ์ฉํ์ฌ ์ปดํ์ผ ์ง์ฐ ์๊ฐ์ ํฌ๊ฒ ์ค์ผ ์ ์๋ค๋ ์ ์ ๋ฐฐ์ ์ต๋๋ค. ์ด๋ฌํ ์บ์๋ค์ ์ฌ์ฉ์์ ๋ณ๋ ๊ฐ์ ์์ด ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ํํ๊ฒ ์๋ํฉ๋๋ค.