diff --git a/README.md b/README.md index b14934e..8e078a9 100644 --- a/README.md +++ b/README.md @@ -77,29 +77,68 @@ MiniChain is a minimal fully functional blockchain implemented in Python. implementing MiniChain in Python aligns with MiniChain's educational goal. -### Overview of Tasks +### Resources -* Develop a fully functional minimal blockchain in Python, with all the expected components: - peer-to-peer networking, consensus, mempool, ledger, ... +* Read this book: https://www.marabu.dev/blockchain-foundations.pdf -* Bonus task: add smart contracts to the blockchain. +--- -Candidates are expected to refine these tasks in their GSoC proposals. -It is encouraged that you develop an initial prototype during the application phase. +## Getting Started -### Requirements +### Prerequisites -* Use [PyNaCl](https://pynacl.readthedocs.io/en/latest/) library for hashing, signing transactions and verifying signatures. -* Use [Py-libp2p](https://github.com/libp2p/py-libp2p/tree/main) for p2p networking. -* Implement Proof-of-Work as the consensus protocol. -* Use accounts (instead of UTxO) as the accounting model for the ledger. -* Use as few lines of code as possible without compromising readability and understandability. -* For the bonus task, make Python itself be the language used for smart contracts, but watch out for security concerns related to executing arbitrary code from untrusted sources. +- Python 3.10+ +- Install dependencies: + ```bash + pip install -r requirements.txt + ``` + +### 1. Creating a New MiniChain +To bootstrap a brand new blockchain network from scratch, simply start a node. By default, this creates a new Genesis block. +```bash +python main.py --port 9000 --datadir ./node1_data +``` +*Note: Keep this terminal open to interact with the node via the CLI.* + +### 2. Connecting to an Existing Chain +To connect a secondary node to the network, start a new instance on a different port and point it to the seed node using the `--connect` flag. +```bash +python main.py --port 9001 --connect 127.0.0.1:9000 --datadir ./node2_data +``` +The node will automatically sync the blockchain state via the P2P network using the Fork-Choice rule. + +### 3. Mining Blocks +To confirm pending transactions, you need to mine blocks. In the interactive CLI of your node, simply type: +```text +minichain> mine +``` +This runs the Proof-of-Work algorithm, validates transactions, computes the new state root, updates your wallet with the block reward + fees, and broadcasts the block to all connected peers. -### Resources +--- -* Read this book: https://www.marabu.dev/blockchain-foundations.pdf +## Basic Operations (Interactive CLI) + +Once your node is running, you can perform basic blockchain operations directly in your terminal. +**Making a Transfer** +Send coins to another public key: +```text +minichain> send +``` +*Example: `send 8b3401abedb875aff7279b5ab58cb9a0c... 100 1`* + +**Checking Balances** +View the state of all active accounts and contracts on the chain: +```text +minichain> balance +``` + +**Viewing Network State** +```text +minichain> chain # View all blocks +minichain> peers # View connected P2P nodes +minichain> address # View your own public key +``` --- @@ -118,26 +157,28 @@ Check out the `/examples` directory for tutorials: ### Interacting via CLI Start the interactive node using `python main.py` and use the following commands: -1. **Deploy:** `deploy [amount] [gas_limit]` -2. **Call:** `call [amount] [gas_limit]` - ---- +1. **Deploy:** `deploy [amount] [fee]` +2. **Call:** `call [amount] [fee]` -## Tech Stack - -TODO: +Example deployment: +```text +minichain> deploy examples/counter.py 0 100 +``` --- -## Getting Started - -### Prerequisites +## JSON-RPC 2.0 Server -TODO +MiniChain automatically spins up a JSON-RPC 2.0 server alongside the P2P node. By default, it binds to `port 8545` (the standard EVM RPC port). External wallets and dApps can use this to interact with the chain asynchronously. -### Installation +**Example Request (Get Block Number):** +```bash +curl -X POST http://127.0.0.1:8545/ \ + -H "Content-Type: application/json" \ + -d '{"jsonrpc": "2.0", "method": "mc_blockNumber", "id": 1}' +``` -TODO +Available endpoints include: `mc_blockNumber`, `mc_getBlockByNumber`, `mc_getBalance`, and `mc_sendTransaction`. --- diff --git a/main.py b/main.py index fe1c1ce..d625ef4 100644 --- a/main.py +++ b/main.py @@ -21,13 +21,15 @@ import logging import re import sys +import os +import json from nacl.signing import SigningKey from nacl.encoding import HexEncoder from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block from minichain.rpc import JSONRPCServer -from minichain.validators import is_valid_receiver +from minichain.validators import is_valid_receiver, ValidationStatus from minichain.block import calculate_receipt_root @@ -41,9 +43,38 @@ # Wallet helpers # ────────────────────────────────────────────── -def create_wallet(): +def load_or_create_wallet(datadir: str | None): + path = datadir or "." + keystore_path = os.path.join(path, "keystore.json") + + # Security Warning: + # Keys are currently stored in unencrypted raw hex format for minimality. + # In a production environment, this file should be encrypted with a "spending password" + # so that the private key only lives in memory when decrypted by the user. + if os.path.exists(keystore_path): + try: + with open(keystore_path, "r") as f: + data = json.load(f) + sk_hex = data.get("private_key") + if sk_hex: + sk = SigningKey(bytes.fromhex(sk_hex)) + pk = sk.verify_key.encode(encoder=HexEncoder).decode() + logger.info("Loaded existing wallet from %s", keystore_path) + return sk, pk + except Exception as e: + logger.warning("Failed to load keystore: %s", e) + sk = SigningKey.generate() pk = sk.verify_key.encode(encoder=HexEncoder).decode() + + os.makedirs(path, exist_ok=True) + try: + with open(keystore_path, "w") as f: + json.dump({"private_key": sk.encode(encoder=HexEncoder).decode()}, f) + logger.info("Created new wallet at %s", keystore_path) + except Exception as e: + logger.warning("Failed to save keystore: %s", e) + return sk, pk @@ -97,7 +128,7 @@ def mine_and_process_block(chain, mempool, miner_pk): mined_block = mine_block(block) - if chain.add_block(mined_block): + if chain.add_block(mined_block) == ValidationStatus.VALID: logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(mineable_txs)) mempool.remove_transactions(mineable_txs) return mined_block @@ -117,6 +148,7 @@ def mine_and_process_block(chain, mempool, miner_pk): def make_network_handler(chain, mempool, network): """Return an async callback that processes incoming P2P messages.""" + from minichain.validators import ValidationStatus async def handler(data): msg_type = data.get("type") @@ -148,24 +180,30 @@ async def handler(data): elif msg_type == "tx": try: tx = Transaction.from_dict(payload) - if getattr(tx, "chain_id", None) != chain.chain_id: - logger.warning("Invalid chain_id in tx from %s", peer_addr) - return - if mempool.add_transaction(tx): - logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) except Exception as e: logger.warning("Invalid tx payload from %s: %s", peer_addr, e) + return ValidationStatus.MALFORMED + + if getattr(tx, "chain_id", None) != chain.chain_id: + logger.warning("Invalid chain_id in tx from %s", peer_addr) + return ValidationStatus.INVALID + + if mempool.add_transaction(tx): + logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) + return ValidationStatus.VALID + else: + return ValidationStatus.FAILED elif msg_type == "block": try: block = Block.from_dict(payload) except Exception as e: logger.warning("Invalid block payload from %s: %s", peer_addr, e) - return + return ValidationStatus.MALFORMED - if chain.add_block(block): + status = chain.add_block(block) + if status == ValidationStatus.VALID: logger.info("📥 Received Block #%d — added to chain", block.index) - # Drop only confirmed transactions so higher nonces can remain queued. mempool.remove_transactions(block.transactions) else: @@ -178,6 +216,7 @@ async def handler(data): # For a fork, request the full chain to use resolve_conflicts req = {"type": "chain_request", "data": {"start_index": 0, "limit": 1000000}} # Request full chain for reorg asyncio.create_task(network._broadcast_raw(req)) + return status elif msg_type == "chain_request": start_index = payload.get("start_index", 0) @@ -221,7 +260,7 @@ async def handler(data): for block in new_chain: if block.index <= chain.last_block.index: continue # Ignore already known blocks - if chain.add_block(block): + if chain.add_block(block) == ValidationStatus.VALID: logger.info("📥 Synced Block #%d", block.index) mempool.remove_transactions(block.transactions) else: @@ -245,35 +284,75 @@ async def handler(data): # Interactive CLI # ────────────────────────────────────────────── -HELP_TEXT = """ -╔════════════════════════════════════════════════╗ -║ MiniChain Commands ║ -╠════════════════════════════════════════════════╣ -║ balance - show all balances ║ -║ send - send coins ║ -║ mine - mine a block ║ -║ peers - show connected peers ║ -║ connect : - connect to a peer ║ -║ address - show your public key ║ -║ chain - show chain summary ║ -║ list-banned - show banned peers ║ -║ ban - ban a peer ║ -║ unban - unban a peer ║ -║ help - show this help ║ -║ quit - shut down ║ -╚════════════════════════════════════════════════╝ +C_CYAN = '\033[96m' +C_BLUE = '\033[94m' +C_YELLOW = '\033[38;2;255;205;0m' # Golden Wallet (#FFCD00) +C_GREEN = '\033[38;2;0;132;61m' # Baggy Green (#00843D) +C_RED = '\033[91m' +C_RESET = '\033[0m' +C_BOLD = '\033[1m' + +def gradient_text(text: str, c1: tuple[int, int, int], c2: tuple[int, int, int]) -> str: + """Applies a smooth horizontal color gradient to text.""" + lines = text.strip('\n').split('\n') + out = [] + max_len = max(len(line) for line in lines) if lines else 1 + + for line in lines: + line_out = "" + for i, char in enumerate(line): + t = i / max(1, max_len - 1) + r = int(c1[0] + (c2[0] - c1[0]) * t) + g = int(c1[1] + (c2[1] - c1[1]) * t) + b = int(c1[2] + (c2[2] - c1[2]) * t) + line_out += f"\033[38;2;{r};{g};{b}m{char}" + out.append(line_out + C_RESET) + return "\n".join(out) + +RAW_LOGO = r""" +███╗ ███╗██╗███╗ ██╗██╗ ██████╗██╗ ██╗ █████╗ ██╗███╗ ██╗ +████╗ ████║██║████╗ ██║██║██╔════╝██║ ██║██╔══██╗██║████╗ ██║ +██╔████╔██║██║██╔██╗ ██║██║██║ ███████║███████║██║██╔██╗ ██║ +██║╚██╔╝██║██║██║╚██╗██║██║██║ ██╔══██║██╔══██║██║██║╚██╗██║ +██║ ╚═╝ ██║██║██║ ╚████║██║╚██████╗██║ ██║██║ ██║██║██║ ╚████║ +╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝ ╚═════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝ +""" + +ASCII_LOGO = gradient_text(RAW_LOGO, (255, 205, 0), (0, 132, 61)) + +HELP_TEXT = f""" +{C_BOLD}{ASCII_LOGO}{C_RESET} +{C_CYAN}╔══════════════════════════════════════════════════════════════╗{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}balance{C_RESET} - show all balances {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}send {C_RESET} - send coins {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}deploy {C_RESET} - deploy a contract {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}call {C_RESET} - call a contract {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}mine{C_RESET} - mine a block {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}peers{C_RESET} - show connected peers {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}connect {C_RESET} - connect to a peer {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}address{C_RESET} - show your public key {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}export_key{C_RESET} - show your private key {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}import_key {C_RESET} - import a private key {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}chain{C_RESET} - show chain summary {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}help{C_RESET} - show this help {C_CYAN}║{C_RESET} +{C_CYAN}║{C_RESET} {C_GREEN}quit{C_RESET} - shut down {C_CYAN}║{C_RESET} +{C_CYAN}╚══════════════════════════════════════════════════════════════╝{C_RESET} """ -async def cli_loop(sk, pk, chain, mempool, network): +async def cli_loop(sk, pk, chain, mempool, network, datadir: str | None = None): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() print(HELP_TEXT) - print(f"Your address: {pk}\n") + + def print_prompt_info(current_pk): + print(f" {C_YELLOW}Your address:{C_RESET} {C_BOLD}{current_pk}{C_RESET}\n") + + print_prompt_info(pk) while True: try: - raw = await loop.run_in_executor(None, lambda: input("minichain> ")) + raw = await loop.run_in_executor(None, lambda: input(f"{C_CYAN}minichain>{C_RESET} ")) except (EOFError, KeyboardInterrupt): break @@ -288,9 +367,9 @@ async def cli_loop(sk, pk, chain, mempool, network): if not accounts: print(" (no accounts yet)") for addr, acc in accounts.items(): - tag = " (you)" if addr == pk else "" - contract_tag = " [Contract]" if acc.get("code") else "" - print(f" {addr[:12]}... balance={acc['balance']} nonce={acc['nonce']}{tag}{contract_tag}") + tag = f" {C_GREEN}(you){C_RESET}" if addr == pk else "" + contract_tag = f" {C_CYAN}[Contract]{C_RESET}" if acc.get("code") else "" + print(f" {C_BOLD}{addr[:12]}...{C_RESET} balance={C_YELLOW}{acc['balance']}{C_RESET} nonce={acc['nonce']}{tag}{contract_tag}") # ── send ── elif cmd == "send": @@ -320,9 +399,9 @@ async def cli_loop(sk, pk, chain, mempool, network): if mempool.add_transaction(tx): await network.broadcast_transaction(tx) - print(f" ✅ Tx sent: {amount} coins → {receiver[:12]}...") + print(f" {C_GREEN}✅ Tx sent:{C_RESET} {amount} coins → {receiver[:12]}...") else: - print(" ❌ Transaction rejected (invalid sig, duplicate, or mempool full).") + print(f" {C_RED}❌ Transaction rejected{C_RESET} (invalid sig, duplicate, or mempool full).") # ── deploy ── elif cmd == "deploy": @@ -416,6 +495,34 @@ async def cli_loop(sk, pk, chain, mempool, network): elif cmd == "address": print(f" {pk}") + # ── export_key ── + elif cmd == "export_key": + sk_hex = sk.encode(encoder=HexEncoder).decode() + print(f" {C_YELLOW}Private Key:{C_RESET} {sk_hex}") + print(f" {C_RED}Warning: Never share this key!{C_RESET}") + + # ── import_key ── + elif cmd == "import_key": + if len(parts) < 2: + print(" Usage: import_key ") + continue + + try: + new_sk = SigningKey(bytes.fromhex(parts[1])) + new_pk = new_sk.verify_key.encode(encoder=HexEncoder).decode() + + path = datadir or "." + keystore_path = os.path.join(path, "keystore.json") + os.makedirs(path, exist_ok=True) + with open(keystore_path, "w") as f: + json.dump({"private_key": new_sk.encode(encoder=HexEncoder).decode()}, f) + + sk, pk = new_sk, new_pk + print(f" {C_GREEN}✅ Key imported successfully!{C_RESET}") + print_prompt_info(pk) + except Exception as e: + print(f" {C_RED}❌ Failed to import key: {e}{C_RESET}") + # ── chain ── elif cmd == "chain": print(f" Chain length: {len(chain.chain)} blocks") @@ -426,7 +533,7 @@ async def cli_loop(sk, pk, chain, mempool, network): # ── list-banned ── elif cmd == "list-banned": from minichain.persistence import get_banned_peers - banned = get_banned_peers() + banned = get_banned_peers(path=datadir or ".") if not banned: print(" No peers are currently banned.") else: @@ -441,7 +548,8 @@ async def cli_loop(sk, pk, chain, mempool, network): continue peer_id = parts[1] from minichain.persistence import ban_peer - ban_peer(peer_id, reason="Manual ban via CLI") + ban_peer(peer_id, reason="Manual ban via CLI", path=datadir or ".") + await network.disconnect_peer(f"peer:{peer_id}") print(f" ✅ Peer {peer_id} banned.") # ── unban ── @@ -451,7 +559,7 @@ async def cli_loop(sk, pk, chain, mempool, network): continue peer_id = parts[1] from minichain.persistence import unban_peer - unban_peer(peer_id) + unban_peer(peer_id, path=datadir or ".") print(f" ✅ Peer {peer_id} unbanned.") # ── help ── @@ -472,7 +580,7 @@ async def cli_loop(sk, pk, chain, mempool, network): async def run_node(port: int, host: str, connect_to: str | None, fund: int, datadir: str | None): """Boot the node, optionally connect to a peer, then enter the CLI.""" - sk, pk = create_wallet() + sk, pk = load_or_create_wallet(datadir) # Load existing chain from disk, or start fresh chain = None @@ -493,7 +601,7 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data chain = Blockchain() mempool = Mempool() - network = P2PNetwork() + network = P2PNetwork(data_path=datadir or ".") handler = make_network_handler(chain, mempool, network) network.register_handler(handler) @@ -534,7 +642,7 @@ async def on_peer_connected(writer): await network.connect_to_peer(connect_to) try: - await cli_loop(sk, pk, chain, mempool, network) + await cli_loop(sk, pk, chain, mempool, network, datadir) finally: # Save chain to disk on shutdown if datadir: @@ -555,7 +663,7 @@ def main(): parser.add_argument("--port", type=int, default=9000, help="TCP port to listen on (default: 9000)") parser.add_argument("--connect", type=str, default=None, help="Peer address to connect to (multiaddr)") parser.add_argument("--fund", type=int, default=100, help="Initial coins to fund this wallet (default: 100)") - parser.add_argument("--datadir", type=str, default=None, help="Directory to save/load blockchain state (enables persistence)") + parser.add_argument("--datadir", type=str, default=".minichain", help="Directory to save/load blockchain state (enables persistence)") args = parser.parse_args() logging.basicConfig( diff --git a/minichain/chain.py b/minichain/chain.py index 1ed9b84..8aacd17 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -253,6 +253,10 @@ def resolve_conflicts(self, new_chain_list) -> tuple[bool, list]: logger.warning("Reorg failed: Invalid receipt root at block %s. Expected %s, got %s", block.index, computed_receipt_root, block.receipt_root) return False, [] + if [r.to_dict() for r in block.receipts] != [r.to_dict() for r in receipts]: + logger.warning("Reorg failed: Receipt payload mismatch at block %s", block.index) + return False, [] + if block.state_root != temp_state.state_root(): logger.warning("Reorg failed: Invalid state root at block %s", block.index) return False, [] diff --git a/minichain/p2p.py b/minichain/p2p.py index 28efe38..33cc9f7 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -7,6 +7,7 @@ import json import logging import threading +import time import trio import queue @@ -15,16 +16,34 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from multiaddr import Multiaddr from .serialization import canonical_json_hash, canonical_json_dumps +from .validators import ValidationStatus +from .persistence import ban_peer, is_peer_banned logger = logging.getLogger(__name__) SUPPORTED_MESSAGE_TYPES = {"hello", "tx", "block", "chain_request", "chain_response"} PROTOCOL_ID = TProtocol("/minichain/1.0.0") +MAX_FRAME_BYTES = 1 * 1024 * 1024 # 1 MB + +# Misbehavior thresholds — all four are overridable per P2PNetwork instance. +MALFORMED_THRESHOLD = 15 # N: accumulated malformed messages before ban +FAILED_THRESHOLD = 15 # M: accumulated failed messages before ban +INVALID_THRESHOLD = 1 # L: accumulated invalid messages before ban (1 = immediate) +DECAY_INTERVAL_MINUTES = 10 # T: counter half-life period in minutes + class P2PNetwork: """Lightweight peer-to-peer networking using libp2p.""" - def __init__(self, handler_callback=None): + def __init__( + self, + handler_callback=None, + data_path: str = ".", + malformed_threshold: int = MALFORMED_THRESHOLD, + failed_threshold: int = FAILED_THRESHOLD, + invalid_threshold: int = INVALID_THRESHOLD, + decay_interval_minutes: float = DECAY_INTERVAL_MINUTES, + ): self._handler_callback = handler_callback self._on_peer_connected = None self._seen_tx_ids = set() @@ -34,6 +53,26 @@ def __init__(self, handler_callback=None): self._peer_count = 0 self._peer_count_lock = threading.Lock() + # Misbehavior tracking + self.data_path = data_path + self.thresholds = { + "malformed": malformed_threshold, + "failed": failed_threshold, + "invalid": invalid_threshold, + } + self.decay_interval_minutes = decay_interval_minutes + # { peer_id_str -> {"malformed": int, "failed": int, "invalid": int} } + self._peer_counters: dict = {} + + if self.decay_interval_minutes <= 0: + raise ValueError(f"decay_interval_minutes must be positive, got {self.decay_interval_minutes}") + if self.thresholds["malformed"] <= 0: + raise ValueError(f"malformed_threshold must be positive, got {self.thresholds['malformed']}") + if self.thresholds["failed"] <= 0: + raise ValueError(f"failed_threshold must be positive, got {self.thresholds['failed']}") + if self.thresholds["invalid"] <= 0: + raise ValueError(f"invalid_threshold must be positive, got {self.thresholds['invalid']}") + def register_handler(self, handler_callback): self._handler_callback = handler_callback @@ -44,9 +83,10 @@ async def start(self, port: int = 9000, host: str = "127.0.0.1"): self.port = port self.host_addr = host self.loop = asyncio.get_running_loop() - + threading.Thread(target=trio.run, args=(self._trio_main,), daemon=True).start() asyncio.create_task(self._asyncio_reader()) + asyncio.create_task(self._decay_counters()) logger.info(f"Network: Starting libp2p on port {port}") async def stop(self): @@ -101,17 +141,116 @@ def peer_count(self) -> int: with self._peer_count_lock: return self._peer_count + # ── misbehavior helpers ────────────────────────────────────────────────── + + def _increment_counter(self, peer_id: str, category: str) -> bool: + """ + Increment the named counter (malformed/failed/invalid) for peer_id. + Returns True if any counter now meets or exceeds its threshold. + Called only from the asyncio thread — no lock needed. + """ + if peer_id not in self._peer_counters: + self._peer_counters[peer_id] = {"malformed": 0, "failed": 0, "invalid": 0} + self._peer_counters[peer_id][category] += 1 + counts = self._peer_counters[peer_id] + return counts[category] >= self.thresholds[category] + + async def _handle_validation_status( + self, peer_id: str, peer_addr: str, status: ValidationStatus + ): + """ + Apply misbehavior policy for a single ValidationStatus event: + MALFORMED → always disconnect; ban if counter >= threshold + FAILED → drop silently; ban + disconnect if counter >= threshold + INVALID → always ban + disconnect (threshold configurable, default=1) + """ + category = { + ValidationStatus.MALFORMED: "malformed", + ValidationStatus.FAILED: "failed", + ValidationStatus.INVALID: "invalid", + }.get(status) + if category is None: + return + + exceeded = self._increment_counter(peer_id, category) + + if exceeded: + ban_peer(peer_id, reason=f"{category}_threshold_exceeded", path=self.data_path) + logger.warning( + "Banned peer %s: %s threshold (%d) exceeded", + peer_id, category, self.thresholds[category], + ) + + always_disconnect = status in (ValidationStatus.MALFORMED, ValidationStatus.INVALID) + if always_disconnect or exceeded: + await self.disconnect_peer(peer_addr) + + async def _decay_counters(self): + """ + Half-life decay: every decay_interval_minutes minutes divide all per-peer + counters by 2 (integer floor division). Runs for the lifetime of the node. + """ + interval_seconds = self.decay_interval_minutes * 60 + while True: + await asyncio.sleep(interval_seconds) + for counts in self._peer_counters.values(): + for key in counts: + counts[key] //= 2 + self._peer_counters = { + peer_id: counts + for peer_id, counts in self._peer_counters.items() + if any(v > 0 for v in counts.values()) + } + + # ── asyncio reader ─────────────────────────────────────────────────────── + async def _asyncio_reader(self): while True: - try: msg = await self.loop.run_in_executor(None, self._to_asyncio.get) - except Exception: continue - + try: + msg = await self.loop.run_in_executor(None, self._to_asyncio.get) + except Exception: + continue + if msg[0] == "MSG": data = msg[1] - msg_type, payload = data.get("type"), data.get("data") - if msg_type not in SUPPORTED_MESSAGE_TYPES or self._is_duplicate(msg_type, payload): continue - self._mark_seen(msg_type, payload) - if self._handler_callback: await self._handler_callback(data) + msg_type = data.get("type") + payload = data.get("data") + peer_addr = data.get("_peer_addr", "") + peer_id = ( + peer_addr[len("peer:"):] if peer_addr.startswith("peer:") else peer_addr + ) + + if msg_type not in SUPPORTED_MESSAGE_TYPES: + continue + try: + if self._is_duplicate(msg_type, payload): + continue + except Exception: + await self._handle_validation_status(peer_id, peer_addr, ValidationStatus.MALFORMED) + continue + + status = None + if self._handler_callback: + status = await self._handler_callback(data) + + # Only apply interception for content-bearing message types. + if msg_type in ("tx", "block") and status is not None: + await self._handle_validation_status(peer_id, peer_addr, status) + + if status is None or status == ValidationStatus.VALID: + try: + self._mark_seen(msg_type, payload) + except Exception: + pass + + elif msg[0] == "MALFORMED": + # JSON parse failure signalled from the Trio thread. + peer_addr = msg[1] + peer_id = ( + peer_addr[len("peer:"):] if peer_addr.startswith("peer:") else peer_addr + ) + await self._handle_validation_status(peer_id, peer_addr, ValidationStatus.MALFORMED) + elif msg[0] == "PEER_CONNECTED": class MockWriter: def write(self, data): self.data = data @@ -119,44 +258,70 @@ async def drain(self): pass if self._on_peer_connected: writer = MockWriter() await self._on_peer_connected(writer) - if hasattr(writer, 'data'): + if hasattr(writer, "data"): try: req = json.loads(writer.data.decode().strip()) await self._broadcast_raw(req) - except Exception: pass + except Exception: + pass + + # ── trio main ──────────────────────────────────────────────────────────── async def _trio_main(self): host = new_host() listen_addr = Multiaddr(f"/ip4/{self.host_addr}/tcp/{self.port}") await host.get_network().listen(listen_addr) print(f" Network Multiaddr: {listen_addr}/p2p/{host.get_id().to_string()}") - + streams = [] async def stream_handler(stream): + peer_id = str(stream.muxed_conn.peer_id) + addr = f"peer:{peer_id}" + + # Reject banned peers before doing anything else. + if is_peer_banned(peer_id, path=self.data_path): + logger.warning("Rejected connection from banned peer %s", peer_id) + try: + await stream.reset() + except Exception: + pass + return + streams.append(stream) with self._peer_count_lock: self._peer_count += 1 - peer_id = stream.muxed_conn.peer_id - addr = f"peer:{peer_id}" self._to_asyncio.put(("PEER_CONNECTED", None)) + try: + buffer = b"" while True: data = await stream.read(4096) - if not data: break - for line in data.split(b'\n'): - if not line: continue + if not data: + break + buffer += data + if len(buffer) > MAX_FRAME_BYTES: + self._to_asyncio.put(("MALFORMED", addr)) + break + *lines, buffer = buffer.split(b"\n") + for line in lines: + if not line.strip(): + continue try: - msg = json.loads(line.decode().strip()) - msg["_peer_addr"] = addr - self._to_asyncio.put(("MSG", msg)) - except Exception: pass - except Exception: pass + parsed = json.loads(line.decode().strip()) + parsed["_peer_addr"] = addr + self._to_asyncio.put(("MSG", parsed)) + except Exception: + # Signal the asyncio side to apply MALFORMED policy. + self._to_asyncio.put(("MALFORMED", addr)) + except Exception: + pass + if stream in streams: streams.remove(stream) with self._peer_count_lock: self._peer_count -= 1 - + host.set_stream_handler(PROTOCOL_ID, stream_handler) async def check_queue(): @@ -164,7 +329,8 @@ async def check_queue(): try: while not self._to_trio.empty(): cmd, arg = self._to_trio.get_nowait() - if cmd == "STOP": return True + if cmd == "STOP": + return True elif cmd == "CONNECT": try: maddr = Multiaddr(arg) @@ -177,27 +343,34 @@ async def check_queue(): elif cmd == "BROADCAST": msg = (canonical_json_dumps(arg) + "\n").encode() for s in list(streams): - try: await s.write(msg) - except Exception: pass + try: + await s.write(msg) + except Exception: + pass elif cmd == "UNICAST": target_addr, payload = arg msg = (canonical_json_dumps(payload) + "\n").encode() for s in list(streams): - addr = f"peer:{s.muxed_conn.peer_id}" - if addr == target_addr: - try: await s.write(msg) - except Exception: pass + s_addr = f"peer:{s.muxed_conn.peer_id}" + if s_addr == target_addr: + try: + await s.write(msg) + except Exception: + pass elif cmd == "DISCONNECT": for s in list(streams): - addr = f"peer:{s.muxed_conn.peer_id}" - if addr == arg: - try: await s.reset() - except Exception: pass + s_addr = f"peer:{s.muxed_conn.peer_id}" + if s_addr == arg: + try: + await s.reset() + except Exception: + pass if s in streams: streams.remove(s) with self._peer_count_lock: self._peer_count -= 1 - except Exception: pass + except Exception: + pass await trio.sleep(0.1) async with trio.open_nursery() as nursery: diff --git a/minichain/state.py b/minichain/state.py index 413fec5..bbe3ad6 100644 --- a/minichain/state.py +++ b/minichain/state.py @@ -91,9 +91,12 @@ def validate_and_apply(self, tx): Validate and apply a transaction. Returns: Receipt|None """ - # Semantic validation: amount must be an integer and non-negative + # Semantic validation: amount and fee must be non-negative integers if not isinstance(tx.amount, int) or tx.amount < 0: return None + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return None return self.apply_transaction(tx) def validate_and_apply_with_status(self, tx): @@ -104,24 +107,38 @@ def validate_and_apply_with_status(self, tx): from .validators import ValidationStatus if not isinstance(tx.amount, int) or tx.amount < 0: return ValidationStatus.MALFORMED, None - + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return ValidationStatus.MALFORMED, None + status = self.verify_transaction_logic(tx) if status != ValidationStatus.VALID: return status, None - - # We know it's valid, so apply_transaction will succeed and return a Receipt - return ValidationStatus.VALID, self.apply_transaction(tx) + + # verify_transaction_logic already passed — skip the second call inside apply_transaction. + return ValidationStatus.VALID, self._apply_validated_tx(tx) def apply_transaction(self, tx): """ - Applies transaction and mutates state. - Returns: Receipt object if mathematically valid, None if invalid. + Validates and applies a transaction. + Returns: Receipt object if valid, None if invalid. """ + if not isinstance(tx.amount, int) or tx.amount < 0: + return None + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return None from .validators import ValidationStatus - status = self.verify_transaction_logic(tx) - if status != ValidationStatus.VALID: + if self.verify_transaction_logic(tx) != ValidationStatus.VALID: return None + return self._apply_validated_tx(tx) + def _apply_validated_tx(self, tx): + """ + Apply a transaction that has already passed verify_transaction_logic. + Mutates state and returns a Receipt. Never call this directly — use + apply_transaction() or validate_and_apply_with_status() instead. + """ sender = self.accounts[tx.sender] total_cost = tx.amount + getattr(tx, 'fee', 0) diff --git a/tests/test_difficulty.py b/tests/test_difficulty.py index 0176f9b..d15c853 100644 --- a/tests/test_difficulty.py +++ b/tests/test_difficulty.py @@ -1,6 +1,7 @@ import unittest from minichain import Blockchain, Block from minichain.pow import mine_block +from minichain.validators import ValidationStatus class TestEMADifficulty(unittest.TestCase): def test_difficulty_adjustment(self): @@ -16,7 +17,7 @@ def test_difficulty_adjustment(self): ts = chain.last_block.timestamp + 1 block1 = Block(index=1, previous_hash=chain.last_block.hash, transactions=[], timestamp=ts, difficulty=chain.current_difficulty, state_root=chain.state.state_root()) mined_block1 = mine_block(block1) - self.assertTrue(chain.add_block(mined_block1)) + self.assertEqual(chain.add_block(mined_block1), ValidationStatus.VALID) self.assertEqual(chain.current_difficulty, 4) # Slow mining: timestamp 5000ms apart @@ -24,7 +25,7 @@ def test_difficulty_adjustment(self): ts = chain.last_block.timestamp + 5000 block2 = Block(index=2, previous_hash=chain.last_block.hash, transactions=[], timestamp=ts, difficulty=chain.current_difficulty, state_root=chain.state.state_root()) mined_block2 = mine_block(block2) - self.assertTrue(chain.add_block(mined_block2)) + self.assertEqual(chain.add_block(mined_block2), ValidationStatus.VALID) self.assertEqual(chain.current_difficulty, 3) def test_reorg_difficulty_validation(self): diff --git a/tests/test_persistence_runtime.py b/tests/test_persistence_runtime.py index 894ccca..73265e5 100644 --- a/tests/test_persistence_runtime.py +++ b/tests/test_persistence_runtime.py @@ -12,7 +12,7 @@ class FakeNetwork: - def __init__(self): + def __init__(self, **kwargs): self.handler = None self.peer_count = 0 self._on_peer_connected = None @@ -84,7 +84,7 @@ async def test_run_node_loads_existing_sqlite_snapshot(self): chain = self._chain_with_tx() save(chain, self.tmpdir) - async def fake_cli_loop(sk, pk, loaded_chain, mempool, network): + async def fake_cli_loop(sk, pk, loaded_chain, mempool, network, datadir=None): self.assertEqual(len(loaded_chain.chain), len(chain.chain)) self.assertEqual(loaded_chain.last_block.hash, chain.last_block.hash) self.assertEqual(loaded_chain.state.accounts, chain.state.accounts) @@ -103,13 +103,13 @@ async def fake_cli_loop(sk, pk, loaded_chain, mempool, network): async def test_run_node_saves_sqlite_snapshot_on_shutdown(self): fixed_sk, fixed_pk = _make_keypair() - async def fake_cli_loop(sk, pk, chain, mempool, network): + async def fake_cli_loop(sk, pk, chain, mempool, network, datadir=None): self.assertEqual(pk, fixed_pk) self.assertEqual(chain.state.get_account(pk)["balance"], 25) with patch.object(main_module, "P2PNetwork", FakeNetwork), patch.object( main_module, "cli_loop", fake_cli_loop - ), patch.object(main_module, "create_wallet", return_value=(fixed_sk, fixed_pk)): + ), patch.object(main_module, "load_or_create_wallet", return_value=(fixed_sk, fixed_pk)): await main_module.run_node( port=9401, host="127.0.0.1",