11"""
2- Reasoning about Shapes in PyTorch
2+ PyTorch์ Shape๋ค์ ๋ํ ์ถ๋ก
33=================================
4+ ๋ฒ์ญ: `์ด์์ญ <https://github.com/0seob>`_
45
5- When writing models with PyTorch, it is commonly the case that the parameters
6- to a given layer depend on the shape of the output of the previous layer. For
7- example, the ``in_features`` of an ``nn.Linear`` layer must match the
8- ``size(-1)`` of the input. For some layers, the shape computation involves
9- complex equations, for example convolution operations.
6+ ์ผ๋ฐ์ ์ผ๋ก PyTorch๋ก ๋ชจ๋ธ์ ์์ฑํ ๋ ํน์ ๊ณ์ธต์ ๋งค๊ฐ๋ณ์๋ ์ด์ ๊ณ์ธต์ ์ถ๋ ฅ shape์ ๋ฐ๋ผ ๋ฌ๋ผ์ง๋๋ค.
7+ ์๋ฅผ ๋ค์ด, ``nn.Linear`` ๊ณ์ธต์ ``in_features`` ๋ ์
๋ ฅ์ ``size(-1)`` ์ ์ผ์นํด์ผ ํฉ๋๋ค.
8+ ๋ช๋ช ๊ณ์ธต์ ๊ฒฝ์ฐ, shape ๊ณ์ฐ์ ํฉ์ฑ๊ณฑ ์ฐ์ฐ๊ณผ ๊ฐ์ ๋ณต์กํ ๋ฐฉ์ ์์ ํฌํจํฉ๋๋ค.
109
11- One way around this is to run the forward pass with random inputs, but this is
12- wasteful in terms of memory and compute.
10+ ์ด๋ฅผ ๋๋คํ ์
๋ ฅ์ผ๋ก ์์ ํ(forward pass)๋ฅผ ์คํํ์ฌ ํด๊ฒฐํ ์ ์์ง๋ง, ์ด๋ ๋ฉ๋ชจ๋ฆฌ์ ์ปดํจํ
ํ์๋ฅผ ๋ญ๋นํฉ๋๋ค.
1311
14- Instead, we can make use of the ``meta`` device to determine the output shapes
15- of a layer without materializing any data.
12+ ๋์ ์ ``meta`` ๋๋ฐ์ด์ค๋ฅผ ํ์ฉํ๋ค๋ฉด ๋ฐ์ดํฐ๋ฅผ ๊ตฌ์ฒดํํ์ง ์๊ณ ๋ ๊ณ์ธต์ ์ถ๋ ฅ shape์ ๊ฒฐ์ ํ ์ ์์ต๋๋ค.
1613"""
1714
1815import torch
2926
3027
3128##########################################################################
32- # Observe that since data is not materialized, passing arbitrarily large
33- # inputs will not significantly alter the time taken for shape computation .
29+ # ๋ฐ์ดํฐ๊ฐ ๊ตฌ์ฒดํ๋์ง ์๊ธฐ ๋๋ฌธ์ ์์๋ก ํฐ ์
๋ ฅ์ ์ ๋ฌํด๋ shape ๊ณ์ฐ์ ์์๋๋ ์๊ฐ์ด
30+ # ํฌ๊ฒ ๋ณ๊ฒฝ๋์ง๋ ์์ต๋๋ค .
3431
3532t_large = torch .rand (2 ** 10 , 3 , 2 ** 16 , 2 ** 16 , device = "meta" )
3633start = timeit .default_timer ()
4239
4340
4441######################################################
45- # Consider an arbitrary network such as the following :
42+ # ๋ค์๊ณผ ๊ฐ์ ์์์ ๋คํธ์ํฌ๋ฅผ ๊ฐ์ ํฉ๋๋ค :
4643
4744import torch .nn as nn
4845import torch .nn .functional as F
@@ -61,23 +58,23 @@ def __init__(self):
6158 def forward (self , x ):
6259 x = self .pool (F .relu (self .conv1 (x )))
6360 x = self .pool (F .relu (self .conv2 (x )))
64- x = torch .flatten (x , 1 ) # flatten all dimensions except batch
61+ x = torch .flatten (x , 1 ) # ๋ฐฐ์น๋ฅผ ์ ์ธํ ๋ชจ๋ ์ฐจ์์ ํํํ ํฉ๋๋ค.
6562 x = F .relu (self .fc1 (x ))
6663 x = F .relu (self .fc2 (x ))
6764 x = self .fc3 (x )
6865 return x
6966
7067
7168###############################################################################
72- # We can view the intermediate shapes within an entire network by registering a
73- # forward hook to each layer that prints the shape of the output .
69+ # ๊ฐ๊ฐ์ ๊ณ์ธต์ ์ถ๋ ฅ์ shape์ ์ธ์ํ๋ forward hook์ ๋ฑ๋กํ์ฌ ๋คํธ์ํฌ์
70+ # ์ค๊ฐ shape์ ํ์ธํ ์ ์์ต๋๋ค .
7471
7572def fw_hook (module , input , output ):
7673 print (f"Shape of output to { module } is { output .shape } ." )
7774
7875
79- # Any tensor created within this torch.device context manager will be
80- # on the meta device .
76+ # torch.device context manager(with ๊ตฌ๋ฌธ) ๋ด๋ถ์์ ์์ฑ๋ ๋ชจ๋ tensor๋
77+ # meta ๋๋ฐ์ด์ค ๋ด๋ถ์ ์กด์ฌํฉ๋๋ค .
8178with torch .device ("meta" ):
8279 net = Net ()
8380 inp = torch .randn ((1024 , 3 , 32 , 32 ))
0 commit comments