diff --git a/main.py b/main.py index c5738b9..3f77b25 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,8 @@ import logging import re import sys +import os +import json from nacl.signing import SigningKey from nacl.encoding import HexEncoder @@ -41,9 +43,38 @@ # Wallet helpers # ────────────────────────────────────────────── -def create_wallet(): +def load_or_create_wallet(datadir: str | None): + path = datadir or "." + keystore_path = os.path.join(path, "keystore.json") + + # Security Warning: + # Keys are currently stored in unencrypted raw hex format for minimality. + # In a production environment, this file should be encrypted with a "spending password" + # so that the private key only lives in memory when decrypted by the user. + if os.path.exists(keystore_path): + try: + with open(keystore_path, "r") as f: + data = json.load(f) + sk_hex = data.get("private_key") + if sk_hex: + sk = SigningKey(bytes.fromhex(sk_hex)) + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + logger.info("Loaded existing wallet from %s", keystore_path) + return sk, pk + except Exception as e: + logger.warning("Failed to load keystore: %s", e) + sk = SigningKey.generate() pk = sk.verify_key.encode(encoder=HexEncoder).decode() + + os.makedirs(path, exist_ok=True) + try: + with open(keystore_path, "w") as f: + json.dump({"private_key": sk.encode(encoder=HexEncoder).decode()}, f) + logger.info("Created new wallet at %s", keystore_path) + except Exception as e: + logger.warning("Failed to save keystore: %s", e) + return sk, pk @@ -292,6 +323,8 @@ def gradient_text(text: str, c1: tuple[int, int, int], c2: tuple[int, int, int]) {C_CYAN}║{C_RESET} {C_GREEN}peers{C_RESET} - show connected peers {C_CYAN}║{C_RESET} {C_CYAN}║{C_RESET} {C_GREEN}connect {C_RESET} - connect to a peer {C_CYAN}║{C_RESET} {C_CYAN}║{C_RESET} {C_GREEN}address{C_RESET} - show your public key {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}export_key{C_RESET} - show your private key {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}import_key {C_RESET} - import a private key {C_CYAN}║{C_RESET} {C_CYAN}║{C_RESET} {C_GREEN}chain{C_RESET} - show chain summary {C_CYAN}║{C_RESET} {C_CYAN}║{C_RESET} {C_GREEN}help{C_RESET} - show this help {C_CYAN}║{C_RESET} {C_CYAN}║{C_RESET} {C_GREEN}quit{C_RESET} - shut down {C_CYAN}║{C_RESET} @@ -303,7 +336,11 @@ async def cli_loop(sk, pk, chain, mempool, network): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() print(HELP_TEXT) - print(f" {C_YELLOW}Your address:{C_RESET} {C_BOLD}{pk}{C_RESET}\n") + + def print_prompt_info(current_pk): + print(f" {C_YELLOW}Your address:{C_RESET} {C_BOLD}{current_pk}{C_RESET}\n") + + print_prompt_info(pk) while True: try: @@ -450,6 +487,34 @@ async def cli_loop(sk, pk, chain, mempool, network): elif cmd == "address": print(f" {pk}") + # ── export_key ── + elif cmd == "export_key": + sk_hex = sk.encode(encoder=HexEncoder).decode() + print(f" {C_YELLOW}Private Key:{C_RESET} {sk_hex}") + print(f" {C_RED}Warning: Never share this key!{C_RESET}") + + # ── import_key ── + elif cmd == "import_key": + if len(parts) < 2: + print(" Usage: import_key ") + continue + + try: + new_sk = SigningKey(bytes.fromhex(parts[1])) + new_pk = new_sk.verify_key.encode(encoder=HexEncoder).decode() + + path = datadir or "." + keystore_path = os.path.join(path, "keystore.json") + os.makedirs(path, exist_ok=True) + with open(keystore_path, "w") as f: + json.dump({"private_key": new_sk.encode(encoder=HexEncoder).decode()}, f) + + sk, pk = new_sk, new_pk + print(f" {C_GREEN}✅ Key imported successfully!{C_RESET}") + print_prompt_info(pk) + except Exception as e: + print(f" {C_RED}❌ Failed to import key: {e}{C_RESET}") + # ── chain ── elif cmd == "chain": print(f" Chain length: {len(chain.chain)} blocks") @@ -506,7 +571,7 @@ async def cli_loop(sk, pk, chain, mempool, network): async def run_node(port: int, host: str, connect_to: str | None, fund: int, datadir: str | None): """Boot the node, optionally connect to a peer, then enter the CLI.""" - sk, pk = create_wallet() + sk, pk = load_or_create_wallet(datadir) # Load existing chain from disk, or start fresh chain = None @@ -589,7 +654,7 @@ def main(): parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (multiaddr)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") - parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)") + parser.add_argument("--datadir", type=str, default=".minichain", help="Directory to save/load blockchain state (enables persistence)") args = parser.parse_args() logging.basicConfig( diff --git a/tests/test_persistence_runtime.py b/tests/test_persistence_runtime.py index 894ccca..51b90f4 100644 --- a/tests/test_persistence_runtime.py +++ b/tests/test_persistence_runtime.py @@ -109,7 +109,7 @@ async def fake_cli_loop(sk, pk, chain, mempool, network): with patch.object(main_module, "P2PNetwork", FakeNetwork), patch.object( main_module, "cli_loop", fake_cli_loop - ), patch.object(main_module, "create_wallet", return_value=(fixed_sk, fixed_pk)): + ), patch.object(main_module, "load_or_create_wallet", return_value=(fixed_sk, fixed_pk)): await main_module.run_node( port=9401, host="127.0.0.1",