Skip to content

Latest commit

Β 

History

History
291 lines (212 loc) Β· 11.2 KB

File metadata and controls

291 lines (212 loc) Β· 11.2 KB

μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜μ™€ 이쀑 μ—­μ „νŒŒ

μ—­μ „νŒŒ κ·Έλž˜ν”„λ₯Ό 톡해 μ—­μ „νŒŒλ₯Ό 두 번 μ‹€ν–‰ν•˜λŠ” 것은 가끔씩 μœ μš©ν•œ κ²½μš°κ°€ μžˆμŠ΅λ‹ˆλ‹€. 예λ₯Ό λ“€μ–΄ κ³ μ°¨ 미뢄을 계산할 λ•Œμž…λ‹ˆλ‹€. κ·ΈλŸ¬λ‚˜ 이쀑 μ—­μ „νŒŒλ₯Ό μ§€μ›ν•˜λ €λ©΄ autograd의 이해와 μ£Όμ˜κ°€ ν•„μš”ν•©λ‹ˆλ‹€. 단일 μ—­μ „νŒŒλ₯Ό μ§€μ›ν•œλ‹€κ³  λ°˜λ“œμ‹œ 이쀑 μ—­μ „νŒŒλ₯Ό μ§€μ›ν•˜λŠ” 것은 μ•„λ‹™λ‹ˆλ‹€. 이 νŠœν† λ¦¬μ–Όμ—μ„œλŠ” μ–΄λ–»κ²Œ μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜λ‘œ 이쀑 μ—­μ „νŒŒλ₯Ό μ§€μ›ν•˜λŠ”μ§€ μ•Œλ €μ£Όκ³  μ£Όμ˜ν•΄μ•Ό ν•  점듀을 μ•ˆλ‚΄ν•©λ‹ˆλ‹€.

이쀑 μ—­μ „νŒŒλ₯Ό μ‚¬μš©ν•˜λŠ” μ‚¬μš©μž μ •μ˜ autograd ν•¨μˆ˜λ₯Ό μ‚¬μš©ν•  λ•Œ, ν•¨μˆ˜ λ‚΄μ—μ„œ μ–΄λ–»κ²Œ λ™μž‘ν•˜λŠ”μ§€ μ–Έμ œ 계산 κ²°κ³Όκ°€ 기둝되고 μ–Έμ œ κΈ°λ‘λ˜μ§€ μ•ŠλŠ”μ§€ μ΄ν•΄ν•˜λŠ” 것이 μ€‘μš”ν•©λ‹ˆλ‹€. 특히 전체 κ³Όμ •μ—μ„œ save_for_backward κ°€ μ–΄λ–»κ²Œ λ™μž‘ν•˜λŠ”μ§€ μ•„λŠ” 것이 κ°€μž₯ μ€‘μš”ν•©λ‹ˆλ‹€.

μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜λŠ” μ•”λ¬΅μ μœΌλ‘œ grad λͺ¨λ“œμ— 두 κ°€μ§€ λ°©μ‹μœΌλ‘œ 영ν–₯을 μ€λ‹ˆλ‹€:

  • μˆœμ „νŒŒλ₯Ό μ§„ν–‰ν•˜λŠ” λ™μ•ˆ autogradλŠ” μˆœμ „νŒŒ ν•¨μˆ˜μ•ˆμ—μ„œ λ™μž‘ν•˜λŠ” μ–΄λ–€ 연산도 κ·Έλž˜ν”„μ— κΈ°λ‘ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. μˆœμ „νŒŒκ°€ λλ‚˜κ³  μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜μ˜ μ—­μ „νŒŒλŠ” μˆœμ „νŒŒμ˜ 결과의 grad_fn 이 λ©λ‹ˆλ‹€.
  • μ—­μ „νŒŒκ°€ μ§„ν–‰λ˜λŠ” λ™μ•ˆ create_graphκ°€ μ§€μ •λ˜μ–΄ μžˆλ‹€λ©΄ autogradλŠ” μ—­μ „νŒŒμ˜ 연산을 κ·Έλž˜ν”„μ— κΈ°λ‘ν•©λ‹ˆλ‹€.

λ‹€μŒμœΌλ‘œ, save_for_backward κ°€ μœ„μ˜ λ‚΄μš©κ³Ό μ–΄λ–»κ²Œ μƒν˜Έμž‘μš©ν•˜λŠ”μ§€ μ΄ν•΄ν•˜κΈ° μœ„ν•΄μ„œ, λͺ‡ κ°€μ§€ μ˜ˆμ‹œλ₯Ό μ‚΄νŽ΄λ³΄κ² μŠ΅λ‹ˆλ‹€:

μž…λ ₯κ°’ μ €μž₯ν•˜κΈ°

κ°„λ‹¨ν•œ 제곱 ν•¨μˆ˜λ₯Ό 생각해 λ³΄κ² μŠ΅λ‹ˆλ‹€. 이 ν•¨μˆ˜λŠ” μ—­μ „νŒŒλ₯Ό μœ„ν•΄μ„œ μž…λ ₯ ν…μ„œλ₯Ό μ €μž₯ν•©λ‹ˆλ‹€. μ—­μ „νŒŒ 과정을 autogradκ°€ 기둝할 수 μžˆλ‹€λ©΄ 이쀑 μ—­μ „νŒŒλŠ” μžλ™μœΌλ‘œ λ™μž‘ν•©λ‹ˆλ‹€. λ”°λΌμ„œ μ—­μ „νŒŒλ₯Ό μœ„ν•΄ μž…λ ₯을 μ €μž₯ν•  λ•ŒλŠ” 일반적으둜 κ±±μ •ν•  ν•„μš”κ°€ μ—†μŠ΅λ‹ˆλ‹€. μž…λ ₯이 gradλ₯Ό μš”κ΅¬ν•˜λŠ” ν…μ„œλΆ€ν„° κ³„μ‚°λœ ν•¨μˆ˜λΌλ©΄ grad_fn을 κ°€μ§€κ³  있고 이λ₯Ό ν†΅ν•΄μ„œ 변화도가 μ˜¬λ°”λ₯΄κ²Œ μ „νŒŒλ˜κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€.

import torch

class Square(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        # Because we are saving one of the inputs use `save_for_backward`
        # Save non-tensors and non-inputs/non-outputs directly on ctx
        ctx.save_for_backward(x)
        return x**2

    @staticmethod
    def backward(ctx, grad_out):
        # A function support double backward automatically if autograd
        # is able to record the computations performed in backward
        x, = ctx.saved_tensors
        return grad_out * 2 * x

# Use double precision because finite differencing method magnifies errors
x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(Square.apply, x)
# Use gradcheck to verify second-order derivatives
torch.autograd.gradgradcheck(Square.apply, x)

torchviz둜 κ·Έλž˜ν”„λ₯Ό μ‹œκ°ν™”ν•΄μ„œ μž‘λ™μ›λ¦¬λ₯Ό 확인할 수 μžˆμŠ΅λ‹ˆλ‹€.

import torchviz

x = torch.tensor(1., requires_grad=True).clone()
out = Square.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

x에 λŒ€ν•œ 변화도가 κ·Έ 자체둜 x의 ν•¨μˆ˜λΌλŠ” 것을 확인할 수 μžˆμŠ΅λ‹ˆλ‹€(dout/dx = 2x). 이 ν•¨μˆ˜μ— λŒ€ν•œ κ·Έλž˜ν”„λ„ μ œλŒ€λ‘œ μƒμ„±λ˜μ—ˆμŠ΅λ‹ˆλ‹€.

https://user-images.githubusercontent.com/13428986/126559699-e04f3cb1-aaf2-4a9a-a83d-b8767d04fbd9.png

κ²°κ³Ό μ €μž₯ν•˜κΈ°

이전 예제λ₯Ό 쑰금 λ³€ν˜•ν•˜λ©΄ μž…λ ₯λŒ€μ‹  좜λ ₯을 μ €μž₯ν• μˆ˜ μžˆμŠ΅λ‹ˆλ‹€. 좜λ ₯도 grad_fn을 κ°€μ§€κ³ μžˆκΈ°μ— 방식을 λΉ„μŠ·ν•©λ‹ˆλ‹€.

class Exp(torch.autograd.Function):
    # Simple case where everything goes well
    @staticmethod
    def forward(ctx, x):
        # This time we save the output
        result = torch.exp(x)
        # Note that we should use `save_for_backward` here when
        # the tensor saved is an ouptut (or an input).
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_out):
        result, = ctx.saved_tensors
        return result * grad_out

x = torch.tensor(1., requires_grad=True, dtype=torch.double).clone()
# Validate our gradients using gradcheck
torch.autograd.gradcheck(Exp.apply, x)
torch.autograd.gradgradcheck(Exp.apply, x)

torchviz둜 κ·Έλž˜ν”„ μ‹œκ°ν™”ν•˜κΈ°:

out = Exp.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), {"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559780-d141f2ba-1ee8-4c33-b4eb-c9877b27a954.png

쀑간 κ²°κ³Ό μ €μž₯ν•˜κΈ°

쀑간 κ²°κ³Όλ₯Ό μ €μž₯ν•˜λŠ”κ²ƒμ€ μ’€ 더 μ–΄λ ΅μŠ΅λ‹ˆλ‹€. λ‹€μŒμ„ κ΅¬ν˜„ν•˜μ—¬ λ³΄μ—¬λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€:

sinh(x) := \frac{e^x - e^{-x}}{2}

sinh의 λ„ν•¨μˆ˜λŠ” coshμ΄λ―€λ‘œ, μˆœμ „νŒŒμ˜ 쀑간결과인 exp(x) 와 exp(-x) λ₯Ό μ—­μ „νŒŒ 계산에 μž¬μ‚¬μš©ν•˜λ©΄ νš¨μœ¨μ μž…λ‹ˆλ‹€.

쀑간 κ²°κ³Όλ₯Ό 직접 μ €μž₯ν•˜μ—¬ μ—­μ „νŒŒμ— μ‚¬μš©ν•˜λ©΄ μ•ˆ λ©λ‹ˆλ‹€. μˆœμ „νŒŒκ°€ no-grad λͺ¨λ“œμ—μ„œ μ‹€ν–‰λ˜κΈ° λ•Œλ¬Έμ—, λ§Œμ•½ μˆœμ „νŒŒμ˜ 쀑간 κ²°κ³Όκ°€ μ—­μ „νŒŒμ—μ„œ 변화도λ₯Ό κ³„μ‚°ν•˜λŠ” 데 μ‚¬μš©λ˜λ©΄ λ³€ν™”λ„μ˜ μ—­μ „νŒŒ κ·Έλž˜ν”„μ— 쀑간 κ²°κ³Όλ₯Ό κ³„μ‚°ν•œ 연산듀이 ν¬ν•¨λ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. 결과적으둜 변화도가 λΆ€μ •ν™•ν•΄μ§‘λ‹ˆλ‹€.

class Sinh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.save_for_backward(expx, expnegx)
        # In order to be able to save the intermediate results, a trick is to
        # include them as our outputs, so that the backward graph is constructed
        return (expx - expnegx) / 2, expx, expnegx

    @staticmethod
    def backward(ctx, grad_out, _grad_out_exp, _grad_out_negexp):
        expx, expnegx = ctx.saved_tensors
        grad_input = grad_out * (expx + expnegx) / 2
        # We cannot skip accumulating these even though we won't use the outputs
        # directly. They will be used later in the second backward.
        grad_input += _grad_out_exp * expx
        grad_input -= _grad_out_negexp * expnegx
        return grad_input

def sinh(x):
    # Create a wrapper that only returns the first output
    return Sinh.apply(x)[0]

x = torch.rand(3, 3, requires_grad=True, dtype=torch.double)
torch.autograd.gradcheck(sinh, x)
torch.autograd.gradgradcheck(sinh, x)

torchviz둜 κ·Έλž˜ν”„ μ‹œκ°ν™”ν•˜κΈ°:

out = sinh(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126560494-e48eba62-be84-4b29-8c90-a7f6f40b1438.png

쀑간결과 μ €μž₯ν•˜κΈ°: 잘λͺ»λœ 방법

쀑간 κ²°κ³Όλ₯Ό 좜λ ₯으둜 λ°˜ν™˜ν•˜μ§€ μ•ŠμœΌλ©΄ μ–΄λ–€ 일이 λ°œμƒν•˜λŠ”μ§€ μ‚΄νŽ΄λ³΄κ² μŠ΅λ‹ˆλ‹€. grad_x λŠ” μ—­μ „νŒŒ κ·Έλž˜ν”„λ₯Ό μ•„μ˜ˆ κ°–μ§€ λͺ»ν•©λ‹ˆλ‹€. 이것은 grad_x κ°€ 였직 gradλ₯Ό ν•„μš”λ‘œ ν•˜μ§€ μ•ŠλŠ” exp 와 expnegx 의 ν•¨μˆ˜μ΄κΈ° λ•Œλ¬Έμž…λ‹ˆλ‹€.

class SinhBad(torch.autograd.Function):
    # This is an example of what NOT to do!
    @staticmethod
    def forward(ctx, x):
        expx = torch.exp(x)
        expnegx = torch.exp(-x)
        ctx.expx = expx
        ctx.expnegx = expnegx
        return (expx - expnegx) / 2

    @staticmethod
    def backward(ctx, grad_out):
        expx = ctx.expx
        expnegx = ctx.expnegx
        grad_input = grad_out * (expx + expnegx) / 2
        return grad_input

torchviz둜 κ·Έλž˜ν”„ μ‹œκ°ν™”ν•˜κΈ°. grad_x κ°€ κ·Έλž˜ν”„μ— ν¬ν•¨λ˜μ§€ μ•ŠλŠ”κ²ƒμ„ ν™•μΈν•˜μ„Έμš”!

out = SinhBad.apply(x)
grad_x, = torch.autograd.grad(out.sum(), x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126565889-13992f01-55bc-411a-8aee-05b721fe064a.png

μ—­μ „νŒŒ 좔적이 λΆˆκ°€λŠ₯ν•œ 경우

λ§ˆμ§€λ§‰μœΌλ‘œ autogradκ°€ ν•¨μˆ˜μ˜ μ—­μ „νŒŒμ— λŒ€ν•œ 변화도λ₯Ό 좔적할 수 μ—†λŠ” 상황을 μ‚΄νŽ΄λ³΄κ² μŠ΅λ‹ˆλ‹€. cube_backwardκ°€ SciPyλ‚˜ NumPy 같은 μ™ΈλΆ€ 라이브러리λ₯Ό μ‚¬μš©ν•˜κ±°λ‚˜ C++둜 κ΅¬ν˜„λ˜μ—ˆλ‹€κ³  κ°€μ •ν•΄ λ³΄κ² μŠ΅λ‹ˆλ‹€. 이런 κ²½μš°λŠ” CubeBackwardλΌλŠ” 또 λ‹€λ₯Έ μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜λ₯Ό μƒμ„±ν•˜μ—¬ cube_backward의 μ—­μ „νŒŒλ„ μˆ˜λ™μœΌλ‘œ μ§€μ •ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€!

def cube_forward(x):
    return x**3

def cube_backward(grad_out, x):
    return grad_out * 3 * x**2

def cube_backward_backward(grad_out, sav_grad_out, x):
    return grad_out * sav_grad_out * 6 * x

def cube_backward_backward_grad_out(grad_out, x):
    return grad_out * 3 * x**2

class Cube(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return cube_forward(x)

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        return CubeBackward.apply(grad_out, x)

class CubeBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, grad_out, x):
        ctx.save_for_backward(x, grad_out)
        return cube_backward(grad_out, x)

    @staticmethod
    def backward(ctx, grad_out):
        x, sav_grad_out = ctx.saved_tensors
        dx = cube_backward_backward(grad_out, sav_grad_out, x)
        dgrad_out = cube_backward_backward_grad_out(grad_out, x)
        return dgrad_out, dx

x = torch.tensor(2., requires_grad=True, dtype=torch.double)

torch.autograd.gradcheck(Cube.apply, x)
torch.autograd.gradgradcheck(Cube.apply, x)

torchviz둜 κ·Έλž˜ν”„ μ‹œκ°ν™”ν•˜κΈ°:

out = Cube.apply(x)
grad_x, = torch.autograd.grad(out, x, create_graph=True)
torchviz.make_dot((grad_x, x, out), params={"grad_x": grad_x, "x": x, "out": out})

https://user-images.githubusercontent.com/13428986/126559935-74526b4d-d419-4983-b1f0-a6ee99428531.png

결둠적으둜 μ‚¬μš©μž μ •μ˜ ν•¨μˆ˜μ˜ 이쀑 μ—­μ „νŒŒ μž‘λ™μ—¬λΆ€λŠ” autogradκ°€ μ—­μ „νŒŒ 과정을 좔적할 수 μžˆλŠλƒμ— 달렀 μžˆμŠ΅λ‹ˆλ‹€. 처음 두 μ˜ˆμ œμ—μ„œλŠ” 이쀑 μ—­μ „νŒŒκ°€ μžλ™μœΌλ‘œ λ™μž‘ν•˜λŠ” 경우λ₯Ό λ³΄μ—¬μ£Όμ—ˆκ³ , μ„Έ λ²ˆμ§Έμ™€ λ„€ 번째 μ˜ˆμ œλŠ” μΆ”μ λ˜μ§€ μ•ŠλŠ” μ—­μ „νŒŒ ν•¨μˆ˜λ₯Ό 좔적 κ°€λŠ₯ν•˜κ²Œ λ§Œλ“œλŠ” 방법을 μ„€λͺ…ν–ˆμŠ΅λ‹ˆλ‹€.