Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/github_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,21 @@
from __future__ import annotations

import contextlib
import logging
import time
from typing import Generator

import jwt
import requests

GITHUB_API_TIMEOUT = 10
GITHUB_API_MAX_ATTEMPTS = 3
GITHUB_API_RETRYABLE_ERRORS = (
requests.exceptions.ConnectionError,
requests.exceptions.SSLError,
requests.exceptions.Timeout,
)


class GithubAppToken:
def __init__(self, private_key, app_id) -> None:
Expand All @@ -19,8 +28,9 @@ def __init__(self, private_key, app_id) -> None:
# configured by the GitHub App and expire after one hour.
@contextlib.contextmanager
def get_token(self, installation_id: int) -> Generator[str, None, None]:
req = requests.post(
url=f"https://api.github.com/app/installations/{installation_id}/access_tokens",
req = _request_github_api(
"POST",
f"https://api.github.com/app/installations/{installation_id}/access_tokens",
headers=self.headers,
)
req.raise_for_status()
Expand All @@ -29,7 +39,8 @@ def get_token(self, installation_id: int) -> Generator[str, None, None]:
# This token expires in an hour
yield resp["token"]
finally:
requests.delete(
_request_github_api(
"DELETE",
"https://api.github.com/installation/token",
headers={"Authorization": f"token {resp['token']}"},
)
Expand All @@ -51,3 +62,25 @@ def get_authentication_header(self, private_key, app_id):
"Accept": "application/vnd.github.v3+json",
"Authorization": f"Bearer {jwt_token}",
}


def _request_github_api(method: str, url: str, **kwargs) -> requests.Response:
for attempt in range(1, GITHUB_API_MAX_ATTEMPTS + 1):
try:
return requests.request(
method,
url,
timeout=GITHUB_API_TIMEOUT,
**kwargs,
)
except GITHUB_API_RETRYABLE_ERRORS:
if attempt == GITHUB_API_MAX_ATTEMPTS:
raise

logging.warning(
"Transient GitHub App API request failed; retrying",
exc_info=True,
)
time.sleep(2 ** (attempt - 1))

raise RuntimeError("Unreachable GitHub App API retry state")
71 changes: 71 additions & 0 deletions tests/test_github_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

from unittest.mock import Mock

import pytest
import requests

from src.github_app import GITHUB_API_TIMEOUT, GithubAppToken


class DummyResponse:
def __init__(self, payload=None, error=None) -> None:
self.payload = payload or {}
self.error = error

def raise_for_status(self):
if self.error:
raise self.error

def json(self):
return self.payload


@pytest.fixture
def github_app_token():
token = GithubAppToken.__new__(GithubAppToken)
token.headers = {"Authorization": "Bearer jwt"}
return token


def test_get_token_retries_transient_token_creation_errors(
github_app_token,
monkeypatch,
):
calls = []
sleep = Mock()
monkeypatch.setattr("src.github_app.time.sleep", sleep)

def request(method, url, **kwargs):
calls.append((method, url, kwargs))
post_attempts = len([call for call in calls if call[0] == "POST"])
if method == "POST" and post_attempts == 1:
raise requests.exceptions.ConnectTimeout("connect timeout")
return DummyResponse({"token": "github-token"})

monkeypatch.setattr("src.github_app.requests.request", request)

with github_app_token.get_token(42) as token:
assert token == "github-token"

post_calls = [call for call in calls if call[0] == "POST"]
assert len(post_calls) == 2
assert calls[-1][0] == "DELETE"
assert all(call[2]["timeout"] == GITHUB_API_TIMEOUT for call in calls)
sleep.assert_called_once_with(1)


def test_get_token_does_not_retry_http_status_errors(github_app_token, monkeypatch):
calls = []

def request(method, url, **kwargs):
calls.append((method, url, kwargs))
return DummyResponse(error=requests.HTTPError("server error"))

monkeypatch.setattr("src.github_app.requests.request", request)

with pytest.raises(requests.HTTPError):
with github_app_token.get_token(42):
pass

assert len(calls) == 1
Loading