|
| 1 | +#!/usr/bin/env python3 |
| 2 | +""" |
| 3 | +Regenerate `src/lib/model-multipliers.generated.ts` from the latest release of |
| 4 | +rajbos/github-copilot-model-notifier. |
| 5 | +
|
| 6 | +The source repo publishes a markdown table of current models in the release body |
| 7 | +under a `### Current Models` heading. This script parses that table and fully |
| 8 | +overwrites the generated TypeScript file. The companion file |
| 9 | +`model-multipliers.legacy.ts` contains hand-maintained backward-compat entries |
| 10 | +and is never touched here. |
| 11 | +
|
| 12 | +Usage: |
| 13 | + python scripts/update-model-multipliers.py |
| 14 | +
|
| 15 | +Env: |
| 16 | + GITHUB_TOKEN Optional. Used to authenticate the GitHub API request. |
| 17 | +""" |
| 18 | + |
| 19 | +from __future__ import annotations |
| 20 | + |
| 21 | +import json |
| 22 | +import os |
| 23 | +import re |
| 24 | +import sys |
| 25 | +import urllib.error |
| 26 | +import urllib.request |
| 27 | +from pathlib import Path |
| 28 | + |
| 29 | +RELEASE_URL = ( |
| 30 | + "https://api.github.com/repos/rajbos/" |
| 31 | + "github-copilot-model-notifier/releases/latest" |
| 32 | +) |
| 33 | +REPO_ROOT = Path(__file__).resolve().parent.parent |
| 34 | +GENERATED_PATH = REPO_ROOT / "src" / "lib" / "model-multipliers.generated.ts" |
| 35 | + |
| 36 | + |
| 37 | +def fetch_latest_release() -> dict: |
| 38 | + """Fetch the latest release JSON from the source repo.""" |
| 39 | + headers = { |
| 40 | + "Accept": "application/vnd.github+json", |
| 41 | + "User-Agent": "github-copilot-premium-reqs-usage-updater", |
| 42 | + } |
| 43 | + token = os.environ.get("GITHUB_TOKEN") |
| 44 | + if token: |
| 45 | + headers["Authorization"] = f"Bearer {token}" |
| 46 | + |
| 47 | + req = urllib.request.Request(RELEASE_URL, headers=headers) |
| 48 | + try: |
| 49 | + with urllib.request.urlopen(req, timeout=30) as resp: |
| 50 | + return json.loads(resp.read().decode("utf-8")) |
| 51 | + except urllib.error.HTTPError as e: |
| 52 | + sys.stderr.write( |
| 53 | + f"HTTP error {e.code} fetching latest release: {e.reason}\n" |
| 54 | + ) |
| 55 | + raise |
| 56 | + |
| 57 | + |
| 58 | +def parse_models_table(body: str) -> dict[str, float]: |
| 59 | + """Parse the `### Current Models` markdown table from the release body. |
| 60 | +
|
| 61 | + Returns a mapping of model name -> paid multiplier (float). |
| 62 | + Multiplier 'Not applicable' is mapped to 0. |
| 63 | + """ |
| 64 | + if not body: |
| 65 | + raise ValueError("Release body is empty; cannot parse models table.") |
| 66 | + |
| 67 | + m = re.search( |
| 68 | + r"###\s+Current Models\s*\n(.+?)(?=\n#{1,6}\s|\Z)", |
| 69 | + body, |
| 70 | + re.DOTALL, |
| 71 | + ) |
| 72 | + if not m: |
| 73 | + raise ValueError( |
| 74 | + "Could not find '### Current Models' section in release body." |
| 75 | + ) |
| 76 | + section = m.group(1) |
| 77 | + |
| 78 | + models: dict[str, float] = {} |
| 79 | + for line in section.splitlines(): |
| 80 | + line = line.strip() |
| 81 | + if not line.startswith("|") or not line.endswith("|"): |
| 82 | + continue |
| 83 | + cells = [c.strip() for c in line.strip("|").split("|")] |
| 84 | + if len(cells) < 3: |
| 85 | + continue |
| 86 | + if cells[0].lower() == "model": |
| 87 | + continue |
| 88 | + if set(cells[0]) <= set("-: "): |
| 89 | + continue |
| 90 | + |
| 91 | + name = cells[0] |
| 92 | + raw_mult = cells[2] |
| 93 | + if raw_mult.lower() == "not applicable": |
| 94 | + mult: float = 0.0 |
| 95 | + else: |
| 96 | + try: |
| 97 | + mult = float(raw_mult) |
| 98 | + except ValueError: |
| 99 | + sys.stderr.write( |
| 100 | + f"Warning: skipping {name!r} - " |
| 101 | + f"unparseable multiplier {raw_mult!r}\n" |
| 102 | + ) |
| 103 | + continue |
| 104 | + models[name] = mult |
| 105 | + |
| 106 | + if not models: |
| 107 | + raise ValueError("Parsed zero models from release body.") |
| 108 | + return models |
| 109 | + |
| 110 | + |
| 111 | +def format_multiplier(value: float) -> str: |
| 112 | + """Render a multiplier as JS/TS literal: drop .0 for whole numbers.""" |
| 113 | + if value == int(value): |
| 114 | + return str(int(value)) |
| 115 | + return repr(value) |
| 116 | + |
| 117 | + |
| 118 | +def js_string_literal(s: str) -> str: |
| 119 | + """Render a string as a single-quoted JS/TS literal.""" |
| 120 | + escaped = s.replace("\\", "\\\\").replace("'", "\\'") |
| 121 | + return f"'{escaped}'" |
| 122 | + |
| 123 | + |
| 124 | +def render_generated_file(models: dict[str, float], tag_name: str) -> str: |
| 125 | + sorted_names = sorted(models.keys(), key=str.lower) |
| 126 | + default_names = [n for n in sorted_names if models[n] == 0] |
| 127 | + |
| 128 | + lines = [ |
| 129 | + "// AUTO-GENERATED FILE — DO NOT EDIT BY HAND.", |
| 130 | + "//", |
| 131 | + "// Source: https://github.com/rajbos/github-copilot-model-notifier (latest release)", |
| 132 | + "// Updated by: scripts/update-model-multipliers.py (run daily via GitHub Actions)", |
| 133 | + "//", |
| 134 | + "// To make manual changes, edit `model-multipliers.legacy.ts` instead.", |
| 135 | + "", |
| 136 | + f"export const CURRENT_MODELS_SOURCE_RELEASE = {js_string_literal(tag_name)};", |
| 137 | + "", |
| 138 | + "export const CURRENT_MODEL_MULTIPLIERS: Record<string, number> = {", |
| 139 | + ] |
| 140 | + for name in sorted_names: |
| 141 | + lines.append( |
| 142 | + f" {js_string_literal(name)}: {format_multiplier(models[name])}," |
| 143 | + ) |
| 144 | + lines.append("};") |
| 145 | + lines.append("") |
| 146 | + lines.append( |
| 147 | + "// Models with a 0x multiplier (free) are treated as \"Default\" " |
| 148 | + "and grouped together." |
| 149 | + ) |
| 150 | + lines.append("export const CURRENT_DEFAULT_MODELS: string[] = [") |
| 151 | + for name in default_names: |
| 152 | + lines.append(f" {js_string_literal(name)},") |
| 153 | + lines.append("];") |
| 154 | + lines.append("") # trailing newline |
| 155 | + return "\n".join(lines) |
| 156 | + |
| 157 | + |
| 158 | +def parse_existing_models(content: str) -> dict[str, float]: |
| 159 | + """Parse the existing CURRENT_MODEL_MULTIPLIERS object for a diff summary.""" |
| 160 | + m = re.search( |
| 161 | + r"CURRENT_MODEL_MULTIPLIERS[^=]*=\s*\{(.*?)\};", |
| 162 | + content, |
| 163 | + re.DOTALL, |
| 164 | + ) |
| 165 | + if not m: |
| 166 | + return {} |
| 167 | + body = m.group(1) |
| 168 | + models: dict[str, float] = {} |
| 169 | + entry_re = re.compile(r"'((?:\\'|[^'])*)'\s*:\s*([0-9.]+)") |
| 170 | + for line in body.splitlines(): |
| 171 | + line = line.split("//", 1)[0] |
| 172 | + em = entry_re.search(line) |
| 173 | + if em: |
| 174 | + name = em.group(1).replace("\\'", "'") |
| 175 | + try: |
| 176 | + models[name] = float(em.group(2)) |
| 177 | + except ValueError: |
| 178 | + continue |
| 179 | + return models |
| 180 | + |
| 181 | + |
| 182 | +def print_diff_summary( |
| 183 | + old: dict[str, float], new: dict[str, float], tag_name: str |
| 184 | +) -> None: |
| 185 | + added = sorted(set(new) - set(old), key=str.lower) |
| 186 | + removed = sorted(set(old) - set(new), key=str.lower) |
| 187 | + changed = sorted( |
| 188 | + (n for n in set(new) & set(old) if old[n] != new[n]), key=str.lower |
| 189 | + ) |
| 190 | + |
| 191 | + print(f"Source release: {tag_name}") |
| 192 | + if not (added or removed or changed): |
| 193 | + print("No model changes detected.") |
| 194 | + return |
| 195 | + if added: |
| 196 | + print("Added:") |
| 197 | + for n in added: |
| 198 | + print(f" + {n} = {format_multiplier(new[n])}") |
| 199 | + if removed: |
| 200 | + print("Removed:") |
| 201 | + for n in removed: |
| 202 | + print(f" - {n} (was {format_multiplier(old[n])})") |
| 203 | + if changed: |
| 204 | + print("Changed:") |
| 205 | + for n in changed: |
| 206 | + print( |
| 207 | + f" ~ {n}: {format_multiplier(old[n])} -> " |
| 208 | + f"{format_multiplier(new[n])}" |
| 209 | + ) |
| 210 | + |
| 211 | + |
| 212 | +def main() -> int: |
| 213 | + release = fetch_latest_release() |
| 214 | + tag_name = release.get("tag_name", "<unknown>") |
| 215 | + body = release.get("body") or "" |
| 216 | + |
| 217 | + new_models = parse_models_table(body) |
| 218 | + |
| 219 | + old_content = ( |
| 220 | + GENERATED_PATH.read_text(encoding="utf-8") |
| 221 | + if GENERATED_PATH.exists() |
| 222 | + else "" |
| 223 | + ) |
| 224 | + old_models = parse_existing_models(old_content) |
| 225 | + |
| 226 | + new_content = render_generated_file(new_models, tag_name) |
| 227 | + |
| 228 | + if new_content == old_content: |
| 229 | + print(f"Source release: {tag_name}") |
| 230 | + print( |
| 231 | + f"{GENERATED_PATH.relative_to(REPO_ROOT)} is already up to date." |
| 232 | + ) |
| 233 | + return 0 |
| 234 | + |
| 235 | + GENERATED_PATH.write_text(new_content, encoding="utf-8") |
| 236 | + print_diff_summary(old_models, new_models, tag_name) |
| 237 | + print(f"Updated {GENERATED_PATH.relative_to(REPO_ROOT)}") |
| 238 | + return 0 |
| 239 | + |
| 240 | + |
| 241 | +if __name__ == "__main__": |
| 242 | + sys.exit(main()) |
0 commit comments