μ μ: Anshul Sinha
μ΄ νν 리μΌμμλ PyTorchμ DistributedTensor(DTensor)μ ν¨κ» ``CommDebugMode``λ₯Ό μ¬μ©νλ λ°©λ²μ μ΄ν΄λ΄ λλ€. μ΄λ₯Ό ν΅ν΄ λΆμ° νμ΅ νκ²½μμ μνλλ μ§ν© μ°μ°(collective operation)μ μΆμ νμ¬ λλ²κΉ ν μ μμ΅λλ€.
- Python 3.8 - 3.11
- PyTorch 2.2 μ΄μ
``CommDebugMode``λ 무μμ΄λ©°, μ μ μ©νκ°
λͺ¨λΈμ ν¬κΈ°κ° 컀μ§μ λ°λΌ, μ¬μ©μλ λ€μν λ³λ ¬ν(parallelism) μ λ΅μ μ‘°ν©νμ¬ λΆμ° νμ΅(distributed training)μ νμ₯νλ € ν©λλ€. νμ§λ§ κΈ°μ‘΄ μ루μ κ°μ μνΈμ΄μ©μ±(interoperability) λΆμ‘±μ μ¬μ ν ν° κ³Όμ λ‘ λ¨μ μμ΅λλ€. μ΄λ μλ‘ λ€λ₯Έ λ³λ ¬ν μ λ΅μ μ°κ²°ν μ μλ ν΅ν©λ μΆμν(unified abstraction)κ° λΆμ‘±νκΈ° λλ¬Έμ λλ€.
μ΄ λ¬Έμ λ₯Ό ν΄κ²°νκΈ° μν΄ PyTorchλ DistributedTensor(DTensor) λ₯Ό λμ νμ΅λλ€. DTensorλ λΆμ° νμ΅ νκ²½μμ ν μ ν΅μ μ 볡μ‘μ±μ μΆμννμ¬ μ¬μ©μμκ² μΌκ΄λκ³ κ°κ²°ν κ²½νμ μ 곡ν©λλ€.
κ·Έλ¬λ μ΄λ¬ν ν΅ν© μΆμνλ₯Ό μ¬μ©νλ κ³Όμ μμ, λ΄λΆμ μΌλ‘ μ΄λ€ μμ μ μ§ν© ν΅μ μ΄ μνλλμ§ λͺ νν μκΈ° μ΄λ €μ κ³ κΈ μ¬μ©μκ° λλ²κΉ νκ±°λ λ¬Έμ λ₯Ό μλ³νκΈ° μ΄λ ΅μ΅λλ€.
μ΄λ ``CommDebugMode``λ Pythonμ 컨ν μ€νΈ λ§€λμ (context manager)λ‘μ DTensor μ¬μ© μ€ λ°μνλ μ§ν© μ°μ°μ μμ κ³Ό μ΄μ λ₯Ό μκ°μ μΌλ‘ μΆμ ν μ μλ μ£Όμ λλ²κΉ λꡬμ λλ€. μ΄λ₯Ό ν΅ν΄ μ¬μ©μλ μΈμ , μ collective μ°μ°μ΄ μ€νλλμ§λ₯Ό λͺ νν νμ ν μ μμ΅λλ€.
λ€μμ ``CommDebugMode``λ₯Ό μ¬μ©νλ μμμ λλ€:
# μ΄ μμ μμ μ¬μ©λ λͺ¨λΈμ ν
μ λ³λ ¬ν(tensor parallelism)λ₯Ό μ μ©ν MLPModuleμ
λλ€.
comm_mode = CommDebugMode()
with comm_mode:
output = model(inp)
# μ°μ° λ¨μμ collective μΆμ μ 보λ₯Ό μΆλ ₯
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))
# μ°μ° λ¨μμ collective μΆμ μ 보λ₯Ό νμΌλ‘ κΈ°λ‘
comm_mode.log_comm_debug_tracing_table_to_file(
noise_level=1, file_name="transformer_operation_log.txt"
)
# μ°μ° λ¨μμ collective μΆμ μ 보λ₯Ό JSON νμΌλ‘ λ€ν(dump)
# μλμ μκ°ν λΈλΌμ°μ μμ μ΄ JSON νμΌμ μ¬μ©ν μ μμ΅λλ€.
comm_mode.generate_json_dump(noise_level=2)λ€μμ noise level 0μμ MLPModuleμ μΆλ ₯ μμμ λλ€:
Expected Output:
Global
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule
FORWARD PASS
*c10d_functional.all_reduce: 1
MLPModule.net1
MLPModule.relu
MLPModule.net2
FORWARD PASS
*c10d_functional.all_reduce: 1CommDebugMode``λ₯Ό μ¬μ©νλ €λ©΄ λͺ¨λΈ μ€ν μ½λλ₯Ό ``CommDebugMode λΈλ‘ μμ κ°μΈκ³ ,
μνλ μ 보λ₯Ό νμνλ APIλ₯Ό νΈμΆνλ©΄ λ©λλ€.
λν noise_level μΈμλ₯Ό μ¬μ©ν΄ μΆλ ₯λλ μ 보μ μμΈ μμ€(verbosity level)μ μ μ΄ν μ μμ΅λλ€.
κ° noise levelμ λ€μκ³Ό κ°μ μ 보λ₯Ό μ 곡ν©λλ€:
μμ μμμμ λ³Ό μ μλ―μ΄, collective μ°μ°μΈ all_reduceλ ``MLPModule``μ forward λ¨κ³μμ ν λ² λ°μν©λλ€. λν ``CommDebugMode``λ₯Ό μ¬μ©νλ©΄ μ΄ all-reduce μ°μ°μ΄ ``MLPModule``μ λ λ²μ§Έ μ ν κ³μΈ΅(linear layer)μμ λ°μνλ€λ μ μ μ νν νμΈν μ μμ΅λλ€.
μλλ μμ±λ JSON νμΌμ μ λ‘λνμ¬ μκ°μ μΌλ‘ νμν μ μλ μΈν°λν°λΈ λͺ¨λ νΈλ¦¬ μκ°ν(interactive module tree visualization)μ λλ€:
CommDebugMode Module Tree ul, #tree-container { list-style-type: none; margin: 0; padding: 0; } .caret { cursor: pointer; user-select: none; } .caret::before { content: "\25B6"; color:black; display: inline-block; margin-right: 6px; } .caret-down::before { transform: rotate(90deg); } .tree { padding-left: 20px; } .tree ul { padding-left: 20px; } .nested { display: none; } .active { display: block; } .forward-pass, .backward-pass { margin-left: 40px; } .forward-pass table { margin-left: 40px; width: auto; } .forward-pass table td, .forward-pass table th { padding: 8px; } .forward-pass ul { display: none; } table { font-family: arial, sans-serif; border-collapse: collapse; width: 100%; } td, th { border: 1px solid #dddddd; text-align: left; padding: 8px; } tr:nth-child(even) { background-color: #dddddd; } #drop-area { position: relative; width: 25%; height: 100px; border: 2px dashed #ccc; border-radius: 5px; padding: 0px; text-align: center; } .drag-drop-block { display: inline-block; width: 200px; height: 50px; background-color: #f7f7f7; border: 1px solid #ccc; border-radius: 5px; padding: 10px; font-size: 14px; color: #666; cursor: pointer; } #file-input { position: absolute; top: 0; left: 0; width: 100%; height: 100%; opacity: 0; }μ΄ λ μνΌμμλ PyTorchμ ``CommDebugMode``λ₯Ό μ¬μ©νμ¬ μ§ν© ν΅μ (collective communication)μ ν¬ν¨νλ DistributedTensor λ° λ³λ ¬ν μ루μ μ λλ²κΉ νλ λ°©λ²μ λ°°μ μ΅λλ€. λν μμ±λ JSON μΆλ ₯μ λ΄μ₯λ μκ°ν λΈλΌμ°μ μμ μ§μ λΆλ¬μ νμΈν μλ μμ΅λλ€.
``CommDebugMode``μ λν λ³΄λ€ μμΈν λ΄μ©μ comm_mode_features_example.py λ₯Ό μ°Έκ³ νμΈμ.