Skip to content

Latest commit

Β 

History

History
215 lines (181 loc) Β· 8.1 KB

File metadata and controls

215 lines (181 loc) Β· 8.1 KB

PyTorchμ—μ„œ CommDebugMode μ‹œμž‘ν•˜κΈ°

μ €μž: Anshul Sinha

이 νŠœν† λ¦¬μ–Όμ—μ„œλŠ” PyTorch의 DistributedTensor(DTensor)와 ν•¨κ»˜ ``CommDebugMode``λ₯Ό μ‚¬μš©ν•˜λŠ” 방법을 μ‚΄νŽ΄λ΄…λ‹ˆλ‹€. 이λ₯Ό 톡해 λΆ„μ‚° ν•™μŠ΅ ν™˜κ²½μ—μ„œ μˆ˜ν–‰λ˜λŠ” μ§‘ν•© μ—°μ‚°(collective operation)을 μΆ”μ ν•˜μ—¬ 디버깅할 수 μžˆμŠ΅λ‹ˆλ‹€.

사전 μ€€λΉ„(Prerequisites)

  • Python 3.8 - 3.11
  • PyTorch 2.2 이상

``CommDebugMode``λž€ 무엇이며, μ™œ μœ μš©ν•œκ°€

λͺ¨λΈμ˜ 크기가 컀짐에 따라, μ‚¬μš©μžλŠ” λ‹€μ–‘ν•œ 병렬화(parallelism) μ „λž΅μ„ μ‘°ν•©ν•˜μ—¬ λΆ„μ‚° ν•™μŠ΅(distributed training)을 ν™•μž₯ν•˜λ € ν•©λ‹ˆλ‹€. ν•˜μ§€λ§Œ κΈ°μ‘΄ μ†”λ£¨μ…˜ κ°„μ˜ μƒν˜Έμš΄μš©μ„±(interoperability) 뢀쑱은 μ—¬μ „νžˆ 큰 과제둜 남아 μžˆμŠ΅λ‹ˆλ‹€. μ΄λŠ” μ„œλ‘œ λ‹€λ₯Έ 병렬화 μ „λž΅μ„ μ—°κ²°ν•  수 μžˆλŠ” ν†΅ν•©λœ 좔상화(unified abstraction)κ°€ λΆ€μ‘±ν•˜κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€.

이 문제λ₯Ό ν•΄κ²°ν•˜κΈ° μœ„ν•΄ PyTorchλŠ” DistributedTensor(DTensor) λ₯Ό λ„μž…ν–ˆμŠ΅λ‹ˆλ‹€. DTensorλŠ” λΆ„μ‚° ν•™μŠ΅ ν™˜κ²½μ—μ„œ ν…μ„œ ν†΅μ‹ μ˜ λ³΅μž‘μ„±μ„ μΆ”μƒν™”ν•˜μ—¬ μ‚¬μš©μžμ—κ²Œ μΌκ΄€λ˜κ³  κ°„κ²°ν•œ κ²½ν—˜μ„ μ œκ³΅ν•©λ‹ˆλ‹€.

κ·ΈλŸ¬λ‚˜ μ΄λŸ¬ν•œ 톡합 좔상화λ₯Ό μ‚¬μš©ν•˜λŠ” κ³Όμ •μ—μ„œ, λ‚΄λΆ€μ μœΌλ‘œ μ–΄λ–€ μ‹œμ μ— μ§‘ν•© 톡신이 μˆ˜ν–‰λ˜λŠ”μ§€ λͺ…ν™•νžˆ μ•ŒκΈ° μ–΄λ €μ›Œ κ³ κΈ‰ μ‚¬μš©μžκ°€ λ””λ²„κΉ…ν•˜κ±°λ‚˜ 문제λ₯Ό μ‹λ³„ν•˜κΈ° μ–΄λ ΅μŠ΅λ‹ˆλ‹€.

μ΄λ•Œ ``CommDebugMode``λŠ” Python의 μ»¨ν…μŠ€νŠΈ λ§€λ‹ˆμ €(context manager)λ‘œμ„œ DTensor μ‚¬μš© 쀑 λ°œμƒν•˜λŠ” μ§‘ν•© μ—°μ‚°μ˜ μ‹œμ κ³Ό 이유λ₯Ό μ‹œκ°μ μœΌλ‘œ 좔적할 수 μžˆλŠ” μ£Όμš” 디버깅 λ„κ΅¬μž…λ‹ˆλ‹€. 이λ₯Ό 톡해 μ‚¬μš©μžλŠ” μ–Έμ œ, μ™œ collective 연산이 μ‹€ν–‰λ˜λŠ”μ§€λ₯Ό λͺ…ν™•νžˆ νŒŒμ•…ν•  수 μžˆμŠ΅λ‹ˆλ‹€.

CommDebugMode μ‚¬μš©λ²•

λ‹€μŒμ€ ``CommDebugMode``λ₯Ό μ‚¬μš©ν•˜λŠ” μ˜ˆμ‹œμž…λ‹ˆλ‹€:

# 이 μ˜ˆμ œμ—μ„œ μ‚¬μš©λœ λͺ¨λΈμ€ ν…μ„œ 병렬화(tensor parallelism)λ₯Ό μ μš©ν•œ MLPModuleμž…λ‹ˆλ‹€.
comm_mode = CommDebugMode()
with comm_mode:
    output = model(inp)

# μ—°μ‚° λ‹¨μœ„μ˜ collective 좔적 정보λ₯Ό 좜λ ₯
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# μ—°μ‚° λ‹¨μœ„μ˜ collective 좔적 정보λ₯Ό 파일둜 기둝
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# μ—°μ‚° λ‹¨μœ„μ˜ collective 좔적 정보λ₯Ό JSON 파일둜 덀프(dump)
# μ•„λž˜μ˜ μ‹œκ°ν™” λΈŒλΌμš°μ €μ—μ„œ 이 JSON νŒŒμΌμ„ μ‚¬μš©ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
comm_mode.generate_json_dump(noise_level=2)

λ‹€μŒμ€ noise level 0μ—μ„œ MLPModule의 좜λ ₯ μ˜ˆμ‹œμž…λ‹ˆλ‹€:

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

CommDebugMode``λ₯Ό μ‚¬μš©ν•˜λ €λ©΄ λͺ¨λΈ μ‹€ν–‰ μ½”λ“œλ₯Ό ``CommDebugMode 블둝 μ•ˆμ— 감싸고, μ›ν•˜λŠ” 정보λ₯Ό ν‘œμ‹œν•˜λŠ” APIλ₯Ό ν˜ΈμΆœν•˜λ©΄ λ©λ‹ˆλ‹€.

λ˜ν•œ noise_level 인자λ₯Ό μ‚¬μš©ν•΄ 좜λ ₯λ˜λŠ” μ •λ³΄μ˜ 상세 μˆ˜μ€€(verbosity level)을 μ œμ–΄ν•  수 μžˆμŠ΅λ‹ˆλ‹€. 각 noise level은 λ‹€μŒκ³Ό 같은 정보λ₯Ό μ œκ³΅ν•©λ‹ˆλ‹€:

0. λͺ¨λ“ˆ λ‹¨μœ„μ˜ collective μ—°μ‚° 개수 좜λ ₯
1. μ€‘μš”ν•˜μ§€ μ•Šμ€ 연산을 μ œμ™Έν•œ DTensor μ—°μ‚° 및 λͺ¨λ“ˆ 샀딩(sharding) 정보 좜λ ₯
2. μ€‘μš”ν•˜μ§€ μ•Šμ€ 연산을 μ œμ™Έν•œ ν…μ„œ λ‹¨μœ„ μ—°μ‚° 좜λ ₯
3. λͺ¨λ“  μ—°μ‚° 좜λ ₯

μœ„μ˜ μ˜ˆμ‹œμ—μ„œ λ³Ό 수 μžˆλ“―μ΄, collective 연산인 all_reduceλŠ” ``MLPModule``의 forward λ‹¨κ³„μ—μ„œ ν•œ 번 λ°œμƒν•©λ‹ˆλ‹€. λ˜ν•œ ``CommDebugMode``λ₯Ό μ‚¬μš©ν•˜λ©΄ 이 all-reduce 연산이 ``MLPModule``의 두 번째 μ„ ν˜• 계측(linear layer)μ—μ„œ λ°œμƒν•œλ‹€λŠ” 점을 μ •ν™•νžˆ 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

μ•„λž˜λŠ” μƒμ„±λœ JSON νŒŒμΌμ„ μ—…λ‘œλ“œν•˜μ—¬ μ‹œκ°μ μœΌλ‘œ 탐색할 수 μžˆλŠ” μΈν„°λž™ν‹°λΈŒ λͺ¨λ“ˆ 트리 μ‹œκ°ν™”(interactive module tree visualization)μž…λ‹ˆλ‹€:

CommDebugMode Module Tree ul, #tree-container { list-style-type: none; margin: 0; padding: 0; } .caret { cursor: pointer; user-select: none; } .caret::before { content: "\25B6"; color:black; display: inline-block; margin-right: 6px; } .caret-down::before { transform: rotate(90deg); } .tree { padding-left: 20px; } .tree ul { padding-left: 20px; } .nested { display: none; } .active { display: block; } .forward-pass, .backward-pass { margin-left: 40px; } .forward-pass table { margin-left: 40px; width: auto; } .forward-pass table td, .forward-pass table th { padding: 8px; } .forward-pass ul { display: none; } table { font-family: arial, sans-serif; border-collapse: collapse; width: 100%; } td, th { border: 1px solid #dddddd; text-align: left; padding: 8px; } tr:nth-child(even) { background-color: #dddddd; } #drop-area { position: relative; width: 25%; height: 100px; border: 2px dashed #ccc; border-radius: 5px; padding: 0px; text-align: center; } .drag-drop-block { display: inline-block; width: 200px; height: 50px; background-color: #f7f7f7; border: 1px solid #ccc; border-radius: 5px; padding: 10px; font-size: 14px; color: #666; cursor: pointer; } #file-input { position: absolute; top: 0; left: 0; width: 100%; height: 100%; opacity: 0; }
Drag file here

κ²°λ‘ (Conclusion)

이 λ ˆμ‹œν”Όμ—μ„œλŠ” PyTorch의 ``CommDebugMode``λ₯Ό μ‚¬μš©ν•˜μ—¬ μ§‘ν•© 톡신(collective communication)을 ν¬ν•¨ν•˜λŠ” DistributedTensor 및 병렬화 μ†”λ£¨μ…˜μ„ λ””λ²„κΉ…ν•˜λŠ” 방법을 λ°°μ› μŠ΅λ‹ˆλ‹€. λ˜ν•œ μƒμ„±λœ JSON 좜λ ₯을 λ‚΄μž₯된 μ‹œκ°ν™” λΈŒλΌμš°μ €μ—μ„œ 직접 λΆˆλŸ¬μ™€ 확인할 μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.

``CommDebugMode``에 λŒ€ν•œ 보닀 μžμ„Έν•œ λ‚΄μš©μ€ comm_mode_features_example.py λ₯Ό μ°Έκ³ ν•˜μ„Έμš”.