From e9ba87ebd636f52c34d837ec2e91021e3c6d20dc Mon Sep 17 00:00:00 2001 From: smcio Date: Mon, 11 May 2026 21:11:48 +0200 Subject: [PATCH 1/4] improvements to BlockSource and various removals of inactive & obsolete code --- pkg/bankhash/bankhash.go | 85 +- pkg/bankhash/merkle.go | 1 - pkg/block/block.go | 1 - pkg/blockstream/block_source.go | 33 +- pkg/blockstream/block_source_test.go | 54 + pkg/replay/block.go | 74 +- pkg/replay/consensus_fallback.go | 14 + pkg/replay/consensus_fallback_test.go | 40 + pkg/replay/eah_workaround_test.go | 54 - pkg/replay/hash.go | 347 ----- pkg/replay/hash_test.go | 78 - pkg/replay/leader_schedule.go | 1772 ++++++++++++++++++++- pkg/replay/leader_schedule_local.go | 2072 ------------------------- pkg/replay/profile.go | 47 - 14 files changed, 1867 insertions(+), 2805 deletions(-) create mode 100644 pkg/replay/consensus_fallback.go create mode 100644 pkg/replay/consensus_fallback_test.go delete mode 100644 pkg/replay/eah_workaround_test.go delete mode 100644 pkg/replay/hash.go delete mode 100644 pkg/replay/hash_test.go delete mode 100644 pkg/replay/leader_schedule_local.go delete mode 100644 pkg/replay/profile.go diff --git a/pkg/bankhash/bankhash.go b/pkg/bankhash/bankhash.go index e57e46f5..d2342fb1 100644 --- a/pkg/bankhash/bankhash.go +++ b/pkg/bankhash/bankhash.go @@ -3,16 +3,13 @@ package bankhash import ( "crypto/sha256" "encoding/binary" - "fmt" "sort" "time" "github.com/Overclock-Validator/mithril/pkg/accounts" - "github.com/Overclock-Validator/mithril/pkg/accountsdb" "github.com/Overclock-Validator/mithril/pkg/features" "github.com/Overclock-Validator/mithril/pkg/metrics" "github.com/Overclock-Validator/mithril/pkg/mlog" - "github.com/Overclock-Validator/mithril/pkg/safemath" "github.com/Overclock-Validator/mithril/pkg/sealevel" "github.com/gagliardetto/solana-go" "github.com/zeebo/blake3" @@ -67,12 +64,7 @@ func CalculateBankHash(slotCtx *sealevel.SlotCtx, writableAccts []*accounts.Acco finalHasher.Write(acctsLtHashBytes) bankHash = finalHasher.Sum(nil) } else { - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar - if shouldIncludeEah(epochSchedule, slotCtx) { - panic("EAH no longer supported by mithril") - } else { - bankHash = hasher.Sum(nil) - } + bankHash = hasher.Sum(nil) } return bankHash @@ -114,31 +106,6 @@ func calculateSingleAcctHash(acct accounts.Account) acctHash { return newAcctHash(acct.Key, hasher.Sum(nil)) } -func calculateSingleAcctHashOnly(acct accounts.Account) []byte { - hasher := blake3.New() - - var lamportBytes [8]byte - binary.LittleEndian.PutUint64(lamportBytes[:], acct.Lamports) - _, _ = hasher.Write(lamportBytes[:]) - - var rentEpochBytes [8]byte - binary.LittleEndian.PutUint64(rentEpochBytes[:], acct.RentEpoch) - _, _ = hasher.Write(rentEpochBytes[:]) - - _, _ = hasher.Write(acct.Data) - - if acct.Executable { - _, _ = hasher.Write([]byte{1}) - } else { - _, _ = hasher.Write([]byte{0}) - } - - _, _ = hasher.Write(acct.Owner[:]) - _, _ = hasher.Write(acct.Key[:]) - - return hasher.Sum(nil) -} - func calculateAccountHashes(accts []*accounts.Account) []acctHash { pairs := make([]acctHash, 0, len(accts)) for _, acct := range accts { @@ -180,53 +147,3 @@ func calculateAcctsDeltaHash(accts []*accounts.Account) []byte { return computeMerkleRootLoop(hashes) } - -func calculateEpochAcctsHash(acctsDb *accountsdb.AccountsDb) []byte { - mlog.Log.Infof("computing EAH") - - // get all pubkeys in acctsdb - allKeys := acctsDb.AllKeys() - - // compute acct hashes - hashes := make([][]byte, len(allKeys)) - for i, pk := range allKeys { - pkObj := solana.PublicKeyFromBytes(pk) - - acct, err := acctsDb.GetAccount(0, pkObj) - if err != nil { - panic(fmt.Sprintf("unable to fetch acct in EAH calculation: %s", pkObj)) - } - if acct.Lamports != 0 { - hashes[i] = calculateSingleAcctHashOnly(*acct) - } - } - - // merkel root loop - return computeMerkleRootLoop(hashes) -} - -const maxLockoutHistory = 31 -const calculateIntervalBuffer = 150 -const minimumCalculationInterval = maxLockoutHistory + calculateIntervalBuffer - -func isEnabledThisEpoch(epochSchedule *sealevel.SysvarEpochSchedule, epoch uint64) bool { - slotsPerEpoch := epochSchedule.SlotsInEpoch(epoch) - calculationOffsetStart := slotsPerEpoch / 4 - calculationOffsetStop := (slotsPerEpoch / 4) * 3 - calculationInterval := safemath.SaturatingSubU64(calculationOffsetStop, calculationOffsetStart) - - return calculationInterval >= minimumCalculationInterval -} - -func shouldIncludeEah(epochSchedule *sealevel.SysvarEpochSchedule, slotCtx *sealevel.SlotCtx) bool { - if !isEnabledThisEpoch(epochSchedule, slotCtx.Epoch) { - return false - } - - slotsPerEpoch := epochSchedule.SlotsInEpoch(slotCtx.Epoch) - calculationOffsetStop := (slotsPerEpoch / 4) * 3 - firstSlotInEpoch := epochSchedule.FirstSlotInEpoch(slotCtx.Epoch) - stopSlot := safemath.SaturatingAddU64(firstSlotInEpoch, calculationOffsetStop) - - return slotCtx.ParentSlot < stopSlot && slotCtx.Slot >= stopSlot -} diff --git a/pkg/bankhash/merkle.go b/pkg/bankhash/merkle.go index 20912078..5640cbb3 100644 --- a/pkg/bankhash/merkle.go +++ b/pkg/bankhash/merkle.go @@ -2,7 +2,6 @@ package bankhash import "crypto/sha256" -const maxMerkleHeight = 16 const merkleFanout = 16 func divCeil(x uint64, y uint64) uint64 { diff --git a/pkg/block/block.go b/pkg/block/block.go index 30d6053c..30f087cc 100644 --- a/pkg/block/block.go +++ b/pkg/block/block.go @@ -19,7 +19,6 @@ type Block struct { Versions []uint8 Entries []*TxEntry BankHash [32]byte - EpochAcctsHash []byte EahWorkaroundBankhash []byte HasEahWorkaround bool ParentBankhash [32]byte diff --git a/pkg/blockstream/block_source.go b/pkg/blockstream/block_source.go index 1ae40cbe..fcc4930d 100644 --- a/pkg/blockstream/block_source.go +++ b/pkg/blockstream/block_source.go @@ -555,6 +555,25 @@ func (bs *BlockSource) currentModeString() string { return "catchup" } +func (bs *BlockSource) rewindConsensusManagedFrontierForRPCFallbackLocked() (waitingSlot uint64, previousWaitingSlot uint64) { + waitingSlot = bs.nextSlotToSend + previousWaitingSlot = waitingSlot + if !bs.consensusManagedLightbringer { + return waitingSlot, previousWaitingSlot + } + + replayNextSlot := bs.startSlot + if lastExecuted := bs.lastExecutedSlot.Load(); lastExecuted != 0 { + replayNextSlot = lastExecuted + 1 + } + if replayNextSlot == 0 || replayNextSlot >= waitingSlot { + return waitingSlot, previousWaitingSlot + } + + bs.nextSlotToSend = replayNextSlot + return replayNextSlot, previousWaitingSlot +} + func (bs *BlockSource) forceRPCForCatchup(gap uint64) { if bs.sourceType != BlockSourceLightbringer || bs.lightbringerEndpoint == "" { return @@ -564,13 +583,13 @@ func (bs *BlockSource) forceRPCForCatchup(gap uint64) { bs.lightbringerCooldownUntil.Store(0) oldHandoff := bs.lightbringerHandoffSlot.Swap(0) wasActive := bs.lightbringerActive.Swap(false) - bs.lightbringerNeedRPCResume.Store(false) + bs.lightbringerNeedRPCResume.Store(true) bs.clearLightbringerGapWatch() bs.resetLightbringerRepairSlot() clearedPrefetched := bs.clearBufferedLightbringerBlocks() bs.reorderMu.Lock() - waitingSlot := bs.nextSlotToSend + waitingSlot, previousWaitingSlot := bs.rewindConsensusManagedFrontierForRPCFallbackLocked() removedSlots := make([]uint64, 0) for slot, blk := range bs.reorderBuffer { if blk != nil && blk.FromLightbringer && slot >= waitingSlot { @@ -612,11 +631,21 @@ func (bs *BlockSource) forceRPCForCatchup(gap uint64) { } if wasActive { + if previousWaitingSlot != waitingSlot { + mlog.Log.Warnf("BLOCK SOURCE SWITCH: LIGHTBRINGER -> RPC at slot %d | reason=lost_tip | gap=%d | rewound_emission_frontier_from=%d | cleared_buffered_lightbringer=%d | dropped_prefetched_lightbringer=%d", + waitingSlot, gap, previousWaitingSlot, len(removedSlots), clearedPrefetched) + return + } mlog.Log.Warnf("BLOCK SOURCE SWITCH: LIGHTBRINGER -> RPC at slot %d | reason=lost_tip | gap=%d | cleared_buffered_lightbringer=%d | dropped_prefetched_lightbringer=%d", waitingSlot, gap, len(removedSlots), clearedPrefetched) return } if oldHandoff != 0 || len(removedSlots) > 0 || clearedPrefetched > 0 { + if previousWaitingSlot != waitingSlot { + mlog.Log.Warnf("BLOCK SOURCE STATUS: abandoning pending Lightbringer handoff and forcing RPC catchup | waiting_slot=%d | gap=%d | rewound_emission_frontier_from=%d | cleared_buffered_lightbringer=%d | dropped_prefetched_lightbringer=%d", + waitingSlot, gap, previousWaitingSlot, len(removedSlots), clearedPrefetched) + return + } mlog.Log.Warnf("BLOCK SOURCE STATUS: abandoning pending Lightbringer handoff and forcing RPC catchup | waiting_slot=%d | gap=%d | cleared_buffered_lightbringer=%d | dropped_prefetched_lightbringer=%d", waitingSlot, gap, len(removedSlots), clearedPrefetched) return diff --git a/pkg/blockstream/block_source_test.go b/pkg/blockstream/block_source_test.go index 075dc2a9..f0379d3c 100644 --- a/pkg/blockstream/block_source_test.go +++ b/pkg/blockstream/block_source_test.go @@ -634,6 +634,60 @@ func TestSetLastExecutedSlotAdvancesDeferredLightbringerFrontier(t *testing.T) { } } +func TestForceRPCForCatchupRewindsConsensusManagedFrontier(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + ConsensusManagedLightbringer: true, + }) + + bs.lightbringerActive.Store(true) + bs.lastExecutedSlot.Store(120) + bs.nextSlotToSend = 150 + bs.reorderBuffer[121] = &b.Block{Slot: 121, FromLightbringer: true} + bs.reorderBuffer[149] = &b.Block{Slot: 149, FromLightbringer: true} + bs.reorderBuffer[151] = &b.Block{Slot: 151, FromLightbringer: false} + bs.slotState[121] = slotDone + bs.slotState[149] = slotDone + bs.slotState[151] = slotInflight + bs.retrySlots = []uint64{119, 121, 149, 151} + + bs.forceRPCForCatchup(64) + + if got := bs.nextSlotToSend; got != 121 { + t.Fatalf("expected RPC catchup frontier to rewind to replay's next slot 121, got %d", got) + } + if bs.lightbringerActive.Load() { + t.Fatalf("expected Lightbringer to be marked inactive") + } + if !bs.lightbringerNeedRPCResume.Load() { + t.Fatalf("expected scheduler to be told to resume RPC from the rewound frontier") + } + if _, exists := bs.reorderBuffer[121]; exists { + t.Fatalf("expected Lightbringer slot 121 to be dropped for RPC refetch") + } + if _, exists := bs.reorderBuffer[149]; exists { + t.Fatalf("expected Lightbringer slot 149 to be dropped for RPC refetch") + } + if _, exists := bs.reorderBuffer[151]; !exists { + t.Fatalf("expected RPC buffered slot 151 to remain") + } + if _, exists := bs.slotState[121]; exists { + t.Fatalf("expected slot state 121 to be cleared") + } + if _, exists := bs.slotState[149]; exists { + t.Fatalf("expected slot state 149 to be cleared") + } + if _, exists := bs.slotState[151]; exists { + t.Fatalf("expected consensus-managed catchup to clear future RPC slot state for rescheduling") + } + if len(bs.retrySlots) != 1 || bs.retrySlots[0] != 119 { + t.Fatalf("expected only retries before the replay frontier to remain, got %+v", bs.retrySlots) + } +} + func TestEmitOrderedBlocksDirectlyStreamsConsensusManagedLightbringerObservations(t *testing.T) { bs := NewBlockSource(&BlockSourceOpts{ SourceType: BlockSourceLightbringer, diff --git a/pkg/replay/block.go b/pkg/replay/block.go index 63c0a4b5..96f4e00d 100644 --- a/pkg/replay/block.go +++ b/pkg/replay/block.go @@ -359,18 +359,6 @@ func isNativeProgram(pubkey solana.PublicKey) bool { } } -func isSysvar(pubkey solana.PublicKey) bool { - if pubkey == sealevel.SysvarClockAddr || pubkey == sealevel.SysvarEpochScheduleAddr || - pubkey == sealevel.SysvarFeesAddr || pubkey == sealevel.SysvarInstructionsAddr || - pubkey == sealevel.SysvarRecentBlockHashesAddr || pubkey == sealevel.SysvarRentAddr || - pubkey == a.SysvarRewardsAddr || pubkey == sealevel.SysvarSlotHashesAddr || - pubkey == sealevel.SysvarSlotHistoryAddr || pubkey == sealevel.SysvarStakeHistoryAddr { - return true - } else { - return false - } -} - func cacheConstantSysvars(acctsDb *accountsdb.AccountsDb) { { acct, err := acctsDb.GetAccount(0, sealevel.SysvarEpochScheduleAddr) @@ -849,10 +837,7 @@ func setupInitialVoteAcctsAndStakeAccts(acctsDb *accountsdb.AccountsDb, block *b func configureInitialBlock(acctsDb *accountsdb.AccountsDb, block *b.Block, mithrilState *state.MithrilState, - epochCtx *ReplayCtx, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcClient *rpcclient.RpcClient, - auxBackupEndpoints []string) error { + epochSchedule *sealevel.SysvarEpochSchedule) error { // Read from state file manifest_* fields (required) if mithrilState.ManifestParentBankhash == "" { @@ -891,8 +876,6 @@ func configureInitialBlock(acctsDb *accountsdb.AccountsDb, } block.LatestEvictedBlockhash = evictedHash - block.EpochAcctsHash = epochCtx.EpochAcctsHash - setupInitialVoteAcctsAndStakeAccts(acctsDb, block) configureGlobalCtx(block) @@ -931,11 +914,8 @@ func reconstructFeeRateGovernor(s *state.MithrilState) *sealevel.FeeRateGovernor } func configureBlock(block *b.Block, - epochCtx *ReplayCtx, lastSlotCtx *sealevel.SlotCtx, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcClient *rpcclient.RpcClient, - auxBackupEndpoints []string) error { + epochSchedule *sealevel.SysvarEpochSchedule) error { copy(block.ParentBankhash[:], lastSlotCtx.FinalBankhash) block.AcctsLtHash = lastSlotCtx.AcctsLtHash @@ -943,7 +923,6 @@ func configureBlock(block *b.Block, block.EpochStakesPerVoteAcct = lastSlotCtx.VoteAccts block.ParentSlot = lastSlotCtx.Slot block.LatestEvictedBlockhash = lastSlotCtx.LatestEvictedBlockhash - block.EpochAcctsHash = epochCtx.EpochAcctsHash block.PrevFeeRateGovernor = lastSlotCtx.FeeRateGovernor block.PrevNumSignatures = lastSlotCtx.NumSignatures block.TotalEpochStake = lastSlotCtx.TotalEpochStake @@ -1005,16 +984,12 @@ func configureInitialBlockFromResume(acctsDb *accountsdb.AccountsDb, block *b.Block, resumeState *ResumeState, mithrilState *state.MithrilState, - epochCtx *ReplayCtx, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcClient *rpcclient.RpcClient, - auxBackupEndpoints []string) error { + epochSchedule *sealevel.SysvarEpochSchedule) error { // Use resume state for parent info (the actual last replayed slot) copy(block.ParentBankhash[:], resumeState.ParentBankhash) block.ParentSlot = resumeState.ParentSlot block.AcctsLtHash = resumeState.AcctsLtHash - block.EpochAcctsHash = epochCtx.EpochAcctsHash // Reconstruct PrevFeeRateGovernor from state file static fields + resume dynamic fields prevFeeRateGovernor := reconstructFeeRateGovernor(mithrilState) @@ -1416,7 +1391,6 @@ func ReplayBlocks( var readyConsensusPath *pendingConsensusPath observedConsensusBlocks := make(map[uint64]*b.Block) - consensusCatchupHoldLogged := false var opts *blockstream.BlockSourceOpts if useLightbringer { @@ -1549,26 +1523,18 @@ func ReplayBlocks( stats := blockStream.GetFetchStats() if consensusBufferedExecutionActive && !stats.IsNearTip { anchorSlot := currentConsensusAnchorSlot() - hasObservedBlocks := len(observedConsensusBlocks) > 0 - hasReadyDecisions := readyConsensusPath != nil && len(readyConsensusPath.decisions) > 0 - hasUnresolvedGapAhead := stats.NextSlot > anchorSlot+1 - - if hasObservedBlocks || hasReadyDecisions || hasUnresolvedGapAhead { - if !consensusCatchupHoldLogged { - mlog.Log.Warnf("forkchoice: retaining buffered execution across catchup fallback at slot %d because anchor %d still has unresolved slots before next emitted slot %d", - triggerSlot, anchorSlot, stats.NextSlot) - consensusCatchupHoldLogged = true - } - return + discardedObservedBlocks := len(observedConsensusBlocks) + readyDecisionCount := 0 + if readyConsensusPath != nil { + readyDecisionCount = len(readyConsensusPath.decisions) } consensusBufferedExecutionActive = false readyConsensusPath = nil clearObservedConsensusBlocks() observeConsensusAnchor() - mlog.Log.Warnf("forkchoice: suspending buffered execution at slot %d because block source left near-tip mode (anchor=%d)", - triggerSlot, currentConsensusAnchorSlot()) - consensusCatchupHoldLogged = false + mlog.Log.Warnf("forkchoice: suspending buffered execution at slot %d because block source left near-tip mode; RPC catchup will continue from anchor %d (discarded_observed_blocks=%d discarded_ready_decisions=%d next_emitted_slot=%d)", + triggerSlot, anchorSlot, discardedObservedBlocks, readyDecisionCount, stats.NextSlot) } } @@ -1582,7 +1548,6 @@ func ReplayBlocks( return nil } consensusBufferedExecutionActive = true - consensusCatchupHoldLogged = false readyConsensusPath = nil observeConsensusAnchor() pruneObservedConsensusBlocks(currentConsensusAnchorSlot()) @@ -1702,6 +1667,19 @@ func ReplayBlocks( syncConsensusBufferedExecutionMode(block.Slot) + if block.FromLightbringer { + stats := blockStream.GetFetchStats() + if shouldDiscardLightbringerObservationAfterFallback(isLive, useLightbringer, block, stats) { + modeStr := "catchup" + if stats.IsNearTip { + modeStr = "near-tip" + } + mlog.Log.Warnf("forkchoice: discarding stale Lightbringer observation for slot %d after source fallback (mode=%s current_source=%s anchor=%d next_emitted_slot=%d)", + block.Slot, modeStr, stats.CurrentSource, currentConsensusAnchorSlot(), stats.NextSlot) + continue + } + } + if err := observeBlockForConsensus(block); err != nil { if errors.Is(err, forkchoice.ErrEquivocation) { result.Error = fmt.Errorf("forkchoice: equivocation detected at slot %d", block.Slot) @@ -1729,7 +1707,6 @@ func ReplayBlocks( mlog.Log.Warnf("forkchoice: failed to resolve a confirmed path from anchor %d after observing slot %d: %v", currentConsensusAnchorSlot(), block.Slot, err) result.Error = err - break } } if result.Error != nil { @@ -1790,13 +1767,13 @@ func ReplayBlocks( if lastSlotCtx == nil { if resumeState != nil { // RESUME: Use resume state + state file (for static fields) - configErr = configureInitialBlockFromResume(acctsDb, block, resumeState, mithrilState, replayCtx, epochSchedule, rpcc, rpcBackups) + configErr = configureInitialBlockFromResume(acctsDb, block, resumeState, mithrilState, epochSchedule) } else { // FRESH START: Use state file manifest_* fields - configErr = configureInitialBlock(acctsDb, block, mithrilState, replayCtx, epochSchedule, rpcc, rpcBackups) + configErr = configureInitialBlock(acctsDb, block, mithrilState, epochSchedule) } } else { - configErr = configureBlock(block, replayCtx, lastSlotCtx, epochSchedule, rpcc, rpcBackups) + configErr = configureBlock(block, lastSlotCtx, epochSchedule) } if configErr != nil { mlog.Log.Errorf("FATAL: block configuration failed: %v", configErr) @@ -2425,7 +2402,6 @@ func newSlotCtx(block *b.Block, accts accounts.Accounts, parentAccts accounts.Ac VoteTimestamps: block.VoteTimestamps, TotalEpochStake: block.TotalEpochStake, - EpochsAcctHash: block.EpochAcctsHash, EahWorkaroundBankhash: block.EahWorkaroundBankhash, HasEahWorkaround: block.HasEahWorkaround, diff --git a/pkg/replay/consensus_fallback.go b/pkg/replay/consensus_fallback.go new file mode 100644 index 00000000..d34c9eb4 --- /dev/null +++ b/pkg/replay/consensus_fallback.go @@ -0,0 +1,14 @@ +package replay + +import ( + b "github.com/Overclock-Validator/mithril/pkg/block" + "github.com/Overclock-Validator/mithril/pkg/blockstream" +) + +func shouldDiscardLightbringerObservationAfterFallback(isLive, useLightbringer bool, block *b.Block, stats blockstream.FetchStatsSnapshot) bool { + return isLive && + useLightbringer && + block != nil && + block.FromLightbringer && + (!stats.IsNearTip || stats.CurrentSource != "lightbringer") +} diff --git a/pkg/replay/consensus_fallback_test.go b/pkg/replay/consensus_fallback_test.go new file mode 100644 index 00000000..c03ccfae --- /dev/null +++ b/pkg/replay/consensus_fallback_test.go @@ -0,0 +1,40 @@ +package replay + +import ( + "testing" + + b "github.com/Overclock-Validator/mithril/pkg/block" + "github.com/Overclock-Validator/mithril/pkg/blockstream" +) + +func TestShouldDiscardLightbringerObservationAfterFallback(t *testing.T) { + lightbringerBlock := &b.Block{Slot: 123, FromLightbringer: true} + + if !shouldDiscardLightbringerObservationAfterFallback(true, true, lightbringerBlock, blockstream.FetchStatsSnapshot{ + IsNearTip: false, + CurrentSource: "rpc", + }) { + t.Fatalf("expected Lightbringer observation to be discarded after catchup fallback") + } + + if !shouldDiscardLightbringerObservationAfterFallback(true, true, lightbringerBlock, blockstream.FetchStatsSnapshot{ + IsNearTip: true, + CurrentSource: "rpc", + }) { + t.Fatalf("expected Lightbringer observation to be discarded while near-tip has not handed back to Lightbringer") + } + + if shouldDiscardLightbringerObservationAfterFallback(true, true, lightbringerBlock, blockstream.FetchStatsSnapshot{ + IsNearTip: true, + CurrentSource: "lightbringer", + }) { + t.Fatalf("expected active Lightbringer observations to be retained") + } + + if shouldDiscardLightbringerObservationAfterFallback(true, true, &b.Block{Slot: 123}, blockstream.FetchStatsSnapshot{ + IsNearTip: false, + CurrentSource: "rpc", + }) { + t.Fatalf("expected RPC block to be retained") + } +} diff --git a/pkg/replay/eah_workaround_test.go b/pkg/replay/eah_workaround_test.go deleted file mode 100644 index fbd672a7..00000000 --- a/pkg/replay/eah_workaround_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package replay - -import ( - "fmt" - "testing" - - "github.com/Overclock-Validator/mithril/pkg/base58" - "github.com/Overclock-Validator/mithril/pkg/rpcclient" - "github.com/stretchr/testify/assert" -) - -func TestEahWorkaround(t *testing.T) { - client := rpcclient.NewRpcClient("https://api.mainnet-beta.solana.com/") - - // test #1 - bankHash, err := fetchBankhashForSlot(client, 337646774) - assert.NoError(t, err) - - bankHashStr := base58.Encode(bankHash) - assert.Equal(t, "GzahP43kqpouTJrufyEehMhpjbu5BDvzPjLxbkzD647z", bankHashStr) - fmt.Printf("bankhash: %s\n", base58.Encode(bankHash)) - - // test #2 - bankHash, err = fetchBankhashForSlot(client, 337646505) - assert.NoError(t, err) - - bankHashStr = base58.Encode(bankHash) - assert.Equal(t, "GdCmxQrHfh2dgVZwjVX6SvPWLZUG5TDMXB339fVyMuhh", bankHashStr) - fmt.Printf("bankhash: %s\n", base58.Encode(bankHash)) - - // test #3 - bankHash, err = fetchBankhashForSlot(client, 337646220) - assert.NoError(t, err) - - bankHashStr = base58.Encode(bankHash) - assert.Equal(t, "MEtyFqQajLAfskQmbw28kTMSy2ASrg9KyShiM1TT2t6", bankHashStr) - fmt.Printf("bankhash: %s\n", base58.Encode(bankHash)) - - // test #4 - bankHash, err = fetchBankhashForSlot(client, 337645795) - assert.NoError(t, err) - - bankHashStr = base58.Encode(bankHash) - assert.Equal(t, "4ojc7a9ad4SVzWteAx2cvGH3JUbMvz2rXUv4ogr5DXYD", bankHashStr) - fmt.Printf("bankhash: %s\n", base58.Encode(bankHash)) - - // test #5 - bankHash, err = fetchBankhashForSlot(client, 337638540) - assert.NoError(t, err) - - bankHashStr = base58.Encode(bankHash) - assert.Equal(t, "8nPhRvJwtPia6NGigPbWr89FFpKD8mzWHovAbjxx6doM", bankHashStr) - fmt.Printf("bankhash: %s\n", base58.Encode(bankHash)) -} diff --git a/pkg/replay/hash.go b/pkg/replay/hash.go deleted file mode 100644 index 0afe77b3..00000000 --- a/pkg/replay/hash.go +++ /dev/null @@ -1,347 +0,0 @@ -package replay - -import ( - "bytes" - "crypto/sha256" - "encoding/binary" - "fmt" - "runtime" - "slices" - "sync" - - "github.com/Overclock-Validator/mithril/pkg/accounts" - "github.com/Overclock-Validator/mithril/pkg/accountsdb" - "github.com/Overclock-Validator/mithril/pkg/block" - "github.com/Overclock-Validator/mithril/pkg/mlog" - "github.com/Overclock-Validator/mithril/pkg/rpcclient" - "github.com/Overclock-Validator/mithril/pkg/safemath" - "github.com/Overclock-Validator/mithril/pkg/sealevel" - bin "github.com/gagliardetto/binary" - "github.com/gagliardetto/solana-go" - "github.com/gagliardetto/solana-go/rpc" - "github.com/zeebo/blake3" -) - -type acctHash struct { - Pubkey solana.PublicKey - Hash [32]byte -} - -func newAcctHash(pubkey solana.PublicKey, hash []byte) acctHash { - pair := acctHash{Pubkey: pubkey} - copy(pair.Hash[:], hash) - return pair -} - -func calculateSingleAcctHash(acct accounts.Account) acctHash { - hasher := blake3.New() - - var lamportBytes [8]byte - binary.LittleEndian.PutUint64(lamportBytes[:], acct.Lamports) - _, _ = hasher.Write(lamportBytes[:]) - - var rentEpochBytes [8]byte - binary.LittleEndian.PutUint64(rentEpochBytes[:], acct.RentEpoch) - _, _ = hasher.Write(rentEpochBytes[:]) - - _, _ = hasher.Write(acct.Data) - - if acct.Executable { - _, _ = hasher.Write([]byte{1}) - } else { - _, _ = hasher.Write([]byte{0}) - } - - _, _ = hasher.Write(acct.Owner[:]) - _, _ = hasher.Write(acct.Key[:]) - - /*h := sha256.New() - h.Write(acct.Data) - - fmt.Printf("acct: pubkey %s, lamports %d, owner %s, rent_epoch %d, data hash: %s\n", acct.Key, acct.Lamports, solana.PublicKeyFromBytes(acct.Owner[:]), acct.RentEpoch, solana.HashFromBytes(h.Sum(nil)))*/ - - return newAcctHash(acct.Key, hasher.Sum(nil)) -} - -func calculateSingleAcctHashOnly(acct accounts.Account) []byte { - hasher := blake3.New() - - var lamportBytes [8]byte - binary.LittleEndian.PutUint64(lamportBytes[:], acct.Lamports) - _, _ = hasher.Write(lamportBytes[:]) - - var rentEpochBytes [8]byte - binary.LittleEndian.PutUint64(rentEpochBytes[:], acct.RentEpoch) - _, _ = hasher.Write(rentEpochBytes[:]) - - _, _ = hasher.Write(acct.Data) - - if acct.Executable { - _, _ = hasher.Write([]byte{1}) - } else { - _, _ = hasher.Write([]byte{0}) - } - - _, _ = hasher.Write(acct.Owner[:]) - _, _ = hasher.Write(acct.Key[:]) - - return hasher.Sum(nil) -} - -func calculateAccountHashes(accts []*accounts.Account) []acctHash { - if len(accts) == 0 { - return []acctHash{} - } - - numWorkers := runtime.NumCPU() - if numWorkers > len(accts) { - numWorkers = len(accts) - } - - pairs := make([]acctHash, len(accts)) - chunkSize := (len(accts) + numWorkers - 1) / numWorkers - - var wg sync.WaitGroup - for i := 0; i < numWorkers; i++ { - wg.Add(1) - go func(workerID int) { - defer wg.Done() - start := workerID * chunkSize - end := start + chunkSize - if end > len(accts) { - end = len(accts) - } - - for j := start; j < end; j++ { - acct := accts[j] - if acct.Lamports == 0 { - pairs[j] = newAcctHash(acct.Key, nil) - } else { - pairs[j] = calculateSingleAcctHash(*acct) - } - } - }(i) - } - - wg.Wait() - return pairs -} - -const maxMerkleHeight = 16 -const merkleFanout = 16 - -func divCeil(x uint64, y uint64) uint64 { - result := x / y - if (x % y) != 0 { - result++ - } - return result -} - -func computeMerkleRootLoop(acctHashes [][]byte) []byte { - if len(acctHashes) == 0 { - return nil - } - - totalHashes := uint64(len(acctHashes)) - chunks := divCeil(totalHashes, merkleFanout) - - results := make([][]byte, chunks) - - for i := uint64(0); i < chunks; i++ { - startIdx := i * merkleFanout - endIdx := min(startIdx+merkleFanout, totalHashes) - - hasher := sha256.New() - a := acctHashes[startIdx:endIdx] - - for _, h := range a { - hasher.Write(h) - } - - results[i] = hasher.Sum(nil) - } - - if len(results) == 1 { - return results[0] - } else { - return computeMerkleRootLoop(results) - } -} - -func calculateAcctsDeltaHash(accts []*accounts.Account) []byte { - acctHashes := calculateAccountHashes(accts) - - // sort by pubkey - slices.SortFunc(acctHashes, func(a, b acctHash) int { - return bytes.Compare(a.Pubkey[:], b.Pubkey[:]) - }) - - hashes := make([][]byte, len(acctHashes)) - for idx, ah := range acctHashes { - hashes[idx] = make([]byte, 32) - copy(hashes[idx], ah.Hash[:]) - } - - return computeMerkleRootLoop(hashes) -} - -func calculateEpochAcctsHash(acctsDb *accountsdb.AccountsDb) []byte { - mlog.Log.Infof("computing EAH") - - // get all pubkeys in acctsdb - allKeys := acctsDb.AllKeys() - - // compute acct hashes - hashes := make([][]byte, len(allKeys)) - for i, pk := range allKeys { - pkObj := solana.PublicKeyFromBytes(pk) - - acct, err := acctsDb.GetAccount(0, pkObj) - if err != nil { - panic(fmt.Sprintf("unable to fetch acct in EAH calculation: %s", pkObj)) - } - if acct.Lamports != 0 { - hashes[i] = calculateSingleAcctHashOnly(*acct) - } - } - - // merkel root loop - return computeMerkleRootLoop(hashes) -} - -const maxLockoutHistory = 31 -const calculateIntervalBuffer = 150 -const minimumCalculationInterval = maxLockoutHistory + calculateIntervalBuffer - -func isEnabledThisEpoch(epochSchedule *sealevel.SysvarEpochSchedule, epoch uint64) bool { - slotsPerEpoch := epochSchedule.SlotsInEpoch(epoch) - calculationOffsetStart := slotsPerEpoch / 4 - calculationOffsetStop := (slotsPerEpoch / 4) * 3 - calculationInterval := safemath.SaturatingSubU64(calculationOffsetStop, calculationOffsetStart) - - return calculationInterval >= minimumCalculationInterval -} - -func shouldIncludeEah(epochSchedule *sealevel.SysvarEpochSchedule, slotCtx *sealevel.SlotCtx) bool { - if !isEnabledThisEpoch(epochSchedule, slotCtx.Epoch) { - return false - } - - slotsPerEpoch := epochSchedule.SlotsInEpoch(slotCtx.Epoch) - calculationOffsetStop := (slotsPerEpoch / 4) * 3 - firstSlotInEpoch := epochSchedule.FirstSlotInEpoch(slotCtx.Epoch) - stopSlot := safemath.SaturatingAddU64(firstSlotInEpoch, calculationOffsetStop) - - return slotCtx.ParentSlot < stopSlot && slotCtx.Slot >= stopSlot -} - -func calculateBankHash(slotCtx *sealevel.SlotCtx, acctsDeltaHash []byte, parentBankHash [32]byte, numSigs uint64, blockHash [32]byte) []byte { - hasher := sha256.New() - hasher.Write(parentBankHash[:]) - hasher.Write(acctsDeltaHash[:]) - - var numSigsBytes [8]byte - binary.LittleEndian.PutUint64(numSigsBytes[:], numSigs) - - hasher.Write(numSigsBytes[:]) - hasher.Write(blockHash[:]) - - bankHash := hasher.Sum(nil) - - // EAH must be worked into the bankhash for the slot that is 3/4 through the epoch - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar - if shouldIncludeEah(epochSchedule, slotCtx) { - mlog.Log.Infof("EAH required for this bankhash") - hasher := sha256.New() - hasher.Write(bankHash) - hasher.Write(slotCtx.EpochsAcctHash) - bankHash = hasher.Sum(nil) - } - - return bankHash -} - -var maxBlockfetchAttempts = 10 - -func fetchBankhashForSlot(rpcc *rpcclient.RpcClient, slot uint64) ([]byte, error) { - var blockResult *rpc.GetBlockResult - var err error - var errCount uint64 - - slotToFetch := slot + 1 - for { - blockResult, err = rpcc.GetBlockFinalized(uint64(slotToFetch)) - if err == nil { - break - } else if err == rpcclient.SlotSkipped { - slotToFetch++ - } else { - if errCount == 10 { - return nil, fmt.Errorf("unable to get block: %s", err) - } - errCount++ - } - } - - block := block.FromBlockResult(blockResult, slot, rpcc) - - var count uint64 - for _, tx := range block.Transactions { - if tx.IsVote() { - - // skip first 400 votes. most of the first load of votes in a slot usually pertain to two slots back rather than - // the most recent parent slot. - count++ - if count < 400 { - continue - } - - if len(tx.Message.Instructions) < 1 { - continue - } - - instrData := tx.Message.Instructions[0].Data - decoder := bin.NewBinDecoder(instrData) - instructionType, err := decoder.ReadUint32(bin.LE) - if err != nil { - continue - } - - if instructionType != sealevel.VoteProgramInstrTypeTowerSync && instructionType != sealevel.VoteProgramInstrTypeTowerSyncSwitch { - continue - } - - var towerSyncInstr *sealevel.VoteInstrTowerSync - - if instructionType == sealevel.VoteProgramInstrTypeTowerSync { - var vote sealevel.VoteInstrTowerSync - err = vote.UnmarshalWithDecoder(decoder) - if err != nil { - continue - } - towerSyncInstr = &vote - - } else if instructionType == sealevel.VoteProgramInstrTypeTowerSyncSwitch { - var vote sealevel.VoteInstrTowerSyncSwitch - err = vote.UnmarshalWithDecoder(decoder) - if err != nil { - continue - } - towerSyncInstr = &vote.TowerSync - } - - lockoutsLen := towerSyncInstr.Lockouts.Len() - if lockoutsLen == 0 { - continue - } - - lockout := towerSyncInstr.Lockouts.PopBack() - if lockout.Slot == slot { - return towerSyncInstr.Hash[:], nil - } - } - } - - panic("unable to find a vote for the relevant slot") -} diff --git a/pkg/replay/hash_test.go b/pkg/replay/hash_test.go deleted file mode 100644 index b98b9c60..00000000 --- a/pkg/replay/hash_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package replay - -import ( - "encoding/hex" - "encoding/json" - "fmt" - "strconv" - "testing" - - "github.com/Overclock-Validator/mithril/fixtures" - "github.com/Overclock-Validator/mithril/pkg/accounts" - "github.com/Overclock-Validator/mithril/pkg/base58" - "github.com/stretchr/testify/assert" -) - -// uses known good values to test if bankhash computes correctly -func Test_Compute_Bank_Hash(t *testing.T) { - acctsDeltaHash := []byte{148, 1, 99, 1, 94, 42, 27, 37, 216, 66, 0, 57, 116, 109, 251, 51, 250, 101, 228, 74, 44, 3, 94, 73, 120, 148, 27, 210, 78, 34, 112, 212} - parentBankHash := [32]byte{216, 24, 141, 114, 110, 72, 188, 246, 47, 80, 102, 40, 122, 219, 11, 94, 100, 159, 96, 122, 195, 101, 140, 19, 22, 225, 243, 127, 23, 182, 65, 90} - numSigs := uint64(2) - blockHash := [32]byte{113, 124, 28, 34, 197, 214, 189, 118, 67, 41, 212, 2, 122, 6, 74, 59, 124, 160, 185, 122, 37, 39, 142, 149, 224, 42, 26, 49, 215, 200, 16, 19} - - // correct bankhash for the above values - knownCorrectBankHash := []byte{190, 156, 54, 163, 252, 183, 243, 10, 147, 168, 42, 47, 214, 172, 160, 64, 86, 32, 203, 54, 119, 230, 201, 36, 164, 27, 30, 244, 96, 202, 88, 154} - - bankHash := calculateBankHash(nil, acctsDeltaHash, parentBankHash, numSigs, blockHash) - assert.Equal(t, bankHash, knownCorrectBankHash) -} - -type testAcct struct { - Lamports string `json:"lamports"` - Len uint64 `json:"data.len"` - Owner string `json:"owner"` - Executable bool `json:"executable"` - RentEpoch string `json:"rent_epoch"` - Data string `json:"data"` - Pubkey string `json:"pubkey"` -} - -func Test_Accounts_Delta_Hash_And_BankHash(t *testing.T) { - acctsJson := fixtures.Load(t, "hash", "accts.json") - - var testAccts []testAcct - err := json.Unmarshal(acctsJson, &testAccts) - assert.NoError(t, err) - - fmt.Printf("unmarshaled %d accts\n", len(testAccts)) - - accts := make([]*accounts.Account, 0) - for _, ta := range testAccts { - data, err := hex.DecodeString(ta.Data) - assert.NoError(t, err) - lamports, err := strconv.ParseUint(ta.Lamports, 10, 64) - rentEpoch, err := strconv.ParseUint(ta.RentEpoch, 10, 64) - a := &accounts.Account{Key: base58.MustDecodeFromString(ta.Pubkey), Lamports: lamports, Data: data, Executable: ta.Executable, Owner: base58.MustDecodeFromString(ta.Owner), RentEpoch: rentEpoch} - accts = append(accts, a) - } - - acctsDeltaHash := calculateAcctsDeltaHash(accts) - knownCorrectAcctsDeltaHash := []byte{159, 193, 234, 234, 232, 60, 116, 92, 110, 95, 206, 137, 221, 188, 150, 211, 233, 2, 24, 56, 20, 207, 125, 123, 135, 193, 5, 37, 114, 203, 108, 109} - - fmt.Printf("calculated accts delta hash: %d\n", acctsDeltaHash) - fmt.Printf("known accts delta hash: %d\n", knownCorrectAcctsDeltaHash) - - assert.Equal(t, acctsDeltaHash, knownCorrectAcctsDeltaHash) - - parentBankHash := [32]byte{89, 9, 149, 199, 126, 19, 109, 42, 164, 143, 181, 134, 72, 50, 37, 12, 232, 164, 118, 184, 89, 104, 82, 205, 254, 58, 135, 223, 67, 69, 131, 62} - numSigs := uint64(2) - blockHash := [32]byte{146, 202, 69, 18, 36, 202, 121, 99, 47, 1, 177, 105, 158, 183, 91, 218, 104, 146, 24, 15, 17, 59, 160, 158, 71, 187, 255, 20, 105, 124, 226, 82} - knownCorrectBankHash := []byte{119, 170, 167, 64, 81, 16, 52, 152, 70, 85, 198, 20, 1, 9, 69, 90, 128, 26, 216, 178, 224, 255, 106, 149, 70, 45, 52, 83, 69, 197, 64, 245} - - bankHash := calculateBankHash(nil, acctsDeltaHash, parentBankHash, numSigs, blockHash) - - fmt.Printf("calculated bankhash: %d\n", bankHash) - fmt.Printf("known bankhash: %d\n", knownCorrectBankHash) - - assert.Equal(t, bankHash, knownCorrectBankHash) -} diff --git a/pkg/replay/leader_schedule.go b/pkg/replay/leader_schedule.go index 71fe6ac7..9a267189 100644 --- a/pkg/replay/leader_schedule.go +++ b/pkg/replay/leader_schedule.go @@ -1,111 +1,1743 @@ package replay import ( - "errors" + "bufio" + "bytes" + "crypto/sha256" + "encoding/base64" "fmt" + "os" + "path/filepath" + "sort" + "sync" + "sync/atomic" "time" + "github.com/Overclock-Validator/mithril/pkg/accountsdb" + "github.com/Overclock-Validator/mithril/pkg/config" + "github.com/Overclock-Validator/mithril/pkg/epochstakes" "github.com/Overclock-Validator/mithril/pkg/global" "github.com/Overclock-Validator/mithril/pkg/leaderschedule" "github.com/Overclock-Validator/mithril/pkg/mlog" - "github.com/Overclock-Validator/mithril/pkg/rpcclient" "github.com/Overclock-Validator/mithril/pkg/sealevel" "github.com/gagliardetto/solana-go" + "github.com/panjf2000/ants/v2" ) -// ErrLeaderScheduleFailed is returned when leader schedule cannot be fetched from any endpoint -var ErrLeaderScheduleFailed = errors.New("leader schedule fetch failed from all RPC endpoints") +const ( + // NumConsecutiveLeaderSlots matches Solana's NUM_CONSECUTIVE_LEADER_SLOTS + NumConsecutiveLeaderSlots = 4 + // MaxMismatchLogsPerEpoch caps mismatch logging to avoid disk churn + MaxMismatchLogsPerEpoch = 100 + // SampleBoundarySlots is how many slots to check at epoch boundaries + SampleBoundarySlots = 2000 + // SampleRandomSlots is how many random slots to sample in the middle + SampleRandomSlots = 1000 + // MaxMissingVoteCacheStakePercent is the maximum percentage of stake that can be + // missing from VoteCache before we fail. Since local schedule is the source of truth, + // missing VoteCache entries mean that stake is dropped from the schedule, which would + // produce an incorrect schedule. + // Set to 0 for zero tolerance - any missing stake is a hard failure. + // The VoteCache rebuild from AccountsDB should ensure this never triggers. + MaxMissingVoteCacheStakePercent = 0.0 + // DefaultVoteCacheRebuildConcurrency is the default number of concurrent workers + // for rebuilding vote cache from AccountsDB at epoch boundaries. + DefaultVoteCacheRebuildConcurrency = 16 +) + +var ( + mismatchLogOnce sync.Once + mismatchLogFile *os.File + mismatchLogWriter *bufio.Writer + mismatchLogMu sync.Mutex +) + +// defaultLogsDir is the fallback directory for mismatch logs +const defaultLogsDir = "/mnt/mithril-logs" + +// mismatchLogPath stores the resolved path for use in warnings +var mismatchLogPath string + +// resolveLogsDir returns the leader_schedule subdirectory within the run directory. +// Creates a dedicated subdirectory to keep leader schedule files organized. +func resolveLogsDir(logsDir string) string { + var baseDir string + // First try mlog's directory (for run ID correlation) + if mlogDir := mlog.GetLogDir(); mlogDir != "" { + baseDir = mlogDir + } else if logsDir != "" { + baseDir = logsDir + } else { + baseDir = defaultLogsDir + } + // Return leader_schedule subdirectory + return filepath.Join(baseDir, "leader_schedule") +} + +// initMismatchLog creates/opens the mismatch log file (once per process). +// Uses the same log directory as Mithril's main logs with run ID for correlation. +func initMismatchLog(logsDir string) { + mismatchLogOnce.Do(func() { + logsDir = resolveLogsDir(logsDir) + // Create directory if it doesn't exist + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("failed to create mismatch log directory: %v", err) + return + } + + // Use run ID in filename for correlation with main log + runID := mlog.GetRunID() + var filename string + if runID != "" { + shortRunID := runID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + filename = fmt.Sprintf("mismatch_%s.log", shortRunID) + } else { + filename = "mismatch.log" + } + mismatchLogPath = filepath.Join(logsDir, filename) + + var err error + mismatchLogFile, err = os.OpenFile(mismatchLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + mlog.Log.Warnf("failed to open leader schedule mismatch log: %v", err) + return + } + mismatchLogWriter = bufio.NewWriter(mismatchLogFile) + mlog.Log.FileOnlyf("leader schedule mismatch log: %s", mismatchLogPath) + }) +} + +// getMismatchLogPath returns the path to the mismatch log file +func getMismatchLogPath() string { + if mismatchLogPath != "" { + return mismatchLogPath + } + return filepath.Join(resolveLogsDir(""), "leader_schedule_mismatch.log") +} + +// flushMismatchLog flushes buffered writes (call at end of epoch validation) +func flushMismatchLog() { + mismatchLogMu.Lock() + defer mismatchLogMu.Unlock() + if mismatchLogWriter != nil { + mismatchLogWriter.Flush() + } +} -// prepareLeaderSchedule fetches the leader schedule for the given epoch. -// It tries the primary RPC client first, then falls back to backup endpoints. -// Returns an error instead of panicking to allow graceful shutdown. -func prepareLeaderSchedule(epoch uint64, epochSchedule *sealevel.SysvarEpochSchedule, rpcClient *rpcclient.RpcClient) error { - return prepareLeaderScheduleWithBackups(epoch, epochSchedule, rpcClient, nil) +// VoteCacheRebuildError holds info about a failed vote account for logging +type VoteCacheRebuildError struct { + VoteAcct solana.PublicKey + Stake uint64 + Reason string + Err error } -// prepareLeaderScheduleWithBackups tries the primary client, then backup endpoints in order. -func prepareLeaderScheduleWithBackups(epoch uint64, epochSchedule *sealevel.SysvarEpochSchedule, rpcClient *rpcclient.RpcClient, backupEndpoints []string) error { - firstSlotInEpoch := epochSchedule.FirstSlotInEpoch(epoch) +// RebuildVoteCacheFromAccountsDB rebuilds the VoteCache from AccountsDB for all vote accounts +// in the stake map. This ensures correctness at epoch boundaries by reading the canonical +// state directly from AccountsDB. +// +// Parameters: +// - acctsDb: the AccountsDB instance +// - slot: the slot at which to read account state (typically lastSlotCtx.Slot) +// - voteAcctStakes: the stake map for the new epoch (vote account -> stake) +// - maxConcurrency: number of concurrent workers (0 = use default) +// +// Returns error if any vote account is missing or has invalid state. +// This is a blocking operation and should be called before building the leader schedule. +func RebuildVoteCacheFromAccountsDB( + acctsDb *accountsdb.AccountsDb, + slot uint64, + voteAcctStakes map[solana.PublicKey]uint64, + maxConcurrency int, +) error { + if maxConcurrency <= 0 { + maxConcurrency = DefaultVoteCacheRebuildConcurrency + } + + startTime := time.Now() + totalAccounts := len(voteAcctStakes) + var zeroStakeCount int + for _, stake := range voteAcctStakes { + if stake == 0 { + zeroStakeCount++ + } + } + nonZeroAccounts := totalAccounts - zeroStakeCount + + mlog.Log.FileOnlyf("vote cache rebuild: starting slot=%d accounts=%d (non-zero=%d) concurrency=%d", + slot, totalAccounts, nonZeroAccounts, maxConcurrency) + + // Counters for stats (use atomics for thread safety) + var successCount atomic.Int64 + var missingCount atomic.Int64 + var unmarshalErrCount atomic.Int64 + var zeroNodePkCount atomic.Int64 + var missingStake atomic.Uint64 + var unmarshalErrStake atomic.Uint64 + var zeroNodePkStake atomic.Uint64 + + // Track first few errors for each category (with mutex for thread safety) + const maxErrorsPerCategory = 5 + var errorsMu sync.Mutex + var missingErrors []VoteCacheRebuildError + var unmarshalErrors []VoteCacheRebuildError + var zeroNodePkErrors []VoteCacheRebuildError + + // Track first error for reporting (use sync.Once to capture exactly one error) + var firstError error + var firstErrorOnce sync.Once + + // Create worker pool + var wg sync.WaitGroup + pool, err := ants.NewPoolWithFunc(maxConcurrency, func(i interface{}) { + defer wg.Done() + + item := i.(struct { + pk solana.PublicKey + stake uint64 + }) + + // Read vote account from AccountsDB + voteAcct, err := acctsDb.GetAccount(slot, item.pk) + if err != nil { + global.DeleteVoteCacheItem(item.pk) + missingCount.Add(1) + missingStake.Add(item.stake) + errorsMu.Lock() + if len(missingErrors) < maxErrorsPerCategory { + missingErrors = append(missingErrors, VoteCacheRebuildError{ + VoteAcct: item.pk, + Stake: item.stake, + Reason: "not_found_in_accountsdb", + Err: err, + }) + } + errorsMu.Unlock() + firstErrorOnce.Do(func() { + firstError = fmt.Errorf("missing vote account %s (stake=%d): %w", item.pk, item.stake, err) + }) + return + } + + // Unmarshal vote state + versionedVoteState, err := sealevel.UnmarshalVersionedVoteState(voteAcct.Data) + if err != nil { + global.DeleteVoteCacheItem(item.pk) + unmarshalErrCount.Add(1) + unmarshalErrStake.Add(item.stake) + errorsMu.Lock() + if len(unmarshalErrors) < maxErrorsPerCategory { + unmarshalErrors = append(unmarshalErrors, VoteCacheRebuildError{ + VoteAcct: item.pk, + Stake: item.stake, + Reason: fmt.Sprintf("unmarshal_failed (data_len=%d)", len(voteAcct.Data)), + Err: err, + }) + } + errorsMu.Unlock() + firstErrorOnce.Do(func() { + firstError = fmt.Errorf("failed to unmarshal vote account %s (stake=%d): %w", item.pk, item.stake, err) + }) + return + } + + // Validate NodePubkey is non-zero + nodePk := versionedVoteState.NodePubkey() + var zeroPk solana.PublicKey + if nodePk == zeroPk { + global.DeleteVoteCacheItem(item.pk) + zeroNodePkCount.Add(1) + zeroNodePkStake.Add(item.stake) + errorsMu.Lock() + if len(zeroNodePkErrors) < maxErrorsPerCategory { + zeroNodePkErrors = append(zeroNodePkErrors, VoteCacheRebuildError{ + VoteAcct: item.pk, + Stake: item.stake, + Reason: "zero_nodepubkey", + }) + } + errorsMu.Unlock() + firstErrorOnce.Do(func() { + firstError = fmt.Errorf("vote account %s has zero NodePubkey (stake=%d)", item.pk, item.stake) + }) + return + } + + // Update VoteCache + global.PutVoteCacheItem(item.pk, versionedVoteState) + successCount.Add(1) + }) + if err != nil { + return fmt.Errorf("failed to create worker pool: %w", err) + } + defer pool.Release() + + // Submit all vote accounts to the pool + for pk, stake := range voteAcctStakes { + if stake == 0 { + continue // Skip zero-stake accounts + } + wg.Add(1) + item := struct { + pk solana.PublicKey + stake uint64 + }{pk: pk, stake: stake} + if err := pool.Invoke(item); err != nil { + wg.Done() + return fmt.Errorf("failed to submit work to pool: %w", err) + } + } + + // Wait for all workers to complete + wg.Wait() + + duration := time.Since(startTime) + + // Calculate total stake for percentage + var totalStake uint64 + for _, stake := range voteAcctStakes { + totalStake += stake + } + successStake := totalStake - missingStake.Load() - unmarshalErrStake.Load() - zeroNodePkStake.Load() + + // File only: single line summary + skipped := nonZeroAccounts - int(successCount.Load()) + mlog.Log.FileOnlyf("Vote cache: loaded=%d skipped=%d duration=%v", + successCount.Load(), skipped, duration) + + // File only: detailed results + mlog.Log.FileOnlyf("vote cache rebuild details: slot=%d duration=%v", slot, duration) + mlog.Log.FileOnlyf(" accounts: total=%d non_zero=%d success=%d", + totalAccounts, nonZeroAccounts, successCount.Load()) + mlog.Log.FileOnlyf(" stake: total=%d success=%d (%.2f%%)", + totalStake, successStake, float64(successStake)/float64(totalStake)*100) + + // Check for any failures + missing := missingCount.Load() + unmarshalErr := unmarshalErrCount.Load() + zeroNodePk := zeroNodePkCount.Load() + totalFailed := missing + unmarshalErr + zeroNodePk + + if totalFailed > 0 { + totalFailedStake := missingStake.Load() + unmarshalErrStake.Load() + zeroNodePkStake.Load() + failedPercent := float64(totalFailedStake) / float64(totalStake) * 100 + + // File only: detailed failure info (always log for debugging) + mlog.Log.FileOnlyf("vote cache rebuild failures:") + mlog.Log.FileOnlyf(" slot=%d", slot) + mlog.Log.FileOnlyf(" failures: missing=%d (stake=%d) unmarshal_err=%d (stake=%d) zero_nodepk=%d (stake=%d)", + missing, missingStake.Load(), unmarshalErr, unmarshalErrStake.Load(), zeroNodePk, zeroNodePkStake.Load()) + mlog.Log.FileOnlyf(" total_failed=%d total_failed_stake=%d (%.4f%% of total)", + totalFailed, totalFailedStake, failedPercent) + + // File only: first few errors in each category + if len(missingErrors) > 0 { + mlog.Log.FileOnlyf(" missing_accounts (first %d):", len(missingErrors)) + for i, e := range missingErrors { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d err=%v", i+1, e.VoteAcct, e.Stake, e.Err) + } + } + if len(unmarshalErrors) > 0 { + mlog.Log.FileOnlyf(" unmarshal_errors (first %d):", len(unmarshalErrors)) + for i, e := range unmarshalErrors { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d reason=%s err=%v", i+1, e.VoteAcct, e.Stake, e.Reason, e.Err) + } + } + if len(zeroNodePkErrors) > 0 { + mlog.Log.FileOnlyf(" zero_nodepk_accounts (first %d):", len(zeroNodePkErrors)) + for i, e := range zeroNodePkErrors { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) + } + } + + // Small percentage of unavailable vote accounts is expected on mainnet (dead/closed validators) + // Only ERROR if significant stake is missing - otherwise it's just noise + if failedPercent > 5.0 { + mlog.Log.Errorf("VOTE CACHE REBUILD: slot=%d skipped=%d (%.4f%% stake) - exceeds threshold", + slot, totalFailed, failedPercent) + if firstError != nil { + return fmt.Errorf("vote cache rebuild: %d unavailable (%.4f%% stake): %w", + totalFailed, failedPercent, firstError) + } + return fmt.Errorf("vote cache rebuild: %d unavailable (%.4f%% stake)", + totalFailed, failedPercent) + } - // Try primary endpoint first - leaderMap, err := fetchLeaderScheduleWithRetry(rpcClient, 10) - if err == nil { - leaderSchedule := leaderschedule.NewLeaderScheduleFromKeyedSlots(leaderMap, firstSlotInEpoch) - global.SetLeaderSchedule(leaderSchedule) + // Expected mainnet behavior - log to file only + mlog.Log.FileOnlyf("vote cache rebuild: slot=%d skipped=%d unavailable vote accounts (%.4f%% stake)", + slot, totalFailed, failedPercent) return nil } - lastErr := err - mlog.Log.Errorf("leader schedule fetch failed on primary %s: %v", rpcClient.Endpoint(), err) + mlog.Log.FileOnlyf(" result: SUCCESS (all %d non-zero accounts rebuilt)", nonZeroAccounts) + return nil +} + +// StakeEntry holds a vote account and its stake for logging +type StakeEntry struct { + VoteAcct solana.PublicKey + NodePubkey solana.PublicKey + Stake uint64 + Reason string // For skipped entries: "zero_stake", "missing_vote_acct", "zero_nodepk" +} + +// dumpFullScheduleData writes complete validator data to CSV files for debugging. +// Creates epoch-specific files in the logs directory with ALL validators. +// Includes run ID in filename to prevent overwriting on re-runs. +func dumpFullScheduleData( + epoch uint64, + source string, // "snapshot", "vote_cache", or "rpc" + validEntries []StakeEntry, + skippedEntries []StakeEntry, + totalStake uint64, + logsDir string, +) { + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("dumpFullScheduleData: failed to create logs dir: %v", err) + return + } - // Try backup endpoints - for i, endpoint := range backupEndpoints { - mlog.Log.Infof("trying backup RPC endpoint #%d for leader schedule: %s", i+1, endpoint) - backupClient := rpcclient.NewRpcClient(endpoint) - leaderMap, err := fetchLeaderScheduleWithRetry(backupClient, 5) // Fewer retries on backups - if err == nil { - mlog.Log.Infof("leader schedule fetched from backup endpoint %s", endpoint) - leaderSchedule := leaderschedule.NewLeaderScheduleFromKeyedSlots(leaderMap, firstSlotInEpoch) - global.SetLeaderSchedule(leaderSchedule) - return nil + // Get short run ID for filename (prevents overwriting on re-runs) + runID := mlog.GetRunID() + shortRunID := "" + if runID != "" { + shortRunID = runID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] } - lastErr = err - mlog.Log.Errorf("leader schedule fetch failed on backup %s: %v", endpoint, err) + shortRunID = "_" + shortRunID } - // All endpoints failed - return fmt.Errorf("%w: last error: %v", ErrLeaderScheduleFailed, lastErr) + // Write validators CSV + validatorsFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_validators.csv", epoch, source, shortRunID)) + if err := writeValidatorsCSV(validatorsFile, epoch, source, validEntries, totalStake); err != nil { + mlog.Log.Warnf("dumpFullScheduleData: failed to write validators CSV: %v", err) + } else { + mlog.Log.FileOnlyf("leader schedule validators dumped to: %s (%d entries)", validatorsFile, len(validEntries)) + } + + // Write skipped CSV if there are any + if len(skippedEntries) > 0 { + skippedFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_skipped.csv", epoch, source, shortRunID)) + if err := writeSkippedCSV(skippedFile, epoch, skippedEntries); err != nil { + mlog.Log.Warnf("dumpFullScheduleData: failed to write skipped CSV: %v", err) + } else { + mlog.Log.FileOnlyf("leader schedule skipped accounts dumped to: %s (%d entries)", skippedFile, len(skippedEntries)) + } + } } -// fetchLeaderScheduleWithRetry attempts to fetch leader schedule with exponential backoff -func fetchLeaderScheduleWithRetry(rpcClient *rpcclient.RpcClient, maxAttempts int) (map[solana.PublicKey][]uint64, error) { - var leaderMap map[solana.PublicKey][]uint64 - var err error +// dumpFullScheduleDataWithSummary writes validators CSV, skipped CSV, and a summary file. +// This is the preferred function when all metadata is available. +// Includes run ID in filenames to prevent overwriting on re-runs. +func dumpFullScheduleDataWithSummary( + validEntries []StakeEntry, + skippedEntries []StakeEntry, + summary ScheduleSummary, + logsDir string, +) { + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to create logs dir: %v", err) + return + } + + epoch := summary.BlockEpoch + source := summary.Source + + // Get short run ID for filename (prevents overwriting on re-runs) + shortRunID := "" + if summary.RunID != "" { + shortRunID = summary.RunID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + shortRunID = "_" + shortRunID + } + + // Write validators CSV + validatorsFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_validators.csv", epoch, source, shortRunID)) + if err := writeValidatorsCSV(validatorsFile, epoch, source, validEntries, summary.FilteredStake); err != nil { + mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write validators CSV: %v", err) + } else { + mlog.Log.FileOnlyf("leader schedule validators dumped to: %s (%d entries)", validatorsFile, len(validEntries)) + } + + // Write skipped CSV if there are any + if len(skippedEntries) > 0 { + skippedFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_skipped.csv", epoch, source, shortRunID)) + if err := writeSkippedCSV(skippedFile, epoch, skippedEntries); err != nil { + mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write skipped CSV: %v", err) + } else { + mlog.Log.FileOnlyf("leader schedule skipped accounts dumped to: %s (%d entries)", skippedFile, len(skippedEntries)) + } + } + + // Write summary file + summaryFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_summary.txt", epoch, source, shortRunID)) + if err := writeSummaryFile(summaryFile, summary); err != nil { + mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write summary: %v", err) + } else { + mlog.Log.FileOnlyf("leader schedule summary dumped to: %s", summaryFile) + } +} + +// DumpTieBreakDebug writes tie-break debugging info to a file. +// This verifies that equal-stake validators are sorted by pubkey DESC (Agave behavior). +func DumpTieBreakDebug( + epoch uint64, + voteAcctStakes map[solana.PublicKey]uint64, + voteAcctMap map[solana.PublicKey]*epochstakes.VoteAccount, + logsDir string, +) { + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("DumpTieBreakDebug: failed to create logs dir: %v", err) + return + } + + runID := mlog.GetRunID() + shortRunID := "" + if runID != "" { + shortRunID = runID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + shortRunID = "_" + shortRunID + } + + filename := fmt.Sprintf("epoch%d_tiebreak%s.txt", epoch, shortRunID) + filePath := filepath.Join(logsDir, filename) - for attempt := 0; attempt < maxAttempts; attempt++ { - leaderMap, err = rpcClient.GetLeaderSchedule() - if err == nil { - return leaderMap, nil + f, err := os.Create(filePath) + if err != nil { + mlog.Log.Warnf("DumpTieBreakDebug: failed to create file: %v", err) + return + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + // Get sorted stakes with tie-break info + allEntries, tieGroups := leaderschedule.GetSortedStakesDebug(voteAcctMap, voteAcctStakes) + + w.WriteString("# Tie-Break Debug for Leader Schedule\n") + w.WriteString(fmt.Sprintf("# Epoch: %d\n", epoch)) + w.WriteString(fmt.Sprintf("# Total validators: %d\n", len(allEntries))) + w.WriteString(fmt.Sprintf("# Tie groups (equal stake): %d\n", len(tieGroups))) + w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) + w.WriteString("#\n") + w.WriteString("# Expected behavior: within each tie group, pubkeys should be sorted DESC (higher bytes first)\n") + w.WriteString("# BytesCmp shows comparison vs previous entry: -1 means current < previous (correct for DESC)\n") + w.WriteString("#\n\n") + + if len(tieGroups) == 0 { + w.WriteString("No tie groups found - all validators have unique stake.\n") + mlog.Log.FileOnlyf("tie-break debug: epoch=%d no ties found", epoch) + return + } + + // Sort tie groups by stake descending for consistent output + type tieGroupInfo struct { + stake uint64 + entries []leaderschedule.TieBreakEntry + } + var sortedGroups []tieGroupInfo + for stake, entries := range tieGroups { + sortedGroups = append(sortedGroups, tieGroupInfo{stake: stake, entries: entries}) + } + sort.Slice(sortedGroups, func(i, j int) bool { + return sortedGroups[i].stake > sortedGroups[j].stake + }) + + for _, group := range sortedGroups { + w.WriteString(fmt.Sprintf("## Tie group: stake=%d (%d validators)\n", group.stake, len(group.entries))) + w.WriteString("rank,node_pubkey,stake,first_8_bytes_hex,bytes_cmp_vs_prev\n") + for _, entry := range group.entries { + w.WriteString(fmt.Sprintf("%d,%s,%d,%x,%d\n", + entry.Rank, entry.NodePk.String(), entry.Stake, entry.RawBytes, entry.BytesCmp)) + } + w.WriteString("\n") + } + + // Log to file only (not terminal) + mlog.Log.FileOnlyf("tie-break debug: epoch=%d tie_groups=%d written to %s", epoch, len(tieGroups), filePath) + + // Log the specific tie if we're looking for it (stake 2499999939665440) + if group, ok := tieGroups[2499999939665440]; ok { + mlog.Log.FileOnlyf("tie-break debug: found target tie group stake=2499999939665440:") + for _, entry := range group { + mlog.Log.FileOnlyf(" rank=%d node=%s bytes_cmp=%d", entry.Rank, entry.NodePk.String(), entry.BytesCmp) } - // Retry with exponential backoff - if attempt < maxAttempts-1 { - waitTime := time.Duration(1< 30*time.Second { - waitTime = 30 * time.Second + } + + // Diagnostic: Check specific vote account → node mappings for epoch 905 debugging + // Vote accounts that caused the tie-break mismatch: + debugVoteAccts := []struct { + vote string + expectedNode string + }{ + {"33hurzEz6aEnzfESL6pnNyR6DCgcKzssT1pwSzDCBTRQ", "Aw5wEMXhbygFLR7jHtHpih8QvxVBGAMTqsQ2SjWPk1ex"}, + {"BU3ZgGBXFJwNTrN6VUJ88k9SJ71SyWfBJTabYqRErm4F", "2GUnfxZavKoPfS9s3VSEjaWDzB3vNf5RojUhprCS1rSx"}, + } + for _, d := range debugVoteAccts { + votePk := solana.MustPublicKeyFromBase58(d.vote) + expectedNodePk := solana.MustPublicKeyFromBase58(d.expectedNode) + stake, hasStake := voteAcctStakes[votePk] + va := voteAcctMap[votePk] + if hasStake || va != nil { + var actualNode solana.PublicKey + if va != nil { + actualNode = va.NodePubkey + } + match := actualNode == expectedNodePk + mlog.Log.FileOnlyf("vote-node-mapping: vote=%s expected_node=%s actual_node=%s stake=%d match=%v", + d.vote, d.expectedNode, actualNode.String(), stake, match) + if !match { + mlog.Log.Warnf("VOTE-NODE MISMATCH: vote=%s expected=%s actual=%s stake=%d", + d.vote, d.expectedNode, actualNode.String(), stake) } - mlog.Log.Infof("leader schedule fetch from %s failed, retrying in %v (attempt %d/%d): %v", - rpcClient.Endpoint(), waitTime, attempt+1, maxAttempts, err) - time.Sleep(waitTime) + } + } +} + +// writeValidatorsCSV writes all validators to a CSV file +func writeValidatorsCSV(filepath string, epoch uint64, source string, entries []StakeEntry, totalStake uint64) error { + f, err := os.Create(filepath) + if err != nil { + return err + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + // Header comments + w.WriteString(fmt.Sprintf("# Leader Schedule - Epoch %d\n", epoch)) + w.WriteString(fmt.Sprintf("# Source: %s\n", source)) + w.WriteString(fmt.Sprintf("# Total Validators: %d\n", len(entries))) + w.WriteString(fmt.Sprintf("# Total Stake: %d\n", totalStake)) + w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) + w.WriteString("#\n") + w.WriteString("rank,vote_account,node_pubkey,stake,stake_percent\n") + + // Write all entries (already sorted by stake descending) + for i, e := range entries { + var stakePercent float64 + if totalStake > 0 { + stakePercent = float64(e.Stake) / float64(totalStake) * 100.0 + } + w.WriteString(fmt.Sprintf("%d,%s,%s,%d,%.6f\n", + i+1, e.VoteAcct, e.NodePubkey, e.Stake, stakePercent)) + } + + return nil +} + +// writeSkippedCSV writes all skipped accounts to a CSV file +func writeSkippedCSV(filepath string, epoch uint64, entries []StakeEntry) error { + f, err := os.Create(filepath) + if err != nil { + return err + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + // Header comments + w.WriteString(fmt.Sprintf("# Leader Schedule Skipped Accounts - Epoch %d\n", epoch)) + w.WriteString(fmt.Sprintf("# Total Skipped: %d\n", len(entries))) + w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) + w.WriteString("#\n") + w.WriteString("# Reasons:\n") + w.WriteString("# zero_stake - Vote account has 0 stake\n") + w.WriteString("# missing_vote_acct - Vote account not found in VoteAcctMap\n") + w.WriteString("# missing_vote_cache - Vote account not found in VoteCache\n") + w.WriteString("# zero_nodepk - Vote account has zero NodePubkey\n") + w.WriteString("#\n") + w.WriteString("# Note: node_pubkey is empty for missing_vote_acct/missing_vote_cache since\n") + w.WriteString("# the vote account data was not available to extract the NodePubkey.\n") + w.WriteString("#\n") + w.WriteString("vote_account,node_pubkey,stake,reason\n") + + for _, e := range entries { + w.WriteString(fmt.Sprintf("%s,%s,%d,%s\n", e.VoteAcct, e.NodePubkey, e.Stake, e.Reason)) + } + + return nil +} + +// ScheduleSummary holds all metadata for the summary file +type ScheduleSummary struct { + // Epoch info + BlockEpoch uint64 + ScheduleEpoch uint64 + FirstSlot uint64 + SlotsInEpoch uint64 + Repeat uint64 + + // Stake info + TotalInputStake uint64 // Total stake from EpochStakes (before filtering) + FilteredStake uint64 // Stake used in schedule (after filtering) + MissingStake uint64 // Stake skipped due to missing data + MissingStakePercent float64 + + // Validator counts + ValidatorsInput int // Total vote accounts in EpochStakes + ValidatorsUsed int // Validators included in schedule + ValidatorsSkipped int // Validators skipped (zero stake + missing + zero nodepk) + SkippedZeroStake int + SkippedMissingData int // missing_vote_acct or missing_vote_cache + SkippedZeroNodePk int + + // Hashes + LocalHash string + RPCHash string // Empty if RPC validation not enabled + + // Run info + RunID string + Source string // "snapshot" or "vote_cache" + Timestamp time.Time +} + +// writeSummaryFile writes a comprehensive summary file for the epoch +func writeSummaryFile(filepath string, summary ScheduleSummary) error { + f, err := os.Create(filepath) + if err != nil { + return err + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + w.WriteString("# Leader Schedule Summary\n") + w.WriteString(fmt.Sprintf("# Generated: %s\n", summary.Timestamp.Format(time.RFC3339))) + w.WriteString(fmt.Sprintf("# Run ID: %s\n", summary.RunID)) + w.WriteString("#\n") + + w.WriteString("## Epoch Info\n") + w.WriteString(fmt.Sprintf("block_epoch=%d\n", summary.BlockEpoch)) + w.WriteString(fmt.Sprintf("schedule_epoch=%d\n", summary.ScheduleEpoch)) + w.WriteString(fmt.Sprintf("first_slot=%d\n", summary.FirstSlot)) + w.WriteString(fmt.Sprintf("slots_in_epoch=%d\n", summary.SlotsInEpoch)) + w.WriteString(fmt.Sprintf("repeat=%d\n", summary.Repeat)) + w.WriteString(fmt.Sprintf("source=%s\n", summary.Source)) + w.WriteString("\n") + + w.WriteString("## Stake Info\n") + w.WriteString(fmt.Sprintf("total_input_stake=%d\n", summary.TotalInputStake)) + w.WriteString(fmt.Sprintf("filtered_stake=%d\n", summary.FilteredStake)) + w.WriteString(fmt.Sprintf("missing_stake=%d\n", summary.MissingStake)) + w.WriteString(fmt.Sprintf("missing_stake_percent=%.4f\n", summary.MissingStakePercent)) + w.WriteString("\n") + + w.WriteString("## Validator Counts\n") + w.WriteString(fmt.Sprintf("validators_input=%d\n", summary.ValidatorsInput)) + w.WriteString(fmt.Sprintf("validators_used=%d\n", summary.ValidatorsUsed)) + w.WriteString(fmt.Sprintf("validators_skipped=%d\n", summary.ValidatorsSkipped)) + w.WriteString(fmt.Sprintf("skipped_zero_stake=%d\n", summary.SkippedZeroStake)) + w.WriteString(fmt.Sprintf("skipped_missing_data=%d\n", summary.SkippedMissingData)) + w.WriteString(fmt.Sprintf("skipped_zero_nodepk=%d\n", summary.SkippedZeroNodePk)) + w.WriteString("\n") + + w.WriteString("## Hashes\n") + w.WriteString(fmt.Sprintf("local_hash=%s\n", summary.LocalHash)) + if summary.RPCHash != "" { + w.WriteString(fmt.Sprintf("rpc_hash=%s\n", summary.RPCHash)) + } + w.WriteString("\n") + + return nil +} + +// ValidationStats holds statistics from schedule validation +type ValidationStats struct { + SkippedZeroStake int + SkippedMissingNodePk int + SkippedMissingNodePkStake uint64 // Stake dropped due to zero NodePubkey + SkippedMissingVoteAcct int + SkippedMissingVoteAcctStake uint64 // Stake dropped due to missing VoteCache entries + TotalVoteAccts int + TotalStake uint64 + MinStake uint64 + MaxStake uint64 + ValidatorCount int // Validators with non-zero stake and valid NodePubkey + MismatchCount int + Capped bool + TopStakes []StakeEntry // Top 10 by stake + BottomStakes []StakeEntry // Bottom 10 by stake + MissingVoteAccts []StakeEntry // First few missing vote accounts (for debugging) + ZeroNodePkAccts []StakeEntry // First few zero NodePubkey accounts +} + +// logScheduleBuildSummary logs a comprehensive summary of the schedule build. +// Called once per epoch when building the leader schedule. +// Terminal output is minimal; detailed info goes to log file only. +func logScheduleBuildSummary( + epoch uint64, + scheduleEpoch uint64, + firstSlot uint64, + slotsInEpoch uint64, + source string, // "snapshot" or "vote_cache" + stats ValidationStats, + fullHash string, +) { + // File only: single line summary + mlog.Log.FileOnlyf("leader schedule: epoch=%d validators=%d stake=%d hash=%s", + epoch, stats.ValidatorCount, stats.TotalStake, fullHash) + + // File only: detailed build info + mlog.Log.FileOnlyf("leader schedule build details: epoch=%d schedule_epoch=%d first_slot=%d slots=%d repeat=%d source=%s", + epoch, scheduleEpoch, firstSlot, slotsInEpoch, NumConsecutiveLeaderSlots, source) + mlog.Log.FileOnlyf(" validators=%d total_stake=%d min_stake=%d max_stake=%d zero_stake_count=%d", + stats.ValidatorCount, stats.TotalStake, stats.MinStake, stats.MaxStake, stats.SkippedZeroStake) + mlog.Log.FileOnlyf(" hash=%s", fullHash) + mlog.Log.FileOnlyf(" skipped: missing_vote_acct=%d (stake=%d) missing_nodepk=%d (stake=%d)", + stats.SkippedMissingVoteAcct, stats.SkippedMissingVoteAcctStake, stats.SkippedMissingNodePk, stats.SkippedMissingNodePkStake) + + // File only: top 10 stakes + if len(stats.TopStakes) > 0 { + mlog.Log.FileOnlyf(" top_stakes (showing %d):", len(stats.TopStakes)) + for i, e := range stats.TopStakes { + mlog.Log.FileOnlyf(" %2d. vote=%s node=%s stake=%d", + i+1, e.VoteAcct, e.NodePubkey, e.Stake) + } + } + + // File only: bottom 10 stakes + if len(stats.BottomStakes) > 0 { + mlog.Log.FileOnlyf(" bottom_stakes (showing %d):", len(stats.BottomStakes)) + for i, e := range stats.BottomStakes { + mlog.Log.FileOnlyf(" %2d. vote=%s node=%s stake=%d", + i+1, e.VoteAcct, e.NodePubkey, e.Stake) } } - return nil, fmt.Errorf("failed after %d attempts: %w", maxAttempts, err) + // File only: offending accounts if any were skipped + if len(stats.MissingVoteAccts) > 0 { + mlog.Log.FileOnlyf(" missing_vote_accts (first %d):", len(stats.MissingVoteAccts)) + for i, e := range stats.MissingVoteAccts { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) + } + } + if len(stats.ZeroNodePkAccts) > 0 { + mlog.Log.FileOnlyf(" zero_nodepk_accts (first %d):", len(stats.ZeroNodePkAccts)) + for i, e := range stats.ZeroNodePkAccts { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) + } + } } -// fetchLeaderScheduleForEpochWithRetry fetches leader schedule for a specific epoch with retries. -// This is needed when validating historical epochs during catchup, since the default -// GetLeaderSchedule returns the RPC node's current epoch schedule. -// RPC method: getLeaderSchedule with epoch parameter -func fetchLeaderScheduleForEpochWithRetry(rpcClient *rpcclient.RpcClient, epoch uint64, maxAttempts int) (map[solana.PublicKey][]uint64, error) { - var leaderMap map[solana.PublicKey][]uint64 - var err error +// logHardFailContext logs detailed context when schedule build fails. +// Terminal shows brief error; file gets full details. +func logHardFailContext( + epoch uint64, + reason string, + stats ValidationStats, +) { + // Terminal: brief error + mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=%s", epoch, reason) + + // File only: detailed context + mlog.Log.FileOnlyf("LEADER SCHEDULE BUILD FAILED DETAILS:") + mlog.Log.FileOnlyf(" epoch=%d reason=%s", epoch, reason) + mlog.Log.FileOnlyf(" input_vote_accts=%d total_stake_available=%d", + stats.TotalVoteAccts, stats.TotalStake) + mlog.Log.FileOnlyf(" skipped: zero_stake=%d missing_vote_acct=%d (stake=%d) missing_nodepk=%d (stake=%d)", + stats.SkippedZeroStake, stats.SkippedMissingVoteAcct, stats.SkippedMissingVoteAcctStake, stats.SkippedMissingNodePk, stats.SkippedMissingNodePkStake) - for attempt := 0; attempt < maxAttempts; attempt++ { - leaderMap, err = rpcClient.GetLeaderScheduleForEpoch(epoch) - if err == nil { - return leaderMap, nil + // File only: first few offending accounts + if len(stats.MissingVoteAccts) > 0 { + mlog.Log.FileOnlyf(" missing_vote_accts (first %d):", len(stats.MissingVoteAccts)) + for i, e := range stats.MissingVoteAccts { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) } - // Retry with exponential backoff - if attempt < maxAttempts-1 { - waitTime := time.Duration(1< 30*time.Second { - waitTime = 30 * time.Second + } + if len(stats.ZeroNodePkAccts) > 0 { + mlog.Log.FileOnlyf(" zero_nodepk_accts (first %d):", len(stats.ZeroNodePkAccts)) + for i, e := range stats.ZeroNodePkAccts { + mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) + } + } + + // File only: valid top stakes for context + if len(stats.TopStakes) > 0 { + mlog.Log.FileOnlyf(" top_stakes_found (showing %d):", min(5, len(stats.TopStakes))) + for i := 0; i < min(5, len(stats.TopStakes)); i++ { + e := stats.TopStakes[i] + mlog.Log.FileOnlyf(" %d. vote=%s node=%s stake=%d", i+1, e.VoteAcct, e.NodePubkey, e.Stake) + } + } +} + +// buildLocalLeaderSchedule builds a leader schedule from local state. +// Returns nil schedule if no valid stakes are available. +// Also returns all valid and skipped entries for CSV dump. +func buildLocalLeaderSchedule( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + voteAcctStakes map[solana.PublicKey]uint64, + voteAcctMap map[solana.PublicKey]*epochstakes.VoteAccount, +) (*leaderschedule.LeaderSchedule, ValidationStats, []StakeEntry, []StakeEntry) { + stats := ValidationStats{ + TotalVoteAccts: len(voteAcctStakes), + MinStake: ^uint64(0), // Start with max value + } + + // Collect ALL valid and skipped entries for CSV dump + var validEntries []StakeEntry + var skippedEntries []StakeEntry + + // Filter and build epochVoteAccts map (only entries with stake > 0 and valid NodePubkey) + epochVoteAccts := make(map[solana.PublicKey]*epochstakes.VoteAccount) + filteredStakes := make(map[solana.PublicKey]uint64) + + for votePk, stake := range voteAcctStakes { + if stake == 0 { + stats.SkippedZeroStake++ + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "zero_stake", + }) + continue + } + + va := voteAcctMap[votePk] + if va == nil { + stats.SkippedMissingVoteAcct++ + stats.SkippedMissingVoteAcctStake += stake + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "missing_vote_acct", + }) + // Track first few for quick debugging in logs + if len(stats.MissingVoteAccts) < 5 { + stats.MissingVoteAccts = append(stats.MissingVoteAccts, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + }) } - mlog.Log.Debugf("leader schedule fetch for epoch %d from %s failed, retrying in %v (attempt %d/%d): %v", - epoch, rpcClient.Endpoint(), waitTime, attempt+1, maxAttempts, err) - time.Sleep(waitTime) + continue } + + // Check for zero NodePubkey (missing) + var zeroPk solana.PublicKey + if va.NodePubkey == zeroPk { + stats.SkippedMissingNodePk++ + stats.SkippedMissingNodePkStake += stake + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "zero_nodepk", + }) + // Track first few for quick debugging in logs + if len(stats.ZeroNodePkAccts) < 5 { + stats.ZeroNodePkAccts = append(stats.ZeroNodePkAccts, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + }) + } + continue + } + + epochVoteAccts[votePk] = va + filteredStakes[votePk] = stake + stats.TotalStake += stake + + // Track min/max + if stake < stats.MinStake { + stats.MinStake = stake + } + if stake > stats.MaxStake { + stats.MaxStake = stake + } + + validEntries = append(validEntries, StakeEntry{ + VoteAcct: votePk, + NodePubkey: va.NodePubkey, + Stake: stake, + }) } - return nil, fmt.Errorf("failed after %d attempts: %w", maxAttempts, err) + stats.ValidatorCount = len(validEntries) + + // Guard: empty stakes would panic in weightedrand + if len(filteredStakes) == 0 { + stats.MinStake = 0 // Reset since no valid entries + return nil, stats, validEntries, skippedEntries + } + + // Sort entries by stake descending, then node pubkey descending (matches schedule computation) + sort.Slice(validEntries, func(i, j int) bool { + if validEntries[i].Stake != validEntries[j].Stake { + return validEntries[i].Stake > validEntries[j].Stake + } + // Tie-break by node pubkey descending (higher bytes first) - matches Agave + return bytes.Compare(validEntries[i].NodePubkey[:], validEntries[j].NodePubkey[:]) > 0 + }) + + // Capture top 10 and bottom 10 for log summary + for i := 0; i < min(10, len(validEntries)); i++ { + stats.TopStakes = append(stats.TopStakes, validEntries[i]) + } + for i := max(0, len(validEntries)-10); i < len(validEntries); i++ { + stats.BottomStakes = append(stats.BottomStakes, validEntries[i]) + } + + // Get epoch length (handles warmup epochs correctly) + slotsInEpoch := epochSchedule.SlotsInEpoch(epoch) + + // Build the schedule using leaderschedule.New + ls := leaderschedule.New( + epochVoteAccts, + filteredStakes, + epochSchedule, + epoch, + slotsInEpoch, + NumConsecutiveLeaderSlots, + ) + + return ls, stats, validEntries, skippedEntries +} + +// buildLocalLeaderScheduleFromVoteCache builds schedule using global.VoteCache() for NodePubkey lookups. +// Used at epoch boundaries when epochVoteAcctsMap may not be available. +// Returns nil schedule if no valid stakes are available. +// Also returns all valid and skipped entries for CSV dump. +func buildLocalLeaderScheduleFromVoteCache( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + voteAcctStakes map[solana.PublicKey]uint64, +) (*leaderschedule.LeaderSchedule, ValidationStats, []StakeEntry, []StakeEntry) { + stats := ValidationStats{ + TotalVoteAccts: len(voteAcctStakes), + MinStake: ^uint64(0), // Start with max value + } + + voteCache := global.VoteCache() + + // Collect ALL valid and skipped entries for CSV dump + var validEntries []StakeEntry + var skippedEntries []StakeEntry + + // Build epochVoteAccts map from vote cache + epochVoteAccts := make(map[solana.PublicKey]*epochstakes.VoteAccount) + filteredStakes := make(map[solana.PublicKey]uint64) + + for votePk, stake := range voteAcctStakes { + if stake == 0 { + stats.SkippedZeroStake++ + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "zero_stake", + }) + continue + } + + vs := voteCache[votePk] + if vs == nil { + stats.SkippedMissingVoteAcct++ + stats.SkippedMissingVoteAcctStake += stake + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "missing_vote_cache", + }) + // Track first few for quick debugging in logs + if len(stats.MissingVoteAccts) < 5 { + stats.MissingVoteAccts = append(stats.MissingVoteAccts, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + }) + } + continue + } + + nodePk := vs.NodePubkey() + var zeroPk solana.PublicKey + if nodePk == zeroPk { + stats.SkippedMissingNodePk++ + stats.SkippedMissingNodePkStake += stake + skippedEntries = append(skippedEntries, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + Reason: "zero_nodepk", + }) + // Track first few for quick debugging in logs + if len(stats.ZeroNodePkAccts) < 5 { + stats.ZeroNodePkAccts = append(stats.ZeroNodePkAccts, StakeEntry{ + VoteAcct: votePk, + Stake: stake, + }) + } + continue + } + + // Create a VoteAccount with the NodePubkey + va := &epochstakes.VoteAccount{ + NodePubkey: nodePk, + } + epochVoteAccts[votePk] = va + filteredStakes[votePk] = stake + stats.TotalStake += stake + + // Track min/max + if stake < stats.MinStake { + stats.MinStake = stake + } + if stake > stats.MaxStake { + stats.MaxStake = stake + } + + validEntries = append(validEntries, StakeEntry{ + VoteAcct: votePk, + NodePubkey: nodePk, + Stake: stake, + }) + } + + stats.ValidatorCount = len(validEntries) + + // Guard: empty stakes would panic in weightedrand + if len(filteredStakes) == 0 { + stats.MinStake = 0 // Reset since no valid entries + return nil, stats, validEntries, skippedEntries + } + + // Sort entries by stake descending, then node pubkey descending (matches schedule computation) + sort.Slice(validEntries, func(i, j int) bool { + if validEntries[i].Stake != validEntries[j].Stake { + return validEntries[i].Stake > validEntries[j].Stake + } + // Tie-break by node pubkey descending (higher bytes first) - matches Agave + return bytes.Compare(validEntries[i].NodePubkey[:], validEntries[j].NodePubkey[:]) > 0 + }) + + // Capture top 10 and bottom 10 for log summary + for i := 0; i < min(10, len(validEntries)); i++ { + stats.TopStakes = append(stats.TopStakes, validEntries[i]) + } + for i := max(0, len(validEntries)-10); i < len(validEntries); i++ { + stats.BottomStakes = append(stats.BottomStakes, validEntries[i]) + } + + // Get epoch length (handles warmup epochs correctly) + slotsInEpoch := epochSchedule.SlotsInEpoch(epoch) + + ls := leaderschedule.New( + epochVoteAccts, + filteredStakes, + epochSchedule, + epoch, + slotsInEpoch, + NumConsecutiveLeaderSlots, + ) + + return ls, stats, validEntries, skippedEntries +} + +// scheduleFullHash computes a SHA256 hash of the entire leader schedule. +// Returns base64-encoded first 16 bytes of the hash. +// Takes ~20-50ms for a full epoch (432k slots). +func scheduleFullHash(ls *leaderschedule.LeaderSchedule, firstSlot uint64, numSlots uint64) string { + if ls == nil { + return "nil" + } + + h := sha256.New() + for i := uint64(0); i < numSlots; i++ { + slot := firstSlot + i + leader, ok := ls.LeaderForSlot(slot) + if ok { + h.Write(leader[:]) + } + } + + return base64.StdEncoding.EncodeToString(h.Sum(nil)[:16]) +} + +// PrepareLeaderScheduleLocal builds the leader schedule from local state and sets it as the source of truth. +// This is the primary entry point for leader schedule - no RPC dependency. +// Returns the schedule summary (for RPC validation) and error if schedule cannot be built. +func PrepareLeaderScheduleLocal( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + logsDir string, +) (*ScheduleSummary, error) { + voteAcctStakes := global.EpochStakes(epoch) + voteAcctMap := global.EpochStakesVoteAccts(epoch) + + // The RNG seed uses `epoch` directly (the epoch we're building the schedule for) + // Note: LeaderScheduleEpoch() returns something different (next epoch's prep slot) - don't use it here + firstSlot := epochSchedule.FirstSlotInEpoch(epoch) + numSlots := epochSchedule.SlotsInEpoch(epoch) + + if len(voteAcctStakes) == 0 { + mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=no_stake_data", epoch) + mlog.Log.FileOnlyf(" rng_epoch=%d first_slot=%d slots=%d", epoch, firstSlot, numSlots) + mlog.Log.FileOnlyf(" EpochStakes(%d) returned nil or empty", epoch) + return nil, fmt.Errorf("no stake data available for epoch %d", epoch) + } + + schedule, stats, validEntries, skippedEntries := buildLocalLeaderSchedule(epoch, epochSchedule, voteAcctStakes, voteAcctMap) + + // Calculate total input stake (before filtering) + var totalInputStake uint64 + for _, stake := range voteAcctStakes { + totalInputStake += stake + } + + if schedule == nil { + logHardFailContext(epoch, "no_valid_stakes_after_filtering", stats) + // Still dump whatever data we have for debugging even on failure + dumpFullScheduleData(epoch, "local", validEntries, skippedEntries, stats.TotalStake, logsDir) + return nil, fmt.Errorf("could not build leader schedule for epoch %d: no valid stakes after filtering (zero_stake=%d, missing_nodepk=%d, missing_vote_acct=%d)", + epoch, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) + } + + // Set as source of truth + global.SetLeaderSchedule(schedule) + + // Compute hash for logging + fullHash := scheduleFullHash(schedule, firstSlot, numSlots) + + // Log comprehensive summary + logScheduleBuildSummary(epoch, epoch, firstSlot, numSlots, "snapshot", stats, fullHash) + + // Build summary with all metadata + // Include all missing stake: missing_vote_acct + zero_nodepk + missingStake := stats.SkippedMissingVoteAcctStake + stats.SkippedMissingNodePkStake + var missingPercent float64 + if totalInputStake > 0 { + missingPercent = float64(missingStake) / float64(totalInputStake) * 100.0 + } + summary := ScheduleSummary{ + BlockEpoch: epoch, + ScheduleEpoch: epoch, // RNG seed epoch = block epoch + FirstSlot: firstSlot, + SlotsInEpoch: numSlots, + Repeat: NumConsecutiveLeaderSlots, + TotalInputStake: totalInputStake, + FilteredStake: stats.TotalStake, + MissingStake: missingStake, + MissingStakePercent: missingPercent, + ValidatorsInput: stats.TotalVoteAccts, + ValidatorsUsed: stats.ValidatorCount, + ValidatorsSkipped: stats.SkippedZeroStake + stats.SkippedMissingVoteAcct + stats.SkippedMissingNodePk, + SkippedZeroStake: stats.SkippedZeroStake, + SkippedMissingData: stats.SkippedMissingVoteAcct, + SkippedZeroNodePk: stats.SkippedMissingNodePk, + LocalHash: fullHash, + RunID: mlog.GetRunID(), + Source: "snapshot", // From snapshot loading at startup + Timestamp: time.Now().UTC(), + } + + // Dump ALL validators, skipped accounts, and summary to files + dumpFullScheduleDataWithSummary(validEntries, skippedEntries, summary, logsDir) + + // Dump tie-break debug info (shows how equal-stake validators are ordered) + DumpTieBreakDebug(epoch, voteAcctStakes, voteAcctMap, logsDir) + + // Dump first 1000 slots if dump flag is set (for debugging against RPC) + if config.GetBool("replay.dump_leader_schedule") { + DumpLeaderSchedule(epoch, epochSchedule, schedule, logsDir, 1000) + } + + return &summary, nil +} + +// PrepareLeaderScheduleLocalFromVoteCache builds the leader schedule using vote cache for NodePubkey lookups. +// Used at epoch boundaries when EpochStakesVoteAccts may not have the new epoch's data yet. +// Returns the schedule summary (for RPC validation) and error if schedule cannot be built. +func PrepareLeaderScheduleLocalFromVoteCache( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + logsDir string, +) (*ScheduleSummary, error) { + voteAcctStakes := global.EpochStakes(epoch) + + // The RNG seed uses `epoch` directly (the epoch we're building the schedule for) + // Note: LeaderScheduleEpoch() returns something different (next epoch's prep slot) - don't use it here + firstSlot := epochSchedule.FirstSlotInEpoch(epoch) + numSlots := epochSchedule.SlotsInEpoch(epoch) + + if len(voteAcctStakes) == 0 { + mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=no_stake_data", epoch) + mlog.Log.FileOnlyf(" rng_epoch=%d first_slot=%d slots=%d source=vote_cache", epoch, firstSlot, numSlots) + mlog.Log.FileOnlyf(" EpochStakes(%d) returned nil or empty", epoch) + mlog.Log.FileOnlyf(" VoteCache size=%d", len(global.VoteCache())) + return nil, fmt.Errorf("no stake data available for epoch %d", epoch) + } + + schedule, stats, validEntries, skippedEntries := buildLocalLeaderScheduleFromVoteCache(epoch, epochSchedule, voteAcctStakes) + + // Calculate total input stake (before filtering) + var totalInputStake uint64 + for _, stake := range voteAcctStakes { + totalInputStake += stake + } + + if schedule == nil { + logHardFailContext(epoch, "no_valid_stakes_after_filtering (vote_cache)", stats) + // Still dump whatever data we have for debugging even on failure + dumpFullScheduleData(epoch, "local_vote_cache", validEntries, skippedEntries, stats.TotalStake, logsDir) + return nil, fmt.Errorf("could not build leader schedule for epoch %d: no valid stakes after filtering (zero_stake=%d, missing_nodepk=%d, missing_vote_state=%d)", + epoch, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) + } + + // Safety check: fail if too much stake is missing from VoteCache. + // Since local schedule is the source of truth, missing entries produce incorrect schedules. + missingStake := stats.SkippedMissingVoteAcctStake + if totalInputStake > 0 && missingStake > 0 { + missingPercent := float64(missingStake) / float64(totalInputStake) * 100.0 + if missingPercent > MaxMissingVoteCacheStakePercent { + logHardFailContext(epoch, fmt.Sprintf("vote_cache_too_incomplete (%.2f%% > %.1f%%)", missingPercent, MaxMissingVoteCacheStakePercent), stats) + // Dump data even on failure for debugging + dumpFullScheduleData(epoch, "local_vote_cache", validEntries, skippedEntries, stats.TotalStake, logsDir) + return nil, fmt.Errorf("vote cache too incomplete for epoch %d: %.2f%% stake missing (threshold %.1f%%), missing_accts=%d missing_stake=%d total_stake=%d", + epoch, missingPercent, MaxMissingVoteCacheStakePercent, + stats.SkippedMissingVoteAcct, missingStake, totalInputStake) + } + // Log warning if any stake is missing, even below threshold + mlog.Log.Warnf("leader schedule: epoch=%d has %.2f%% stake missing from VoteCache (count=%d stake=%d)", + epoch, missingPercent, stats.SkippedMissingVoteAcct, missingStake) + } + + // Set as source of truth + global.SetLeaderSchedule(schedule) + + // Compute hash for logging + fullHash := scheduleFullHash(schedule, firstSlot, numSlots) + + // Log comprehensive summary + logScheduleBuildSummary(epoch, epoch, firstSlot, numSlots, "vote_cache", stats, fullHash) + + // Build summary with all metadata + // Include all missing stake: missing_vote_acct + zero_nodepk + totalMissingStake := stats.SkippedMissingVoteAcctStake + stats.SkippedMissingNodePkStake + var missingPercent float64 + if totalInputStake > 0 { + missingPercent = float64(totalMissingStake) / float64(totalInputStake) * 100.0 + } + summary := ScheduleSummary{ + BlockEpoch: epoch, + ScheduleEpoch: epoch, // RNG seed epoch = block epoch + FirstSlot: firstSlot, + SlotsInEpoch: numSlots, + Repeat: NumConsecutiveLeaderSlots, + TotalInputStake: totalInputStake, + FilteredStake: stats.TotalStake, + MissingStake: totalMissingStake, + MissingStakePercent: missingPercent, + ValidatorsInput: stats.TotalVoteAccts, + ValidatorsUsed: stats.ValidatorCount, + ValidatorsSkipped: stats.SkippedZeroStake + stats.SkippedMissingVoteAcct + stats.SkippedMissingNodePk, + SkippedZeroStake: stats.SkippedZeroStake, + SkippedMissingData: stats.SkippedMissingVoteAcct, + SkippedZeroNodePk: stats.SkippedMissingNodePk, + LocalHash: fullHash, + RunID: mlog.GetRunID(), + Source: "transition", // From epoch boundary transition + Timestamp: time.Now().UTC(), + } + + // Dump ALL validators, skipped accounts, and summary to files + dumpFullScheduleDataWithSummary(validEntries, skippedEntries, summary, logsDir) + + // Dump first 1000 slots if dump flag is set (for debugging against RPC) + if config.GetBool("replay.dump_leader_schedule") { + DumpLeaderSchedule(epoch, epochSchedule, schedule, logsDir, 1000) + } + + return &summary, nil +} + +// DumpLeaderSchedule writes the first N slots of the schedule to a file for debugging. +// File is written to logsDir/leader_schedule_dump_epoch.txt +// Useful for comparing against RPC getLeaderSchedule results. +func DumpLeaderSchedule( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + schedule *leaderschedule.LeaderSchedule, + logsDir string, + numSlots int, +) { + if schedule == nil { + mlog.Log.Warnf("DumpLeaderSchedule: schedule is nil") + return + } + + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("DumpLeaderSchedule: failed to create logs dir: %v", err) + return + } + + filename := fmt.Sprintf("leader_schedule_dump_epoch%d.txt", epoch) + filepath := filepath.Join(logsDir, filename) + + f, err := os.Create(filepath) + if err != nil { + mlog.Log.Warnf("DumpLeaderSchedule: failed to create file: %v", err) + return + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + firstSlot := epochSchedule.FirstSlotInEpoch(epoch) + totalSlots := epochSchedule.SlotsInEpoch(epoch) + + // Write header + w.WriteString(fmt.Sprintf("# Leader Schedule Dump - Epoch %d\n", epoch)) + w.WriteString(fmt.Sprintf("# First slot: %d\n", firstSlot)) + w.WriteString(fmt.Sprintf("# Total slots in epoch: %d\n", totalSlots)) + w.WriteString(fmt.Sprintf("# Dumping first %d slots\n", numSlots)) + w.WriteString("# Format: slot_offset,absolute_slot,leader_pubkey\n") + w.WriteString("#\n") + + // Dump first N slots + for i := 0; i < numSlots && uint64(i) < totalSlots; i++ { + slot := firstSlot + uint64(i) + leader, ok := schedule.LeaderForSlot(slot) + if ok { + w.WriteString(fmt.Sprintf("%d,%d,%s\n", i, slot, leader.String())) + } else { + w.WriteString(fmt.Sprintf("%d,%d,NOT_FOUND\n", i, slot)) + } + } + + mlog.Log.FileOnlyf("leader schedule dumped to: %s (first %d slots)", filepath, numSlots) +} + +// dumpScheduleSlotsCSV dumps the full schedule to a CSV for slot-by-slot comparison. +// Format: slot,leader_pubkey (simple format for easy diffing) +// Called when mismatch is detected or when replay.dump_leader_schedule is set. +func dumpScheduleSlotsCSV( + epoch uint64, + source string, // "local" or "rpc" + schedule *leaderschedule.LeaderSchedule, + firstSlot uint64, + numSlots uint64, + logsDir string, +) string { + if schedule == nil { + return "" + } + + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("dumpScheduleSlotsCSV: failed to create logs dir: %v", err) + return "" + } + + // Get short run ID for filename + runID := mlog.GetRunID() + shortRunID := "" + if runID != "" { + shortRunID = runID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + shortRunID = "_" + shortRunID + } + + filename := fmt.Sprintf("epoch%d_%s_slots%s.csv", epoch, source, shortRunID) + filePath := filepath.Join(logsDir, filename) + + f, err := os.Create(filePath) + if err != nil { + mlog.Log.Warnf("dumpScheduleSlotsCSV: failed to create file: %v", err) + return "" + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + // Minimal header - just slot,leader for easy diffing + w.WriteString("slot,leader\n") + + // Dump all slots + for i := uint64(0); i < numSlots; i++ { + slot := firstSlot + i + leader, ok := schedule.LeaderForSlot(slot) + if ok { + w.WriteString(fmt.Sprintf("%d,%s\n", slot, leader.String())) + } else { + w.WriteString(fmt.Sprintf("%d,\n", slot)) // Empty leader for missing + } + } + + mlog.Log.FileOnlyf("leader schedule slots dumped to: %s (%d slots)", filePath, numSlots) + return filePath +} + +// DumpScheduleMismatch dumps both local and RPC schedules to CSV files for analysis. +// Called when a hash mismatch is detected during validation. +// Returns paths to local and RPC slot files. +func DumpScheduleMismatch( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + localSchedule *leaderschedule.LeaderSchedule, + rpcSchedule *leaderschedule.LeaderSchedule, + logsDir string, +) (localPath, rpcPath string) { + firstSlot := epochSchedule.FirstSlotInEpoch(epoch) + numSlots := epochSchedule.SlotsInEpoch(epoch) + + localPath = dumpScheduleSlotsCSV(epoch, "local", localSchedule, firstSlot, numSlots, logsDir) + rpcPath = dumpScheduleSlotsCSV(epoch, "rpc", rpcSchedule, firstSlot, numSlots, logsDir) + + if localPath != "" && rpcPath != "" { + mlog.Log.FileOnlyf("schedule mismatch dumps: local=%s rpc=%s", localPath, rpcPath) + mlog.Log.FileOnlyf(" run: scripts/diff_leader_schedules.py %s %s", localPath, rpcPath) + } + + return localPath, rpcPath +} + +// dumpRPCValidatorList extracts validators from RPC schedule and dumps to CSV. +// Since RPC only gives us slot -> leader, we count slot appearances per leader. +// File is named epoch_rpc__validators.csv for comparison with local. +func dumpRPCValidatorList( + epoch uint64, + rpcSchedule *leaderschedule.LeaderSchedule, + firstSlot uint64, + numSlots uint64, + logsDir string, +) { + if rpcSchedule == nil { + return + } + + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("dumpRPCValidatorList: failed to create logs dir: %v", err) + return + } + + // Count slot appearances per leader + leaderSlots := make(map[solana.PublicKey]uint64) + for i := uint64(0); i < numSlots; i++ { + slot := firstSlot + i + leader, ok := rpcSchedule.LeaderForSlot(slot) + if ok { + leaderSlots[leader]++ + } + } + + // Build entries sorted by slot count (descending) for comparison with local + type rpcEntry struct { + leader solana.PublicKey + slotCount uint64 + } + entries := make([]rpcEntry, 0, len(leaderSlots)) + for leader, count := range leaderSlots { + entries = append(entries, rpcEntry{leader: leader, slotCount: count}) + } + sort.Slice(entries, func(i, j int) bool { + if entries[i].slotCount != entries[j].slotCount { + return entries[i].slotCount > entries[j].slotCount + } + // Tie-break by pubkey descending (matches local sort) + return bytes.Compare(entries[i].leader[:], entries[j].leader[:]) > 0 + }) + + // Get short run ID for filename + runID := mlog.GetRunID() + shortRunID := "" + if runID != "" { + shortRunID = runID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + shortRunID = "_" + shortRunID + } + + filename := fmt.Sprintf("epoch%d_rpc%s_validators.csv", epoch, shortRunID) + filePath := filepath.Join(logsDir, filename) + + f, err := os.Create(filePath) + if err != nil { + mlog.Log.Warnf("dumpRPCValidatorList: failed to create file: %v", err) + return + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + // Header + w.WriteString(fmt.Sprintf("# RPC Leader Schedule - Epoch %d\n", epoch)) + w.WriteString("# Source: rpc\n") + w.WriteString(fmt.Sprintf("# Total Leaders: %d\n", len(entries))) + w.WriteString(fmt.Sprintf("# Total Slots: %d\n", numSlots)) + w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) + w.WriteString("#\n") + w.WriteString("# NOTE: RPC schedule only provides slot->leader mapping.\n") + w.WriteString("# Stake is not available from RPC, so we show slot_count instead.\n") + w.WriteString("# Compare slot_count with local schedule to identify discrepancies.\n") + w.WriteString("#\n") + w.WriteString("rank,node_pubkey,slot_count\n") + + for i, e := range entries { + w.WriteString(fmt.Sprintf("%d,%s,%d\n", i+1, e.leader, e.slotCount)) + } + + mlog.Log.FileOnlyf("RPC validator list dumped to: %s (%d leaders)", filePath, len(entries)) +} + +// BackgroundValidateAgainstRPC optionally validates local schedule against RPC in background. +// This is purely for debugging and does not affect the source of truth. +// Computes full SHA256 hash of entire schedule (~20-50ms) for complete comparison. +// Always writes a validation summary file with full local summary and RPC hash. +// Also dumps RPC-derived validator list for comparison. +func BackgroundValidateAgainstRPC( + epoch uint64, + epochSchedule *sealevel.SysvarEpochSchedule, + localSchedule *leaderschedule.LeaderSchedule, + rpcSchedule *leaderschedule.LeaderSchedule, + localSummary *ScheduleSummary, + logsDir string, +) { + if rpcSchedule == nil || localSchedule == nil { + return + } + + firstSlot := epochSchedule.FirstSlotInEpoch(epoch) + numSlots := epochSchedule.SlotsInEpoch(epoch) + + // Compute full hash for RPC schedule + rpcHash := scheduleFullHash(rpcSchedule, firstSlot, numSlots) + + // Use local summary's hash if available, else compute + localHash := localSummary.LocalHash + if localHash == "" { + localHash = scheduleFullHash(localSchedule, firstSlot, numSlots) + } + + matched := localHash == rpcHash + + // Update summary with RPC data and write validation file + localSummary.RPCHash = rpcHash + + // Always write validation summary file with full local summary + RPC data + writeValidationSummary(localSummary, matched, logsDir) + + if matched { + mlog.Log.FileOnlyf("leader schedule RPC validation: epoch=%d MATCH hash=%s", epoch, localHash) + return + } + + // Only dump RPC validator list on mismatch (expensive I/O) + dumpRPCValidatorList(epoch, rpcSchedule, firstSlot, numSlots, logsDir) + + // Hashes differ - log to mismatch file with details + initMismatchLog(logsDir) + + mismatchLogMu.Lock() + if mismatchLogWriter != nil { + mismatchLogWriter.WriteString(fmt.Sprintf("\n[%s] RPC VALIDATION MISMATCH epoch=%d\n", time.Now().Format(time.RFC3339), epoch)) + mismatchLogWriter.WriteString(fmt.Sprintf(" local_hash=%s rpc_hash=%s\n", localHash, rpcHash)) + } + mismatchLogMu.Unlock() + + mlog.Log.Warnf("leader schedule RPC validation: MISMATCH epoch=%d local_hash=%s rpc_hash=%s - see %s", + epoch, localHash, rpcHash, getMismatchLogPath()) + + flushMismatchLog() + + // Dump both schedules to CSV for detailed analysis + DumpScheduleMismatch(epoch, epochSchedule, localSchedule, rpcSchedule, logsDir) +} + +// writeValidationSummary writes a summary file with full local summary and RPC comparison. +func writeValidationSummary(summary *ScheduleSummary, matched bool, logsDir string) { + logsDir = resolveLogsDir(logsDir) + if err := os.MkdirAll(logsDir, 0755); err != nil { + mlog.Log.Warnf("writeValidationSummary: failed to create logs dir: %v", err) + return + } + + shortRunID := "" + if summary.RunID != "" { + shortRunID = summary.RunID + if len(shortRunID) > 8 { + shortRunID = shortRunID[:8] + } + shortRunID = "_" + shortRunID + } + + filename := fmt.Sprintf("epoch%d_validation%s.txt", summary.BlockEpoch, shortRunID) + filePath := filepath.Join(logsDir, filename) + + f, err := os.Create(filePath) + if err != nil { + mlog.Log.Warnf("writeValidationSummary: failed to create file: %v", err) + return + } + defer f.Close() + + w := bufio.NewWriter(f) + defer w.Flush() + + status := "MATCH" + if !matched { + status = "MISMATCH" + } + + w.WriteString("# Leader Schedule Validation Summary\n") + w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) + w.WriteString(fmt.Sprintf("# Run ID: %s\n", summary.RunID)) + w.WriteString("#\n") + w.WriteString(fmt.Sprintf("## Result: %s\n\n", status)) + + // Epoch Info (same as local summary) + w.WriteString("## Epoch Info\n") + w.WriteString(fmt.Sprintf("block_epoch=%d\n", summary.BlockEpoch)) + w.WriteString(fmt.Sprintf("schedule_epoch=%d\n", summary.ScheduleEpoch)) + w.WriteString(fmt.Sprintf("first_slot=%d\n", summary.FirstSlot)) + w.WriteString(fmt.Sprintf("slots_in_epoch=%d\n", summary.SlotsInEpoch)) + w.WriteString(fmt.Sprintf("repeat=%d\n", summary.Repeat)) + w.WriteString(fmt.Sprintf("source=%s\n", summary.Source)) + w.WriteString("\n") + + // Stake Info + w.WriteString("## Stake Info\n") + w.WriteString(fmt.Sprintf("total_input_stake=%d\n", summary.TotalInputStake)) + w.WriteString(fmt.Sprintf("filtered_stake=%d\n", summary.FilteredStake)) + w.WriteString(fmt.Sprintf("missing_stake=%d\n", summary.MissingStake)) + w.WriteString(fmt.Sprintf("missing_stake_percent=%.4f\n", summary.MissingStakePercent)) + w.WriteString("\n") + + // Validator Counts + w.WriteString("## Validator Counts\n") + w.WriteString(fmt.Sprintf("validators_input=%d\n", summary.ValidatorsInput)) + w.WriteString(fmt.Sprintf("validators_used=%d\n", summary.ValidatorsUsed)) + w.WriteString(fmt.Sprintf("validators_skipped=%d\n", summary.ValidatorsSkipped)) + w.WriteString(fmt.Sprintf("skipped_zero_stake=%d\n", summary.SkippedZeroStake)) + w.WriteString(fmt.Sprintf("skipped_missing_data=%d\n", summary.SkippedMissingData)) + w.WriteString(fmt.Sprintf("skipped_zero_nodepk=%d\n", summary.SkippedZeroNodePk)) + w.WriteString("\n") + + // Hashes - local and RPC side by side + w.WriteString("## Comparison\n") + w.WriteString(fmt.Sprintf("local_hash=%s\n", summary.LocalHash)) + w.WriteString(fmt.Sprintf("rpc_hash=%s\n", summary.RPCHash)) + w.WriteString(fmt.Sprintf("\nstatus=%s\n", status)) + + mlog.Log.FileOnlyf("leader schedule validation summary written to: %s", filePath) } diff --git a/pkg/replay/leader_schedule_local.go b/pkg/replay/leader_schedule_local.go deleted file mode 100644 index 7812bd34..00000000 --- a/pkg/replay/leader_schedule_local.go +++ /dev/null @@ -1,2072 +0,0 @@ -package replay - -import ( - "bufio" - "bytes" - "crypto/sha256" - "encoding/base64" - "fmt" - "os" - "path/filepath" - "sort" - "sync" - "sync/atomic" - "time" - - "github.com/Overclock-Validator/mithril/pkg/accountsdb" - "github.com/Overclock-Validator/mithril/pkg/config" - "github.com/Overclock-Validator/mithril/pkg/epochstakes" - "github.com/Overclock-Validator/mithril/pkg/global" - "github.com/Overclock-Validator/mithril/pkg/leaderschedule" - "github.com/Overclock-Validator/mithril/pkg/mlog" - "github.com/Overclock-Validator/mithril/pkg/rpcclient" - "github.com/Overclock-Validator/mithril/pkg/sealevel" - "github.com/gagliardetto/solana-go" - "github.com/panjf2000/ants/v2" - "golang.org/x/exp/rand" -) - -const ( - // NumConsecutiveLeaderSlots matches Solana's NUM_CONSECUTIVE_LEADER_SLOTS - NumConsecutiveLeaderSlots = 4 - // MaxMismatchLogsPerEpoch caps mismatch logging to avoid disk churn - MaxMismatchLogsPerEpoch = 100 - // SampleBoundarySlots is how many slots to check at epoch boundaries - SampleBoundarySlots = 2000 - // SampleRandomSlots is how many random slots to sample in the middle - SampleRandomSlots = 1000 - // MaxMissingVoteCacheStakePercent is the maximum percentage of stake that can be - // missing from VoteCache before we fail. Since local schedule is the source of truth, - // missing VoteCache entries mean that stake is dropped from the schedule, which would - // produce an incorrect schedule. - // Set to 0 for zero tolerance - any missing stake is a hard failure. - // The VoteCache rebuild from AccountsDB should ensure this never triggers. - MaxMissingVoteCacheStakePercent = 0.0 - // DefaultVoteCacheRebuildConcurrency is the default number of concurrent workers - // for rebuilding vote cache from AccountsDB at epoch boundaries. - DefaultVoteCacheRebuildConcurrency = 16 -) - -var ( - mismatchLogOnce sync.Once - mismatchLogFile *os.File - mismatchLogWriter *bufio.Writer - mismatchLogMu sync.Mutex -) - -// defaultLogsDir is the fallback directory for mismatch logs -const defaultLogsDir = "/mnt/mithril-logs" - -// mismatchLogPath stores the resolved path for use in warnings -var mismatchLogPath string - -// resolveLogsDir returns the leader_schedule subdirectory within the run directory. -// Creates a dedicated subdirectory to keep leader schedule files organized. -func resolveLogsDir(logsDir string) string { - var baseDir string - // First try mlog's directory (for run ID correlation) - if mlogDir := mlog.GetLogDir(); mlogDir != "" { - baseDir = mlogDir - } else if logsDir != "" { - baseDir = logsDir - } else { - baseDir = defaultLogsDir - } - // Return leader_schedule subdirectory - return filepath.Join(baseDir, "leader_schedule") -} - -// initMismatchLog creates/opens the mismatch log file (once per process). -// Uses the same log directory as Mithril's main logs with run ID for correlation. -func initMismatchLog(logsDir string) { - mismatchLogOnce.Do(func() { - logsDir = resolveLogsDir(logsDir) - // Create directory if it doesn't exist - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("failed to create mismatch log directory: %v", err) - return - } - - // Use run ID in filename for correlation with main log - runID := mlog.GetRunID() - var filename string - if runID != "" { - shortRunID := runID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - filename = fmt.Sprintf("mismatch_%s.log", shortRunID) - } else { - filename = "mismatch.log" - } - mismatchLogPath = filepath.Join(logsDir, filename) - - var err error - mismatchLogFile, err = os.OpenFile(mismatchLogPath, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644) - if err != nil { - mlog.Log.Warnf("failed to open leader schedule mismatch log: %v", err) - return - } - mismatchLogWriter = bufio.NewWriter(mismatchLogFile) - mlog.Log.FileOnlyf("leader schedule mismatch log: %s", mismatchLogPath) - }) -} - -// getMismatchLogPath returns the path to the mismatch log file -func getMismatchLogPath() string { - if mismatchLogPath != "" { - return mismatchLogPath - } - return filepath.Join(resolveLogsDir(""), "leader_schedule_mismatch.log") -} - -// flushMismatchLog flushes buffered writes (call at end of epoch validation) -func flushMismatchLog() { - mismatchLogMu.Lock() - defer mismatchLogMu.Unlock() - if mismatchLogWriter != nil { - mismatchLogWriter.Flush() - } -} - -// logMismatch writes a mismatch entry (capped per epoch to avoid disk churn) -func logMismatch(epoch, slot uint64, localLeader, rpcLeader solana.PublicKey, - voteAcct solana.PublicKey, stake uint64, mismatchCount *int) { - if mismatchLogWriter == nil || *mismatchCount >= MaxMismatchLogsPerEpoch { - return - } - *mismatchCount++ - mismatchLogMu.Lock() - defer mismatchLogMu.Unlock() - entry := fmt.Sprintf("[%s] epoch=%d slot=%d local=%s rpc=%s vote_acct=%s stake=%d\n", - time.Now().Format(time.RFC3339), epoch, slot, localLeader, rpcLeader, voteAcct, stake) - mismatchLogWriter.WriteString(entry) -} - -// logInputSnapshot writes the top stakes to the mismatch log for debugging -func logInputSnapshot(epoch uint64, voteAcctStakes map[solana.PublicKey]uint64, - voteAcctMap map[solana.PublicKey]*epochstakes.VoteAccount) { - if mismatchLogWriter == nil { - return - } - - // Sort by stake descending to get top 10 - type stakeEntry struct { - voteAcct solana.PublicKey - stake uint64 - nodePk solana.PublicKey - } - entries := make([]stakeEntry, 0, len(voteAcctStakes)) - for pk, stake := range voteAcctStakes { - var nodePk solana.PublicKey - if va := voteAcctMap[pk]; va != nil { - nodePk = va.NodePubkey - } - entries = append(entries, stakeEntry{voteAcct: pk, stake: stake, nodePk: nodePk}) - } - sort.Slice(entries, func(i, j int) bool { - return entries[i].stake > entries[j].stake - }) - - mismatchLogMu.Lock() - defer mismatchLogMu.Unlock() - - mismatchLogWriter.WriteString(fmt.Sprintf("\n[INPUTS] epoch=%d top_stakes:\n", epoch)) - for i := 0; i < min(10, len(entries)); i++ { - e := entries[i] - mismatchLogWriter.WriteString(fmt.Sprintf(" %d. vote=%s node=%s stake=%d\n", - i+1, e.voteAcct, e.nodePk, e.stake)) - } -} - -// VoteCacheRebuildError holds info about a failed vote account for logging -type VoteCacheRebuildError struct { - VoteAcct solana.PublicKey - Stake uint64 - Reason string - Err error -} - -// RebuildVoteCacheFromAccountsDB rebuilds the VoteCache from AccountsDB for all vote accounts -// in the stake map. This ensures correctness at epoch boundaries by reading the canonical -// state directly from AccountsDB. -// -// Parameters: -// - acctsDb: the AccountsDB instance -// - slot: the slot at which to read account state (typically lastSlotCtx.Slot) -// - voteAcctStakes: the stake map for the new epoch (vote account -> stake) -// - maxConcurrency: number of concurrent workers (0 = use default) -// -// Returns error if any vote account is missing or has invalid state. -// This is a blocking operation and should be called before building the leader schedule. -func RebuildVoteCacheFromAccountsDB( - acctsDb *accountsdb.AccountsDb, - slot uint64, - voteAcctStakes map[solana.PublicKey]uint64, - maxConcurrency int, -) error { - if maxConcurrency <= 0 { - maxConcurrency = DefaultVoteCacheRebuildConcurrency - } - - startTime := time.Now() - totalAccounts := len(voteAcctStakes) - var zeroStakeCount int - for _, stake := range voteAcctStakes { - if stake == 0 { - zeroStakeCount++ - } - } - nonZeroAccounts := totalAccounts - zeroStakeCount - - mlog.Log.FileOnlyf("vote cache rebuild: starting slot=%d accounts=%d (non-zero=%d) concurrency=%d", - slot, totalAccounts, nonZeroAccounts, maxConcurrency) - - // Counters for stats (use atomics for thread safety) - var successCount atomic.Int64 - var missingCount atomic.Int64 - var unmarshalErrCount atomic.Int64 - var zeroNodePkCount atomic.Int64 - var missingStake atomic.Uint64 - var unmarshalErrStake atomic.Uint64 - var zeroNodePkStake atomic.Uint64 - - // Track first few errors for each category (with mutex for thread safety) - const maxErrorsPerCategory = 5 - var errorsMu sync.Mutex - var missingErrors []VoteCacheRebuildError - var unmarshalErrors []VoteCacheRebuildError - var zeroNodePkErrors []VoteCacheRebuildError - - // Track first error for reporting (use sync.Once to capture exactly one error) - var firstError error - var firstErrorOnce sync.Once - - // Create worker pool - var wg sync.WaitGroup - pool, err := ants.NewPoolWithFunc(maxConcurrency, func(i interface{}) { - defer wg.Done() - - item := i.(struct { - pk solana.PublicKey - stake uint64 - }) - - // Read vote account from AccountsDB - voteAcct, err := acctsDb.GetAccount(slot, item.pk) - if err != nil { - global.DeleteVoteCacheItem(item.pk) - missingCount.Add(1) - missingStake.Add(item.stake) - errorsMu.Lock() - if len(missingErrors) < maxErrorsPerCategory { - missingErrors = append(missingErrors, VoteCacheRebuildError{ - VoteAcct: item.pk, - Stake: item.stake, - Reason: "not_found_in_accountsdb", - Err: err, - }) - } - errorsMu.Unlock() - firstErrorOnce.Do(func() { - firstError = fmt.Errorf("missing vote account %s (stake=%d): %w", item.pk, item.stake, err) - }) - return - } - - // Unmarshal vote state - versionedVoteState, err := sealevel.UnmarshalVersionedVoteState(voteAcct.Data) - if err != nil { - global.DeleteVoteCacheItem(item.pk) - unmarshalErrCount.Add(1) - unmarshalErrStake.Add(item.stake) - errorsMu.Lock() - if len(unmarshalErrors) < maxErrorsPerCategory { - unmarshalErrors = append(unmarshalErrors, VoteCacheRebuildError{ - VoteAcct: item.pk, - Stake: item.stake, - Reason: fmt.Sprintf("unmarshal_failed (data_len=%d)", len(voteAcct.Data)), - Err: err, - }) - } - errorsMu.Unlock() - firstErrorOnce.Do(func() { - firstError = fmt.Errorf("failed to unmarshal vote account %s (stake=%d): %w", item.pk, item.stake, err) - }) - return - } - - // Validate NodePubkey is non-zero - nodePk := versionedVoteState.NodePubkey() - var zeroPk solana.PublicKey - if nodePk == zeroPk { - global.DeleteVoteCacheItem(item.pk) - zeroNodePkCount.Add(1) - zeroNodePkStake.Add(item.stake) - errorsMu.Lock() - if len(zeroNodePkErrors) < maxErrorsPerCategory { - zeroNodePkErrors = append(zeroNodePkErrors, VoteCacheRebuildError{ - VoteAcct: item.pk, - Stake: item.stake, - Reason: "zero_nodepubkey", - }) - } - errorsMu.Unlock() - firstErrorOnce.Do(func() { - firstError = fmt.Errorf("vote account %s has zero NodePubkey (stake=%d)", item.pk, item.stake) - }) - return - } - - // Update VoteCache - global.PutVoteCacheItem(item.pk, versionedVoteState) - successCount.Add(1) - }) - if err != nil { - return fmt.Errorf("failed to create worker pool: %w", err) - } - defer pool.Release() - - // Submit all vote accounts to the pool - for pk, stake := range voteAcctStakes { - if stake == 0 { - continue // Skip zero-stake accounts - } - wg.Add(1) - item := struct { - pk solana.PublicKey - stake uint64 - }{pk: pk, stake: stake} - if err := pool.Invoke(item); err != nil { - wg.Done() - return fmt.Errorf("failed to submit work to pool: %w", err) - } - } - - // Wait for all workers to complete - wg.Wait() - - duration := time.Since(startTime) - - // Calculate total stake for percentage - var totalStake uint64 - for _, stake := range voteAcctStakes { - totalStake += stake - } - successStake := totalStake - missingStake.Load() - unmarshalErrStake.Load() - zeroNodePkStake.Load() - - // File only: single line summary - skipped := nonZeroAccounts - int(successCount.Load()) - mlog.Log.FileOnlyf("Vote cache: loaded=%d skipped=%d duration=%v", - successCount.Load(), skipped, duration) - - // File only: detailed results - mlog.Log.FileOnlyf("vote cache rebuild details: slot=%d duration=%v", slot, duration) - mlog.Log.FileOnlyf(" accounts: total=%d non_zero=%d success=%d", - totalAccounts, nonZeroAccounts, successCount.Load()) - mlog.Log.FileOnlyf(" stake: total=%d success=%d (%.2f%%)", - totalStake, successStake, float64(successStake)/float64(totalStake)*100) - - // Check for any failures - missing := missingCount.Load() - unmarshalErr := unmarshalErrCount.Load() - zeroNodePk := zeroNodePkCount.Load() - totalFailed := missing + unmarshalErr + zeroNodePk - - if totalFailed > 0 { - totalFailedStake := missingStake.Load() + unmarshalErrStake.Load() + zeroNodePkStake.Load() - failedPercent := float64(totalFailedStake) / float64(totalStake) * 100 - - // File only: detailed failure info (always log for debugging) - mlog.Log.FileOnlyf("vote cache rebuild failures:") - mlog.Log.FileOnlyf(" slot=%d", slot) - mlog.Log.FileOnlyf(" failures: missing=%d (stake=%d) unmarshal_err=%d (stake=%d) zero_nodepk=%d (stake=%d)", - missing, missingStake.Load(), unmarshalErr, unmarshalErrStake.Load(), zeroNodePk, zeroNodePkStake.Load()) - mlog.Log.FileOnlyf(" total_failed=%d total_failed_stake=%d (%.4f%% of total)", - totalFailed, totalFailedStake, failedPercent) - - // File only: first few errors in each category - if len(missingErrors) > 0 { - mlog.Log.FileOnlyf(" missing_accounts (first %d):", len(missingErrors)) - for i, e := range missingErrors { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d err=%v", i+1, e.VoteAcct, e.Stake, e.Err) - } - } - if len(unmarshalErrors) > 0 { - mlog.Log.FileOnlyf(" unmarshal_errors (first %d):", len(unmarshalErrors)) - for i, e := range unmarshalErrors { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d reason=%s err=%v", i+1, e.VoteAcct, e.Stake, e.Reason, e.Err) - } - } - if len(zeroNodePkErrors) > 0 { - mlog.Log.FileOnlyf(" zero_nodepk_accounts (first %d):", len(zeroNodePkErrors)) - for i, e := range zeroNodePkErrors { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) - } - } - - // Small percentage of unavailable vote accounts is expected on mainnet (dead/closed validators) - // Only ERROR if significant stake is missing - otherwise it's just noise - if failedPercent > 5.0 { - mlog.Log.Errorf("VOTE CACHE REBUILD: slot=%d skipped=%d (%.4f%% stake) - exceeds threshold", - slot, totalFailed, failedPercent) - if firstError != nil { - return fmt.Errorf("vote cache rebuild: %d unavailable (%.4f%% stake): %w", - totalFailed, failedPercent, firstError) - } - return fmt.Errorf("vote cache rebuild: %d unavailable (%.4f%% stake)", - totalFailed, failedPercent) - } - - // Expected mainnet behavior - log to file only - mlog.Log.FileOnlyf("vote cache rebuild: slot=%d skipped=%d unavailable vote accounts (%.4f%% stake)", - slot, totalFailed, failedPercent) - return nil - } - - mlog.Log.FileOnlyf(" result: SUCCESS (all %d non-zero accounts rebuilt)", nonZeroAccounts) - return nil -} - -// StakeEntry holds a vote account and its stake for logging -type StakeEntry struct { - VoteAcct solana.PublicKey - NodePubkey solana.PublicKey - Stake uint64 - Reason string // For skipped entries: "zero_stake", "missing_vote_acct", "zero_nodepk" -} - -// dumpFullScheduleData writes complete validator data to CSV files for debugging. -// Creates epoch-specific files in the logs directory with ALL validators. -// Includes run ID in filename to prevent overwriting on re-runs. -func dumpFullScheduleData( - epoch uint64, - source string, // "snapshot", "vote_cache", or "rpc" - validEntries []StakeEntry, - skippedEntries []StakeEntry, - totalStake uint64, - logsDir string, -) { - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("dumpFullScheduleData: failed to create logs dir: %v", err) - return - } - - // Get short run ID for filename (prevents overwriting on re-runs) - runID := mlog.GetRunID() - shortRunID := "" - if runID != "" { - shortRunID = runID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - // Write validators CSV - validatorsFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_validators.csv", epoch, source, shortRunID)) - if err := writeValidatorsCSV(validatorsFile, epoch, source, validEntries, totalStake); err != nil { - mlog.Log.Warnf("dumpFullScheduleData: failed to write validators CSV: %v", err) - } else { - mlog.Log.FileOnlyf("leader schedule validators dumped to: %s (%d entries)", validatorsFile, len(validEntries)) - } - - // Write skipped CSV if there are any - if len(skippedEntries) > 0 { - skippedFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_skipped.csv", epoch, source, shortRunID)) - if err := writeSkippedCSV(skippedFile, epoch, skippedEntries); err != nil { - mlog.Log.Warnf("dumpFullScheduleData: failed to write skipped CSV: %v", err) - } else { - mlog.Log.FileOnlyf("leader schedule skipped accounts dumped to: %s (%d entries)", skippedFile, len(skippedEntries)) - } - } -} - -// dumpFullScheduleDataWithSummary writes validators CSV, skipped CSV, and a summary file. -// This is the preferred function when all metadata is available. -// Includes run ID in filenames to prevent overwriting on re-runs. -func dumpFullScheduleDataWithSummary( - validEntries []StakeEntry, - skippedEntries []StakeEntry, - summary ScheduleSummary, - logsDir string, -) { - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to create logs dir: %v", err) - return - } - - epoch := summary.BlockEpoch - source := summary.Source - - // Get short run ID for filename (prevents overwriting on re-runs) - shortRunID := "" - if summary.RunID != "" { - shortRunID = summary.RunID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - // Write validators CSV - validatorsFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_validators.csv", epoch, source, shortRunID)) - if err := writeValidatorsCSV(validatorsFile, epoch, source, validEntries, summary.FilteredStake); err != nil { - mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write validators CSV: %v", err) - } else { - mlog.Log.FileOnlyf("leader schedule validators dumped to: %s (%d entries)", validatorsFile, len(validEntries)) - } - - // Write skipped CSV if there are any - if len(skippedEntries) > 0 { - skippedFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_skipped.csv", epoch, source, shortRunID)) - if err := writeSkippedCSV(skippedFile, epoch, skippedEntries); err != nil { - mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write skipped CSV: %v", err) - } else { - mlog.Log.FileOnlyf("leader schedule skipped accounts dumped to: %s (%d entries)", skippedFile, len(skippedEntries)) - } - } - - // Write summary file - summaryFile := filepath.Join(logsDir, fmt.Sprintf("epoch%d_%s%s_summary.txt", epoch, source, shortRunID)) - if err := writeSummaryFile(summaryFile, summary); err != nil { - mlog.Log.Warnf("dumpFullScheduleDataWithSummary: failed to write summary: %v", err) - } else { - mlog.Log.FileOnlyf("leader schedule summary dumped to: %s", summaryFile) - } -} - -// DumpTieBreakDebug writes tie-break debugging info to a file. -// This verifies that equal-stake validators are sorted by pubkey DESC (Agave behavior). -func DumpTieBreakDebug( - epoch uint64, - voteAcctStakes map[solana.PublicKey]uint64, - voteAcctMap map[solana.PublicKey]*epochstakes.VoteAccount, - logsDir string, -) { - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("DumpTieBreakDebug: failed to create logs dir: %v", err) - return - } - - runID := mlog.GetRunID() - shortRunID := "" - if runID != "" { - shortRunID = runID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - filename := fmt.Sprintf("epoch%d_tiebreak%s.txt", epoch, shortRunID) - filePath := filepath.Join(logsDir, filename) - - f, err := os.Create(filePath) - if err != nil { - mlog.Log.Warnf("DumpTieBreakDebug: failed to create file: %v", err) - return - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - // Get sorted stakes with tie-break info - allEntries, tieGroups := leaderschedule.GetSortedStakesDebug(voteAcctMap, voteAcctStakes) - - w.WriteString("# Tie-Break Debug for Leader Schedule\n") - w.WriteString(fmt.Sprintf("# Epoch: %d\n", epoch)) - w.WriteString(fmt.Sprintf("# Total validators: %d\n", len(allEntries))) - w.WriteString(fmt.Sprintf("# Tie groups (equal stake): %d\n", len(tieGroups))) - w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) - w.WriteString("#\n") - w.WriteString("# Expected behavior: within each tie group, pubkeys should be sorted DESC (higher bytes first)\n") - w.WriteString("# BytesCmp shows comparison vs previous entry: -1 means current < previous (correct for DESC)\n") - w.WriteString("#\n\n") - - if len(tieGroups) == 0 { - w.WriteString("No tie groups found - all validators have unique stake.\n") - mlog.Log.FileOnlyf("tie-break debug: epoch=%d no ties found", epoch) - return - } - - // Sort tie groups by stake descending for consistent output - type tieGroupInfo struct { - stake uint64 - entries []leaderschedule.TieBreakEntry - } - var sortedGroups []tieGroupInfo - for stake, entries := range tieGroups { - sortedGroups = append(sortedGroups, tieGroupInfo{stake: stake, entries: entries}) - } - sort.Slice(sortedGroups, func(i, j int) bool { - return sortedGroups[i].stake > sortedGroups[j].stake - }) - - for _, group := range sortedGroups { - w.WriteString(fmt.Sprintf("## Tie group: stake=%d (%d validators)\n", group.stake, len(group.entries))) - w.WriteString("rank,node_pubkey,stake,first_8_bytes_hex,bytes_cmp_vs_prev\n") - for _, entry := range group.entries { - w.WriteString(fmt.Sprintf("%d,%s,%d,%x,%d\n", - entry.Rank, entry.NodePk.String(), entry.Stake, entry.RawBytes, entry.BytesCmp)) - } - w.WriteString("\n") - } - - // Log to file only (not terminal) - mlog.Log.FileOnlyf("tie-break debug: epoch=%d tie_groups=%d written to %s", epoch, len(tieGroups), filePath) - - // Log the specific tie if we're looking for it (stake 2499999939665440) - if group, ok := tieGroups[2499999939665440]; ok { - mlog.Log.FileOnlyf("tie-break debug: found target tie group stake=2499999939665440:") - for _, entry := range group { - mlog.Log.FileOnlyf(" rank=%d node=%s bytes_cmp=%d", entry.Rank, entry.NodePk.String(), entry.BytesCmp) - } - } - - // Diagnostic: Check specific vote account → node mappings for epoch 905 debugging - // Vote accounts that caused the tie-break mismatch: - debugVoteAccts := []struct { - vote string - expectedNode string - }{ - {"33hurzEz6aEnzfESL6pnNyR6DCgcKzssT1pwSzDCBTRQ", "Aw5wEMXhbygFLR7jHtHpih8QvxVBGAMTqsQ2SjWPk1ex"}, - {"BU3ZgGBXFJwNTrN6VUJ88k9SJ71SyWfBJTabYqRErm4F", "2GUnfxZavKoPfS9s3VSEjaWDzB3vNf5RojUhprCS1rSx"}, - } - for _, d := range debugVoteAccts { - votePk := solana.MustPublicKeyFromBase58(d.vote) - expectedNodePk := solana.MustPublicKeyFromBase58(d.expectedNode) - stake, hasStake := voteAcctStakes[votePk] - va := voteAcctMap[votePk] - if hasStake || va != nil { - var actualNode solana.PublicKey - if va != nil { - actualNode = va.NodePubkey - } - match := actualNode == expectedNodePk - mlog.Log.FileOnlyf("vote-node-mapping: vote=%s expected_node=%s actual_node=%s stake=%d match=%v", - d.vote, d.expectedNode, actualNode.String(), stake, match) - if !match { - mlog.Log.Warnf("VOTE-NODE MISMATCH: vote=%s expected=%s actual=%s stake=%d", - d.vote, d.expectedNode, actualNode.String(), stake) - } - } - } -} - -// writeValidatorsCSV writes all validators to a CSV file -func writeValidatorsCSV(filepath string, epoch uint64, source string, entries []StakeEntry, totalStake uint64) error { - f, err := os.Create(filepath) - if err != nil { - return err - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - // Header comments - w.WriteString(fmt.Sprintf("# Leader Schedule - Epoch %d\n", epoch)) - w.WriteString(fmt.Sprintf("# Source: %s\n", source)) - w.WriteString(fmt.Sprintf("# Total Validators: %d\n", len(entries))) - w.WriteString(fmt.Sprintf("# Total Stake: %d\n", totalStake)) - w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) - w.WriteString("#\n") - w.WriteString("rank,vote_account,node_pubkey,stake,stake_percent\n") - - // Write all entries (already sorted by stake descending) - for i, e := range entries { - var stakePercent float64 - if totalStake > 0 { - stakePercent = float64(e.Stake) / float64(totalStake) * 100.0 - } - w.WriteString(fmt.Sprintf("%d,%s,%s,%d,%.6f\n", - i+1, e.VoteAcct, e.NodePubkey, e.Stake, stakePercent)) - } - - return nil -} - -// writeSkippedCSV writes all skipped accounts to a CSV file -func writeSkippedCSV(filepath string, epoch uint64, entries []StakeEntry) error { - f, err := os.Create(filepath) - if err != nil { - return err - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - // Header comments - w.WriteString(fmt.Sprintf("# Leader Schedule Skipped Accounts - Epoch %d\n", epoch)) - w.WriteString(fmt.Sprintf("# Total Skipped: %d\n", len(entries))) - w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) - w.WriteString("#\n") - w.WriteString("# Reasons:\n") - w.WriteString("# zero_stake - Vote account has 0 stake\n") - w.WriteString("# missing_vote_acct - Vote account not found in VoteAcctMap\n") - w.WriteString("# missing_vote_cache - Vote account not found in VoteCache\n") - w.WriteString("# zero_nodepk - Vote account has zero NodePubkey\n") - w.WriteString("#\n") - w.WriteString("# Note: node_pubkey is empty for missing_vote_acct/missing_vote_cache since\n") - w.WriteString("# the vote account data was not available to extract the NodePubkey.\n") - w.WriteString("#\n") - w.WriteString("vote_account,node_pubkey,stake,reason\n") - - for _, e := range entries { - w.WriteString(fmt.Sprintf("%s,%s,%d,%s\n", e.VoteAcct, e.NodePubkey, e.Stake, e.Reason)) - } - - return nil -} - -// ScheduleSummary holds all metadata for the summary file -type ScheduleSummary struct { - // Epoch info - BlockEpoch uint64 - ScheduleEpoch uint64 - FirstSlot uint64 - SlotsInEpoch uint64 - Repeat uint64 - - // Stake info - TotalInputStake uint64 // Total stake from EpochStakes (before filtering) - FilteredStake uint64 // Stake used in schedule (after filtering) - MissingStake uint64 // Stake skipped due to missing data - MissingStakePercent float64 - - // Validator counts - ValidatorsInput int // Total vote accounts in EpochStakes - ValidatorsUsed int // Validators included in schedule - ValidatorsSkipped int // Validators skipped (zero stake + missing + zero nodepk) - SkippedZeroStake int - SkippedMissingData int // missing_vote_acct or missing_vote_cache - SkippedZeroNodePk int - - // Hashes - LocalHash string - RPCHash string // Empty if RPC validation not enabled - - // Run info - RunID string - Source string // "snapshot" or "vote_cache" - Timestamp time.Time -} - -// writeSummaryFile writes a comprehensive summary file for the epoch -func writeSummaryFile(filepath string, summary ScheduleSummary) error { - f, err := os.Create(filepath) - if err != nil { - return err - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - w.WriteString("# Leader Schedule Summary\n") - w.WriteString(fmt.Sprintf("# Generated: %s\n", summary.Timestamp.Format(time.RFC3339))) - w.WriteString(fmt.Sprintf("# Run ID: %s\n", summary.RunID)) - w.WriteString("#\n") - - w.WriteString("## Epoch Info\n") - w.WriteString(fmt.Sprintf("block_epoch=%d\n", summary.BlockEpoch)) - w.WriteString(fmt.Sprintf("schedule_epoch=%d\n", summary.ScheduleEpoch)) - w.WriteString(fmt.Sprintf("first_slot=%d\n", summary.FirstSlot)) - w.WriteString(fmt.Sprintf("slots_in_epoch=%d\n", summary.SlotsInEpoch)) - w.WriteString(fmt.Sprintf("repeat=%d\n", summary.Repeat)) - w.WriteString(fmt.Sprintf("source=%s\n", summary.Source)) - w.WriteString("\n") - - w.WriteString("## Stake Info\n") - w.WriteString(fmt.Sprintf("total_input_stake=%d\n", summary.TotalInputStake)) - w.WriteString(fmt.Sprintf("filtered_stake=%d\n", summary.FilteredStake)) - w.WriteString(fmt.Sprintf("missing_stake=%d\n", summary.MissingStake)) - w.WriteString(fmt.Sprintf("missing_stake_percent=%.4f\n", summary.MissingStakePercent)) - w.WriteString("\n") - - w.WriteString("## Validator Counts\n") - w.WriteString(fmt.Sprintf("validators_input=%d\n", summary.ValidatorsInput)) - w.WriteString(fmt.Sprintf("validators_used=%d\n", summary.ValidatorsUsed)) - w.WriteString(fmt.Sprintf("validators_skipped=%d\n", summary.ValidatorsSkipped)) - w.WriteString(fmt.Sprintf("skipped_zero_stake=%d\n", summary.SkippedZeroStake)) - w.WriteString(fmt.Sprintf("skipped_missing_data=%d\n", summary.SkippedMissingData)) - w.WriteString(fmt.Sprintf("skipped_zero_nodepk=%d\n", summary.SkippedZeroNodePk)) - w.WriteString("\n") - - w.WriteString("## Hashes\n") - w.WriteString(fmt.Sprintf("local_hash=%s\n", summary.LocalHash)) - if summary.RPCHash != "" { - w.WriteString(fmt.Sprintf("rpc_hash=%s\n", summary.RPCHash)) - } - w.WriteString("\n") - - return nil -} - -// ValidationStats holds statistics from schedule validation -type ValidationStats struct { - SkippedZeroStake int - SkippedMissingNodePk int - SkippedMissingNodePkStake uint64 // Stake dropped due to zero NodePubkey - SkippedMissingVoteAcct int - SkippedMissingVoteAcctStake uint64 // Stake dropped due to missing VoteCache entries - TotalVoteAccts int - TotalStake uint64 - MinStake uint64 - MaxStake uint64 - ValidatorCount int // Validators with non-zero stake and valid NodePubkey - MismatchCount int - Capped bool - TopStakes []StakeEntry // Top 10 by stake - BottomStakes []StakeEntry // Bottom 10 by stake - MissingVoteAccts []StakeEntry // First few missing vote accounts (for debugging) - ZeroNodePkAccts []StakeEntry // First few zero NodePubkey accounts -} - -// logScheduleBuildSummary logs a comprehensive summary of the schedule build. -// Called once per epoch when building the leader schedule. -// Terminal output is minimal; detailed info goes to log file only. -func logScheduleBuildSummary( - epoch uint64, - scheduleEpoch uint64, - firstSlot uint64, - slotsInEpoch uint64, - source string, // "snapshot" or "vote_cache" - stats ValidationStats, - fullHash string, -) { - // File only: single line summary - mlog.Log.FileOnlyf("leader schedule: epoch=%d validators=%d stake=%d hash=%s", - epoch, stats.ValidatorCount, stats.TotalStake, fullHash) - - // File only: detailed build info - mlog.Log.FileOnlyf("leader schedule build details: epoch=%d schedule_epoch=%d first_slot=%d slots=%d repeat=%d source=%s", - epoch, scheduleEpoch, firstSlot, slotsInEpoch, NumConsecutiveLeaderSlots, source) - mlog.Log.FileOnlyf(" validators=%d total_stake=%d min_stake=%d max_stake=%d zero_stake_count=%d", - stats.ValidatorCount, stats.TotalStake, stats.MinStake, stats.MaxStake, stats.SkippedZeroStake) - mlog.Log.FileOnlyf(" hash=%s", fullHash) - mlog.Log.FileOnlyf(" skipped: missing_vote_acct=%d (stake=%d) missing_nodepk=%d (stake=%d)", - stats.SkippedMissingVoteAcct, stats.SkippedMissingVoteAcctStake, stats.SkippedMissingNodePk, stats.SkippedMissingNodePkStake) - - // File only: top 10 stakes - if len(stats.TopStakes) > 0 { - mlog.Log.FileOnlyf(" top_stakes (showing %d):", len(stats.TopStakes)) - for i, e := range stats.TopStakes { - mlog.Log.FileOnlyf(" %2d. vote=%s node=%s stake=%d", - i+1, e.VoteAcct, e.NodePubkey, e.Stake) - } - } - - // File only: bottom 10 stakes - if len(stats.BottomStakes) > 0 { - mlog.Log.FileOnlyf(" bottom_stakes (showing %d):", len(stats.BottomStakes)) - for i, e := range stats.BottomStakes { - mlog.Log.FileOnlyf(" %2d. vote=%s node=%s stake=%d", - i+1, e.VoteAcct, e.NodePubkey, e.Stake) - } - } - - // File only: offending accounts if any were skipped - if len(stats.MissingVoteAccts) > 0 { - mlog.Log.FileOnlyf(" missing_vote_accts (first %d):", len(stats.MissingVoteAccts)) - for i, e := range stats.MissingVoteAccts { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) - } - } - if len(stats.ZeroNodePkAccts) > 0 { - mlog.Log.FileOnlyf(" zero_nodepk_accts (first %d):", len(stats.ZeroNodePkAccts)) - for i, e := range stats.ZeroNodePkAccts { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) - } - } -} - -// logHardFailContext logs detailed context when schedule build fails. -// Terminal shows brief error; file gets full details. -func logHardFailContext( - epoch uint64, - reason string, - stats ValidationStats, - voteAcctStakes map[solana.PublicKey]uint64, -) { - // Terminal: brief error - mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=%s", epoch, reason) - - // File only: detailed context - mlog.Log.FileOnlyf("LEADER SCHEDULE BUILD FAILED DETAILS:") - mlog.Log.FileOnlyf(" epoch=%d reason=%s", epoch, reason) - mlog.Log.FileOnlyf(" input_vote_accts=%d total_stake_available=%d", - stats.TotalVoteAccts, stats.TotalStake) - mlog.Log.FileOnlyf(" skipped: zero_stake=%d missing_vote_acct=%d (stake=%d) missing_nodepk=%d (stake=%d)", - stats.SkippedZeroStake, stats.SkippedMissingVoteAcct, stats.SkippedMissingVoteAcctStake, stats.SkippedMissingNodePk, stats.SkippedMissingNodePkStake) - - // File only: first few offending accounts - if len(stats.MissingVoteAccts) > 0 { - mlog.Log.FileOnlyf(" missing_vote_accts (first %d):", len(stats.MissingVoteAccts)) - for i, e := range stats.MissingVoteAccts { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) - } - } - if len(stats.ZeroNodePkAccts) > 0 { - mlog.Log.FileOnlyf(" zero_nodepk_accts (first %d):", len(stats.ZeroNodePkAccts)) - for i, e := range stats.ZeroNodePkAccts { - mlog.Log.FileOnlyf(" %d. vote=%s stake=%d", i+1, e.VoteAcct, e.Stake) - } - } - - // File only: valid top stakes for context - if len(stats.TopStakes) > 0 { - mlog.Log.FileOnlyf(" top_stakes_found (showing %d):", min(5, len(stats.TopStakes))) - for i := 0; i < min(5, len(stats.TopStakes)); i++ { - e := stats.TopStakes[i] - mlog.Log.FileOnlyf(" %d. vote=%s node=%s stake=%d", i+1, e.VoteAcct, e.NodePubkey, e.Stake) - } - } -} - -// buildLocalLeaderSchedule builds a leader schedule from local state. -// Returns nil schedule if no valid stakes are available. -// Also returns all valid and skipped entries for CSV dump. -func buildLocalLeaderSchedule( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - voteAcctStakes map[solana.PublicKey]uint64, - voteAcctMap map[solana.PublicKey]*epochstakes.VoteAccount, -) (*leaderschedule.LeaderSchedule, ValidationStats, []StakeEntry, []StakeEntry) { - stats := ValidationStats{ - TotalVoteAccts: len(voteAcctStakes), - MinStake: ^uint64(0), // Start with max value - } - - // Collect ALL valid and skipped entries for CSV dump - var validEntries []StakeEntry - var skippedEntries []StakeEntry - - // Filter and build epochVoteAccts map (only entries with stake > 0 and valid NodePubkey) - epochVoteAccts := make(map[solana.PublicKey]*epochstakes.VoteAccount) - filteredStakes := make(map[solana.PublicKey]uint64) - - for votePk, stake := range voteAcctStakes { - if stake == 0 { - stats.SkippedZeroStake++ - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "zero_stake", - }) - continue - } - - va := voteAcctMap[votePk] - if va == nil { - stats.SkippedMissingVoteAcct++ - stats.SkippedMissingVoteAcctStake += stake - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "missing_vote_acct", - }) - // Track first few for quick debugging in logs - if len(stats.MissingVoteAccts) < 5 { - stats.MissingVoteAccts = append(stats.MissingVoteAccts, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - }) - } - continue - } - - // Check for zero NodePubkey (missing) - var zeroPk solana.PublicKey - if va.NodePubkey == zeroPk { - stats.SkippedMissingNodePk++ - stats.SkippedMissingNodePkStake += stake - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "zero_nodepk", - }) - // Track first few for quick debugging in logs - if len(stats.ZeroNodePkAccts) < 5 { - stats.ZeroNodePkAccts = append(stats.ZeroNodePkAccts, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - }) - } - continue - } - - epochVoteAccts[votePk] = va - filteredStakes[votePk] = stake - stats.TotalStake += stake - - // Track min/max - if stake < stats.MinStake { - stats.MinStake = stake - } - if stake > stats.MaxStake { - stats.MaxStake = stake - } - - validEntries = append(validEntries, StakeEntry{ - VoteAcct: votePk, - NodePubkey: va.NodePubkey, - Stake: stake, - }) - } - - stats.ValidatorCount = len(validEntries) - - // Guard: empty stakes would panic in weightedrand - if len(filteredStakes) == 0 { - stats.MinStake = 0 // Reset since no valid entries - return nil, stats, validEntries, skippedEntries - } - - // Sort entries by stake descending, then node pubkey descending (matches schedule computation) - sort.Slice(validEntries, func(i, j int) bool { - if validEntries[i].Stake != validEntries[j].Stake { - return validEntries[i].Stake > validEntries[j].Stake - } - // Tie-break by node pubkey descending (higher bytes first) - matches Agave - return bytes.Compare(validEntries[i].NodePubkey[:], validEntries[j].NodePubkey[:]) > 0 - }) - - // Capture top 10 and bottom 10 for log summary - for i := 0; i < min(10, len(validEntries)); i++ { - stats.TopStakes = append(stats.TopStakes, validEntries[i]) - } - for i := max(0, len(validEntries)-10); i < len(validEntries); i++ { - stats.BottomStakes = append(stats.BottomStakes, validEntries[i]) - } - - // Get epoch length (handles warmup epochs correctly) - slotsInEpoch := epochSchedule.SlotsInEpoch(epoch) - - // Build the schedule using leaderschedule.New - ls := leaderschedule.New( - epochVoteAccts, - filteredStakes, - epochSchedule, - epoch, - slotsInEpoch, - NumConsecutiveLeaderSlots, - ) - - return ls, stats, validEntries, skippedEntries -} - -// buildLocalLeaderScheduleFromVoteCache builds schedule using global.VoteCache() for NodePubkey lookups. -// Used at epoch boundaries when epochVoteAcctsMap may not be available. -// Returns nil schedule if no valid stakes are available. -// Also returns all valid and skipped entries for CSV dump. -func buildLocalLeaderScheduleFromVoteCache( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - voteAcctStakes map[solana.PublicKey]uint64, -) (*leaderschedule.LeaderSchedule, ValidationStats, []StakeEntry, []StakeEntry) { - stats := ValidationStats{ - TotalVoteAccts: len(voteAcctStakes), - MinStake: ^uint64(0), // Start with max value - } - - voteCache := global.VoteCache() - - // Collect ALL valid and skipped entries for CSV dump - var validEntries []StakeEntry - var skippedEntries []StakeEntry - - // Build epochVoteAccts map from vote cache - epochVoteAccts := make(map[solana.PublicKey]*epochstakes.VoteAccount) - filteredStakes := make(map[solana.PublicKey]uint64) - - for votePk, stake := range voteAcctStakes { - if stake == 0 { - stats.SkippedZeroStake++ - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "zero_stake", - }) - continue - } - - vs := voteCache[votePk] - if vs == nil { - stats.SkippedMissingVoteAcct++ - stats.SkippedMissingVoteAcctStake += stake - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "missing_vote_cache", - }) - // Track first few for quick debugging in logs - if len(stats.MissingVoteAccts) < 5 { - stats.MissingVoteAccts = append(stats.MissingVoteAccts, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - }) - } - continue - } - - nodePk := vs.NodePubkey() - var zeroPk solana.PublicKey - if nodePk == zeroPk { - stats.SkippedMissingNodePk++ - stats.SkippedMissingNodePkStake += stake - skippedEntries = append(skippedEntries, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - Reason: "zero_nodepk", - }) - // Track first few for quick debugging in logs - if len(stats.ZeroNodePkAccts) < 5 { - stats.ZeroNodePkAccts = append(stats.ZeroNodePkAccts, StakeEntry{ - VoteAcct: votePk, - Stake: stake, - }) - } - continue - } - - // Create a VoteAccount with the NodePubkey - va := &epochstakes.VoteAccount{ - NodePubkey: nodePk, - } - epochVoteAccts[votePk] = va - filteredStakes[votePk] = stake - stats.TotalStake += stake - - // Track min/max - if stake < stats.MinStake { - stats.MinStake = stake - } - if stake > stats.MaxStake { - stats.MaxStake = stake - } - - validEntries = append(validEntries, StakeEntry{ - VoteAcct: votePk, - NodePubkey: nodePk, - Stake: stake, - }) - } - - stats.ValidatorCount = len(validEntries) - - // Guard: empty stakes would panic in weightedrand - if len(filteredStakes) == 0 { - stats.MinStake = 0 // Reset since no valid entries - return nil, stats, validEntries, skippedEntries - } - - // Sort entries by stake descending, then node pubkey descending (matches schedule computation) - sort.Slice(validEntries, func(i, j int) bool { - if validEntries[i].Stake != validEntries[j].Stake { - return validEntries[i].Stake > validEntries[j].Stake - } - // Tie-break by node pubkey descending (higher bytes first) - matches Agave - return bytes.Compare(validEntries[i].NodePubkey[:], validEntries[j].NodePubkey[:]) > 0 - }) - - // Capture top 10 and bottom 10 for log summary - for i := 0; i < min(10, len(validEntries)); i++ { - stats.TopStakes = append(stats.TopStakes, validEntries[i]) - } - for i := max(0, len(validEntries)-10); i < len(validEntries); i++ { - stats.BottomStakes = append(stats.BottomStakes, validEntries[i]) - } - - // Get epoch length (handles warmup epochs correctly) - slotsInEpoch := epochSchedule.SlotsInEpoch(epoch) - - ls := leaderschedule.New( - epochVoteAccts, - filteredStakes, - epochSchedule, - epoch, - slotsInEpoch, - NumConsecutiveLeaderSlots, - ) - - return ls, stats, validEntries, skippedEntries -} - -// scheduleFullHash computes a SHA256 hash of the entire leader schedule. -// Returns base64-encoded first 16 bytes of the hash. -// Takes ~20-50ms for a full epoch (432k slots). -func scheduleFullHash(ls *leaderschedule.LeaderSchedule, firstSlot uint64, numSlots uint64) string { - if ls == nil { - return "nil" - } - - h := sha256.New() - for i := uint64(0); i < numSlots; i++ { - slot := firstSlot + i - leader, ok := ls.LeaderForSlot(slot) - if ok { - h.Write(leader[:]) - } - } - - return base64.StdEncoding.EncodeToString(h.Sum(nil)[:16]) -} - -// validateLeaderSchedule compares local vs RPC schedule and logs mismatches. -// Does NOT return error - mismatches are logged but don't stop replay. -func validateLeaderSchedule( - blockEpoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcSchedule *leaderschedule.LeaderSchedule, - logsDir string, -) { - if rpcSchedule == nil { - mlog.Log.Warnf("leader schedule validation: rpc schedule is nil, skipping") - return - } - - // Initialize mismatch log file (once per process) - initMismatchLog(logsDir) - - // Stakes are stored under the epoch they're EFFECTIVE for, not the boundary epoch. - // E.g., stakes frozen at end of epoch 499 are effective during epoch 500, - // so they're stored as EpochStakes(500). LeaderScheduleEpoch returns 499 (the - // boundary), but lookup should use blockEpoch (500). - firstSlot := epochSchedule.FirstSlotInEpoch(blockEpoch) - - // Fetch stakes for blockEpoch (stored under the epoch they're effective for) - voteAcctStakes := global.EpochStakes(blockEpoch) - voteAcctMap := global.EpochStakesVoteAccts(blockEpoch) - - // Guard: skip if no stake data available for this epoch - if len(voteAcctStakes) == 0 { - mlog.Log.Warnf("leader schedule validation: no stake data for epoch=%d, skipping", blockEpoch) - return - } - - // Use blockEpoch for slot count (this is the epoch we're building schedule for) - numSlots := epochSchedule.SlotsInEpoch(blockEpoch) - - // Log input snapshot for debugging - logInputSnapshot(blockEpoch, voteAcctStakes, voteAcctMap) - - // Build local schedule for blockEpoch (same as RPC) - localSchedule, stats, _, _ := buildLocalLeaderSchedule(blockEpoch, epochSchedule, voteAcctStakes, voteAcctMap) - - // Guard: skip if local schedule couldn't be built (empty stakes after filtering) - if localSchedule == nil { - mlog.Log.Warnf("leader schedule validation: could not build local schedule (no valid stakes), skipping") - mlog.Log.FileOnlyf(" epoch=%d vote_accts=%d skipped: zero_stake=%d missing_nodepk=%d missing_vote_acct=%d", - blockEpoch, stats.TotalVoteAccts, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - return - } - - // Sample and compare slots - mismatchCount := 0 - - // Generate slots to sample: first 2k, last 2k, plus random 1k in middle - slotsToSample := make([]uint64, 0, SampleBoundarySlots*2+SampleRandomSlots) - - // First boundary slots - for i := uint64(0); i < min(SampleBoundarySlots, numSlots); i++ { - slotsToSample = append(slotsToSample, firstSlot+i) - } - - // Last boundary slots - if numSlots > SampleBoundarySlots { - startLast := numSlots - min(SampleBoundarySlots, numSlots) - for i := startLast; i < numSlots; i++ { - slotsToSample = append(slotsToSample, firstSlot+i) - } - } - - // Random slots in middle (deterministic based on blockEpoch for reproducibility) - if numSlots > SampleBoundarySlots*2 { - rng := rand.New(rand.NewSource(blockEpoch)) - middleStart := uint64(SampleBoundarySlots) - middleEnd := numSlots - SampleBoundarySlots - for i := 0; i < SampleRandomSlots && middleEnd > middleStart; i++ { - offset := rng.Uint64() % (middleEnd - middleStart) - slotsToSample = append(slotsToSample, firstSlot+middleStart+offset) - } - } - - // Compare sampled slots - for _, slot := range slotsToSample { - localLeader, localOk := localSchedule.LeaderForSlot(slot) - rpcLeader, rpcOk := rpcSchedule.LeaderForSlot(slot) - - if !localOk || !rpcOk { - continue // Skip slots not in schedules - } - - if localLeader != rpcLeader { - // Find the vote account that maps to the local leader for debugging - var matchingVoteAcct solana.PublicKey - var matchingStake uint64 - for votePk, va := range voteAcctMap { - if va != nil && va.NodePubkey == localLeader { - matchingVoteAcct = votePk - matchingStake = voteAcctStakes[votePk] - break - } - } - logMismatch(blockEpoch, slot, localLeader, rpcLeader, matchingVoteAcct, matchingStake, &mismatchCount) - } - } - - stats.MismatchCount = mismatchCount - stats.Capped = mismatchCount >= MaxMismatchLogsPerEpoch - - // Flush mismatch log - flushMismatchLog() - - // File only: per-epoch validation summary - mlog.Log.FileOnlyf("leader schedule validation: epoch=%d first_slot=%d slots=%d", - blockEpoch, firstSlot, numSlots) - mlog.Log.FileOnlyf(" vote_accts=%d total_stake=%d skipped: zero_stake=%d missing_nodepk=%d missing_vote_acct=%d", - stats.TotalVoteAccts, stats.TotalStake, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - mlog.Log.FileOnlyf(" sampled=%d mismatches=%d (capped=%v)", len(slotsToSample), stats.MismatchCount, stats.Capped) - - // Terminal: only warn on mismatches - if stats.MismatchCount > 0 { - mlog.Log.Warnf("leader schedule validation: %d MISMATCHES epoch=%d - see %s", - stats.MismatchCount, blockEpoch, getMismatchLogPath()) - } -} - -// validateLeaderScheduleFromVoteCache validates using global.VoteCache() for NodePubkey lookups. -// Used at epoch boundaries when epochVoteAcctsMap may not be available from snapshot. -func validateLeaderScheduleFromVoteCache( - blockEpoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcSchedule *leaderschedule.LeaderSchedule, - logsDir string, -) { - if rpcSchedule == nil { - mlog.Log.Warnf("leader schedule validation: rpc schedule is nil, skipping") - return - } - - // Initialize mismatch log file (once per process) - initMismatchLog(logsDir) - - // Stakes are stored under the epoch they're EFFECTIVE for, not the boundary epoch. - firstSlot := epochSchedule.FirstSlotInEpoch(blockEpoch) - - // Fetch stakes for blockEpoch (stored under the epoch they're effective for) - voteAcctStakes := global.EpochStakes(blockEpoch) - - // Guard: skip if no stake data available for this epoch - if len(voteAcctStakes) == 0 { - mlog.Log.Warnf("leader schedule validation: no stake data for epoch=%d, skipping", blockEpoch) - return - } - - // Use blockEpoch for slot count (this is the epoch we're building schedule for) - numSlots := epochSchedule.SlotsInEpoch(blockEpoch) - - // Build local schedule for blockEpoch (same as RPC) - localSchedule, stats, _, _ := buildLocalLeaderScheduleFromVoteCache(blockEpoch, epochSchedule, voteAcctStakes) - - // Guard: skip if local schedule couldn't be built (empty stakes after filtering) - if localSchedule == nil { - mlog.Log.Warnf("leader schedule validation: could not build local schedule (no valid stakes), skipping") - mlog.Log.FileOnlyf(" epoch=%d vote_accts=%d skipped: zero_stake=%d missing_nodepk=%d missing_vote_state=%d", - blockEpoch, stats.TotalVoteAccts, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - return - } - - // Sample and compare slots - mismatchCount := 0 - - // Generate slots to sample: first 2k, last 2k, plus random 1k in middle - slotsToSample := make([]uint64, 0, SampleBoundarySlots*2+SampleRandomSlots) - - // First boundary slots - for i := uint64(0); i < min(SampleBoundarySlots, numSlots); i++ { - slotsToSample = append(slotsToSample, firstSlot+i) - } - - // Last boundary slots - if numSlots > SampleBoundarySlots { - startLast := numSlots - min(SampleBoundarySlots, numSlots) - for i := startLast; i < numSlots; i++ { - slotsToSample = append(slotsToSample, firstSlot+i) - } - } - - // Random slots in middle (deterministic based on blockEpoch for reproducibility) - if numSlots > SampleBoundarySlots*2 { - rng := rand.New(rand.NewSource(blockEpoch)) - middleStart := uint64(SampleBoundarySlots) - middleEnd := numSlots - SampleBoundarySlots - for i := 0; i < SampleRandomSlots && middleEnd > middleStart; i++ { - offset := rng.Uint64() % (middleEnd - middleStart) - slotsToSample = append(slotsToSample, firstSlot+middleStart+offset) - } - } - - // Compare sampled slots - voteCache := global.VoteCache() - for _, slot := range slotsToSample { - localLeader, localOk := localSchedule.LeaderForSlot(slot) - rpcLeader, rpcOk := rpcSchedule.LeaderForSlot(slot) - - if !localOk || !rpcOk { - continue - } - - if localLeader != rpcLeader { - // Find the vote account that maps to the local leader - var matchingVoteAcct solana.PublicKey - var matchingStake uint64 - for votePk, vs := range voteCache { - if vs != nil && vs.NodePubkey() == localLeader { - matchingVoteAcct = votePk - matchingStake = voteAcctStakes[votePk] - break - } - } - logMismatch(blockEpoch, slot, localLeader, rpcLeader, matchingVoteAcct, matchingStake, &mismatchCount) - } - } - - stats.MismatchCount = mismatchCount - stats.Capped = mismatchCount >= MaxMismatchLogsPerEpoch - - flushMismatchLog() - - // File only: per-epoch validation summary - mlog.Log.FileOnlyf("leader schedule validation (vote cache): epoch=%d first_slot=%d slots=%d", - blockEpoch, firstSlot, numSlots) - mlog.Log.FileOnlyf(" vote_accts=%d total_stake=%d skipped: zero_stake=%d missing_nodepk=%d missing_vote_state=%d", - stats.TotalVoteAccts, stats.TotalStake, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - mlog.Log.FileOnlyf(" sampled=%d mismatches=%d (capped=%v)", len(slotsToSample), stats.MismatchCount, stats.Capped) - - // Terminal: only warn on mismatches - if stats.MismatchCount > 0 { - mlog.Log.Warnf("leader schedule validation: %d MISMATCHES epoch=%d - see %s", - stats.MismatchCount, blockEpoch, getMismatchLogPath()) - } -} - -// PrepareLeaderScheduleLocal builds the leader schedule from local state and sets it as the source of truth. -// This is the primary entry point for leader schedule - no RPC dependency. -// Returns the schedule summary (for RPC validation) and error if schedule cannot be built. -func PrepareLeaderScheduleLocal( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - logsDir string, -) (*ScheduleSummary, error) { - voteAcctStakes := global.EpochStakes(epoch) - voteAcctMap := global.EpochStakesVoteAccts(epoch) - - // The RNG seed uses `epoch` directly (the epoch we're building the schedule for) - // Note: LeaderScheduleEpoch() returns something different (next epoch's prep slot) - don't use it here - firstSlot := epochSchedule.FirstSlotInEpoch(epoch) - numSlots := epochSchedule.SlotsInEpoch(epoch) - - if len(voteAcctStakes) == 0 { - mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=no_stake_data", epoch) - mlog.Log.FileOnlyf(" rng_epoch=%d first_slot=%d slots=%d", epoch, firstSlot, numSlots) - mlog.Log.FileOnlyf(" EpochStakes(%d) returned nil or empty", epoch) - return nil, fmt.Errorf("no stake data available for epoch %d", epoch) - } - - schedule, stats, validEntries, skippedEntries := buildLocalLeaderSchedule(epoch, epochSchedule, voteAcctStakes, voteAcctMap) - - // Calculate total input stake (before filtering) - var totalInputStake uint64 - for _, stake := range voteAcctStakes { - totalInputStake += stake - } - - if schedule == nil { - logHardFailContext(epoch, "no_valid_stakes_after_filtering", stats, voteAcctStakes) - // Still dump whatever data we have for debugging even on failure - dumpFullScheduleData(epoch, "local", validEntries, skippedEntries, stats.TotalStake, logsDir) - return nil, fmt.Errorf("could not build leader schedule for epoch %d: no valid stakes after filtering (zero_stake=%d, missing_nodepk=%d, missing_vote_acct=%d)", - epoch, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - } - - // Set as source of truth - global.SetLeaderSchedule(schedule) - - // Compute hash for logging - fullHash := scheduleFullHash(schedule, firstSlot, numSlots) - - // Log comprehensive summary - logScheduleBuildSummary(epoch, epoch, firstSlot, numSlots, "snapshot", stats, fullHash) - - // Build summary with all metadata - // Include all missing stake: missing_vote_acct + zero_nodepk - missingStake := stats.SkippedMissingVoteAcctStake + stats.SkippedMissingNodePkStake - var missingPercent float64 - if totalInputStake > 0 { - missingPercent = float64(missingStake) / float64(totalInputStake) * 100.0 - } - summary := ScheduleSummary{ - BlockEpoch: epoch, - ScheduleEpoch: epoch, // RNG seed epoch = block epoch - FirstSlot: firstSlot, - SlotsInEpoch: numSlots, - Repeat: NumConsecutiveLeaderSlots, - TotalInputStake: totalInputStake, - FilteredStake: stats.TotalStake, - MissingStake: missingStake, - MissingStakePercent: missingPercent, - ValidatorsInput: stats.TotalVoteAccts, - ValidatorsUsed: stats.ValidatorCount, - ValidatorsSkipped: stats.SkippedZeroStake + stats.SkippedMissingVoteAcct + stats.SkippedMissingNodePk, - SkippedZeroStake: stats.SkippedZeroStake, - SkippedMissingData: stats.SkippedMissingVoteAcct, - SkippedZeroNodePk: stats.SkippedMissingNodePk, - LocalHash: fullHash, - RunID: mlog.GetRunID(), - Source: "snapshot", // From snapshot loading at startup - Timestamp: time.Now().UTC(), - } - - // Dump ALL validators, skipped accounts, and summary to files - dumpFullScheduleDataWithSummary(validEntries, skippedEntries, summary, logsDir) - - // Dump tie-break debug info (shows how equal-stake validators are ordered) - DumpTieBreakDebug(epoch, voteAcctStakes, voteAcctMap, logsDir) - - // Dump first 1000 slots if dump flag is set (for debugging against RPC) - if config.GetBool("replay.dump_leader_schedule") { - DumpLeaderSchedule(epoch, epochSchedule, schedule, logsDir, 1000) - } - - return &summary, nil -} - -// PrepareLeaderScheduleLocalFromVoteCache builds the leader schedule using vote cache for NodePubkey lookups. -// Used at epoch boundaries when EpochStakesVoteAccts may not have the new epoch's data yet. -// Returns the schedule summary (for RPC validation) and error if schedule cannot be built. -func PrepareLeaderScheduleLocalFromVoteCache( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - logsDir string, -) (*ScheduleSummary, error) { - voteAcctStakes := global.EpochStakes(epoch) - - // The RNG seed uses `epoch` directly (the epoch we're building the schedule for) - // Note: LeaderScheduleEpoch() returns something different (next epoch's prep slot) - don't use it here - firstSlot := epochSchedule.FirstSlotInEpoch(epoch) - numSlots := epochSchedule.SlotsInEpoch(epoch) - - if len(voteAcctStakes) == 0 { - mlog.Log.Errorf("LEADER SCHEDULE BUILD FAILED: epoch=%d reason=no_stake_data", epoch) - mlog.Log.FileOnlyf(" rng_epoch=%d first_slot=%d slots=%d source=vote_cache", epoch, firstSlot, numSlots) - mlog.Log.FileOnlyf(" EpochStakes(%d) returned nil or empty", epoch) - mlog.Log.FileOnlyf(" VoteCache size=%d", len(global.VoteCache())) - return nil, fmt.Errorf("no stake data available for epoch %d", epoch) - } - - schedule, stats, validEntries, skippedEntries := buildLocalLeaderScheduleFromVoteCache(epoch, epochSchedule, voteAcctStakes) - - // Calculate total input stake (before filtering) - var totalInputStake uint64 - for _, stake := range voteAcctStakes { - totalInputStake += stake - } - - if schedule == nil { - logHardFailContext(epoch, "no_valid_stakes_after_filtering (vote_cache)", stats, voteAcctStakes) - // Still dump whatever data we have for debugging even on failure - dumpFullScheduleData(epoch, "local_vote_cache", validEntries, skippedEntries, stats.TotalStake, logsDir) - return nil, fmt.Errorf("could not build leader schedule for epoch %d: no valid stakes after filtering (zero_stake=%d, missing_nodepk=%d, missing_vote_state=%d)", - epoch, stats.SkippedZeroStake, stats.SkippedMissingNodePk, stats.SkippedMissingVoteAcct) - } - - // Safety check: fail if too much stake is missing from VoteCache. - // Since local schedule is the source of truth, missing entries produce incorrect schedules. - missingStake := stats.SkippedMissingVoteAcctStake - if totalInputStake > 0 && missingStake > 0 { - missingPercent := float64(missingStake) / float64(totalInputStake) * 100.0 - if missingPercent > MaxMissingVoteCacheStakePercent { - logHardFailContext(epoch, fmt.Sprintf("vote_cache_too_incomplete (%.2f%% > %.1f%%)", missingPercent, MaxMissingVoteCacheStakePercent), stats, voteAcctStakes) - // Dump data even on failure for debugging - dumpFullScheduleData(epoch, "local_vote_cache", validEntries, skippedEntries, stats.TotalStake, logsDir) - return nil, fmt.Errorf("vote cache too incomplete for epoch %d: %.2f%% stake missing (threshold %.1f%%), missing_accts=%d missing_stake=%d total_stake=%d", - epoch, missingPercent, MaxMissingVoteCacheStakePercent, - stats.SkippedMissingVoteAcct, missingStake, totalInputStake) - } - // Log warning if any stake is missing, even below threshold - mlog.Log.Warnf("leader schedule: epoch=%d has %.2f%% stake missing from VoteCache (count=%d stake=%d)", - epoch, missingPercent, stats.SkippedMissingVoteAcct, missingStake) - } - - // Set as source of truth - global.SetLeaderSchedule(schedule) - - // Compute hash for logging - fullHash := scheduleFullHash(schedule, firstSlot, numSlots) - - // Log comprehensive summary - logScheduleBuildSummary(epoch, epoch, firstSlot, numSlots, "vote_cache", stats, fullHash) - - // Build summary with all metadata - // Include all missing stake: missing_vote_acct + zero_nodepk - totalMissingStake := stats.SkippedMissingVoteAcctStake + stats.SkippedMissingNodePkStake - var missingPercent float64 - if totalInputStake > 0 { - missingPercent = float64(totalMissingStake) / float64(totalInputStake) * 100.0 - } - summary := ScheduleSummary{ - BlockEpoch: epoch, - ScheduleEpoch: epoch, // RNG seed epoch = block epoch - FirstSlot: firstSlot, - SlotsInEpoch: numSlots, - Repeat: NumConsecutiveLeaderSlots, - TotalInputStake: totalInputStake, - FilteredStake: stats.TotalStake, - MissingStake: totalMissingStake, - MissingStakePercent: missingPercent, - ValidatorsInput: stats.TotalVoteAccts, - ValidatorsUsed: stats.ValidatorCount, - ValidatorsSkipped: stats.SkippedZeroStake + stats.SkippedMissingVoteAcct + stats.SkippedMissingNodePk, - SkippedZeroStake: stats.SkippedZeroStake, - SkippedMissingData: stats.SkippedMissingVoteAcct, - SkippedZeroNodePk: stats.SkippedMissingNodePk, - LocalHash: fullHash, - RunID: mlog.GetRunID(), - Source: "transition", // From epoch boundary transition - Timestamp: time.Now().UTC(), - } - - // Dump ALL validators, skipped accounts, and summary to files - dumpFullScheduleDataWithSummary(validEntries, skippedEntries, summary, logsDir) - - // Dump first 1000 slots if dump flag is set (for debugging against RPC) - if config.GetBool("replay.dump_leader_schedule") { - DumpLeaderSchedule(epoch, epochSchedule, schedule, logsDir, 1000) - } - - return &summary, nil -} - -// DumpLeaderSchedule writes the first N slots of the schedule to a file for debugging. -// File is written to logsDir/leader_schedule_dump_epoch.txt -// Useful for comparing against RPC getLeaderSchedule results. -func DumpLeaderSchedule( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - schedule *leaderschedule.LeaderSchedule, - logsDir string, - numSlots int, -) { - if schedule == nil { - mlog.Log.Warnf("DumpLeaderSchedule: schedule is nil") - return - } - - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("DumpLeaderSchedule: failed to create logs dir: %v", err) - return - } - - filename := fmt.Sprintf("leader_schedule_dump_epoch%d.txt", epoch) - filepath := filepath.Join(logsDir, filename) - - f, err := os.Create(filepath) - if err != nil { - mlog.Log.Warnf("DumpLeaderSchedule: failed to create file: %v", err) - return - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - firstSlot := epochSchedule.FirstSlotInEpoch(epoch) - totalSlots := epochSchedule.SlotsInEpoch(epoch) - - // Write header - w.WriteString(fmt.Sprintf("# Leader Schedule Dump - Epoch %d\n", epoch)) - w.WriteString(fmt.Sprintf("# First slot: %d\n", firstSlot)) - w.WriteString(fmt.Sprintf("# Total slots in epoch: %d\n", totalSlots)) - w.WriteString(fmt.Sprintf("# Dumping first %d slots\n", numSlots)) - w.WriteString(fmt.Sprintf("# Format: slot_offset,absolute_slot,leader_pubkey\n")) - w.WriteString("#\n") - - // Dump first N slots - for i := 0; i < numSlots && uint64(i) < totalSlots; i++ { - slot := firstSlot + uint64(i) - leader, ok := schedule.LeaderForSlot(slot) - if ok { - w.WriteString(fmt.Sprintf("%d,%d,%s\n", i, slot, leader.String())) - } else { - w.WriteString(fmt.Sprintf("%d,%d,NOT_FOUND\n", i, slot)) - } - } - - mlog.Log.FileOnlyf("leader schedule dumped to: %s (first %d slots)", filepath, numSlots) -} - -// dumpScheduleSlotsCSV dumps the full schedule to a CSV for slot-by-slot comparison. -// Format: slot,leader_pubkey (simple format for easy diffing) -// Called when mismatch is detected or when replay.dump_leader_schedule is set. -func dumpScheduleSlotsCSV( - epoch uint64, - source string, // "local" or "rpc" - schedule *leaderschedule.LeaderSchedule, - firstSlot uint64, - numSlots uint64, - logsDir string, -) string { - if schedule == nil { - return "" - } - - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("dumpScheduleSlotsCSV: failed to create logs dir: %v", err) - return "" - } - - // Get short run ID for filename - runID := mlog.GetRunID() - shortRunID := "" - if runID != "" { - shortRunID = runID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - filename := fmt.Sprintf("epoch%d_%s_slots%s.csv", epoch, source, shortRunID) - filePath := filepath.Join(logsDir, filename) - - f, err := os.Create(filePath) - if err != nil { - mlog.Log.Warnf("dumpScheduleSlotsCSV: failed to create file: %v", err) - return "" - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - // Minimal header - just slot,leader for easy diffing - w.WriteString("slot,leader\n") - - // Dump all slots - for i := uint64(0); i < numSlots; i++ { - slot := firstSlot + i - leader, ok := schedule.LeaderForSlot(slot) - if ok { - w.WriteString(fmt.Sprintf("%d,%s\n", slot, leader.String())) - } else { - w.WriteString(fmt.Sprintf("%d,\n", slot)) // Empty leader for missing - } - } - - mlog.Log.FileOnlyf("leader schedule slots dumped to: %s (%d slots)", filePath, numSlots) - return filePath -} - -// DumpScheduleMismatch dumps both local and RPC schedules to CSV files for analysis. -// Called when a hash mismatch is detected during validation. -// Returns paths to local and RPC slot files. -func DumpScheduleMismatch( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - localSchedule *leaderschedule.LeaderSchedule, - rpcSchedule *leaderschedule.LeaderSchedule, - logsDir string, -) (localPath, rpcPath string) { - firstSlot := epochSchedule.FirstSlotInEpoch(epoch) - numSlots := epochSchedule.SlotsInEpoch(epoch) - - localPath = dumpScheduleSlotsCSV(epoch, "local", localSchedule, firstSlot, numSlots, logsDir) - rpcPath = dumpScheduleSlotsCSV(epoch, "rpc", rpcSchedule, firstSlot, numSlots, logsDir) - - if localPath != "" && rpcPath != "" { - mlog.Log.FileOnlyf("schedule mismatch dumps: local=%s rpc=%s", localPath, rpcPath) - mlog.Log.FileOnlyf(" run: scripts/diff_leader_schedules.py %s %s", localPath, rpcPath) - } - - return localPath, rpcPath -} - -// dumpRPCValidatorList extracts validators from RPC schedule and dumps to CSV. -// Since RPC only gives us slot -> leader, we count slot appearances per leader. -// File is named epoch_rpc__validators.csv for comparison with local. -func dumpRPCValidatorList( - epoch uint64, - rpcSchedule *leaderschedule.LeaderSchedule, - firstSlot uint64, - numSlots uint64, - logsDir string, -) { - if rpcSchedule == nil { - return - } - - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("dumpRPCValidatorList: failed to create logs dir: %v", err) - return - } - - // Count slot appearances per leader - leaderSlots := make(map[solana.PublicKey]uint64) - for i := uint64(0); i < numSlots; i++ { - slot := firstSlot + i - leader, ok := rpcSchedule.LeaderForSlot(slot) - if ok { - leaderSlots[leader]++ - } - } - - // Build entries sorted by slot count (descending) for comparison with local - type rpcEntry struct { - leader solana.PublicKey - slotCount uint64 - } - entries := make([]rpcEntry, 0, len(leaderSlots)) - for leader, count := range leaderSlots { - entries = append(entries, rpcEntry{leader: leader, slotCount: count}) - } - sort.Slice(entries, func(i, j int) bool { - if entries[i].slotCount != entries[j].slotCount { - return entries[i].slotCount > entries[j].slotCount - } - // Tie-break by pubkey descending (matches local sort) - return bytes.Compare(entries[i].leader[:], entries[j].leader[:]) > 0 - }) - - // Get short run ID for filename - runID := mlog.GetRunID() - shortRunID := "" - if runID != "" { - shortRunID = runID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - filename := fmt.Sprintf("epoch%d_rpc%s_validators.csv", epoch, shortRunID) - filePath := filepath.Join(logsDir, filename) - - f, err := os.Create(filePath) - if err != nil { - mlog.Log.Warnf("dumpRPCValidatorList: failed to create file: %v", err) - return - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - // Header - w.WriteString(fmt.Sprintf("# RPC Leader Schedule - Epoch %d\n", epoch)) - w.WriteString(fmt.Sprintf("# Source: rpc\n")) - w.WriteString(fmt.Sprintf("# Total Leaders: %d\n", len(entries))) - w.WriteString(fmt.Sprintf("# Total Slots: %d\n", numSlots)) - w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) - w.WriteString("#\n") - w.WriteString("# NOTE: RPC schedule only provides slot->leader mapping.\n") - w.WriteString("# Stake is not available from RPC, so we show slot_count instead.\n") - w.WriteString("# Compare slot_count with local schedule to identify discrepancies.\n") - w.WriteString("#\n") - w.WriteString("rank,node_pubkey,slot_count\n") - - for i, e := range entries { - w.WriteString(fmt.Sprintf("%d,%s,%d\n", i+1, e.leader, e.slotCount)) - } - - mlog.Log.FileOnlyf("RPC validator list dumped to: %s (%d leaders)", filePath, len(entries)) -} - -// BackgroundValidateAgainstRPC optionally validates local schedule against RPC in background. -// This is purely for debugging and does not affect the source of truth. -// Computes full SHA256 hash of entire schedule (~20-50ms) for complete comparison. -// Always writes a validation summary file with full local summary and RPC hash. -// Also dumps RPC-derived validator list for comparison. -func BackgroundValidateAgainstRPC( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - localSchedule *leaderschedule.LeaderSchedule, - rpcSchedule *leaderschedule.LeaderSchedule, - localSummary *ScheduleSummary, - logsDir string, -) { - if rpcSchedule == nil || localSchedule == nil { - return - } - - firstSlot := epochSchedule.FirstSlotInEpoch(epoch) - numSlots := epochSchedule.SlotsInEpoch(epoch) - - // Compute full hash for RPC schedule - rpcHash := scheduleFullHash(rpcSchedule, firstSlot, numSlots) - - // Use local summary's hash if available, else compute - localHash := localSummary.LocalHash - if localHash == "" { - localHash = scheduleFullHash(localSchedule, firstSlot, numSlots) - } - - matched := localHash == rpcHash - - // Update summary with RPC data and write validation file - localSummary.RPCHash = rpcHash - - // Always write validation summary file with full local summary + RPC data - writeValidationSummary(localSummary, matched, logsDir) - - if matched { - mlog.Log.FileOnlyf("leader schedule RPC validation: epoch=%d MATCH hash=%s", epoch, localHash) - return - } - - // Only dump RPC validator list on mismatch (expensive I/O) - dumpRPCValidatorList(epoch, rpcSchedule, firstSlot, numSlots, logsDir) - - // Hashes differ - log to mismatch file with details - initMismatchLog(logsDir) - - mismatchLogMu.Lock() - if mismatchLogWriter != nil { - mismatchLogWriter.WriteString(fmt.Sprintf("\n[%s] RPC VALIDATION MISMATCH epoch=%d\n", time.Now().Format(time.RFC3339), epoch)) - mismatchLogWriter.WriteString(fmt.Sprintf(" local_hash=%s rpc_hash=%s\n", localHash, rpcHash)) - } - mismatchLogMu.Unlock() - - mlog.Log.Warnf("leader schedule RPC validation: MISMATCH epoch=%d local_hash=%s rpc_hash=%s - see %s", - epoch, localHash, rpcHash, getMismatchLogPath()) - - flushMismatchLog() - - // Dump both schedules to CSV for detailed analysis - DumpScheduleMismatch(epoch, epochSchedule, localSchedule, rpcSchedule, logsDir) -} - -// writeValidationSummary writes a summary file with full local summary and RPC comparison. -func writeValidationSummary(summary *ScheduleSummary, matched bool, logsDir string) { - logsDir = resolveLogsDir(logsDir) - if err := os.MkdirAll(logsDir, 0755); err != nil { - mlog.Log.Warnf("writeValidationSummary: failed to create logs dir: %v", err) - return - } - - shortRunID := "" - if summary.RunID != "" { - shortRunID = summary.RunID - if len(shortRunID) > 8 { - shortRunID = shortRunID[:8] - } - shortRunID = "_" + shortRunID - } - - filename := fmt.Sprintf("epoch%d_validation%s.txt", summary.BlockEpoch, shortRunID) - filePath := filepath.Join(logsDir, filename) - - f, err := os.Create(filePath) - if err != nil { - mlog.Log.Warnf("writeValidationSummary: failed to create file: %v", err) - return - } - defer f.Close() - - w := bufio.NewWriter(f) - defer w.Flush() - - status := "MATCH" - if !matched { - status = "MISMATCH" - } - - w.WriteString("# Leader Schedule Validation Summary\n") - w.WriteString(fmt.Sprintf("# Generated: %s\n", time.Now().UTC().Format(time.RFC3339))) - w.WriteString(fmt.Sprintf("# Run ID: %s\n", summary.RunID)) - w.WriteString("#\n") - w.WriteString(fmt.Sprintf("## Result: %s\n\n", status)) - - // Epoch Info (same as local summary) - w.WriteString("## Epoch Info\n") - w.WriteString(fmt.Sprintf("block_epoch=%d\n", summary.BlockEpoch)) - w.WriteString(fmt.Sprintf("schedule_epoch=%d\n", summary.ScheduleEpoch)) - w.WriteString(fmt.Sprintf("first_slot=%d\n", summary.FirstSlot)) - w.WriteString(fmt.Sprintf("slots_in_epoch=%d\n", summary.SlotsInEpoch)) - w.WriteString(fmt.Sprintf("repeat=%d\n", summary.Repeat)) - w.WriteString(fmt.Sprintf("source=%s\n", summary.Source)) - w.WriteString("\n") - - // Stake Info - w.WriteString("## Stake Info\n") - w.WriteString(fmt.Sprintf("total_input_stake=%d\n", summary.TotalInputStake)) - w.WriteString(fmt.Sprintf("filtered_stake=%d\n", summary.FilteredStake)) - w.WriteString(fmt.Sprintf("missing_stake=%d\n", summary.MissingStake)) - w.WriteString(fmt.Sprintf("missing_stake_percent=%.4f\n", summary.MissingStakePercent)) - w.WriteString("\n") - - // Validator Counts - w.WriteString("## Validator Counts\n") - w.WriteString(fmt.Sprintf("validators_input=%d\n", summary.ValidatorsInput)) - w.WriteString(fmt.Sprintf("validators_used=%d\n", summary.ValidatorsUsed)) - w.WriteString(fmt.Sprintf("validators_skipped=%d\n", summary.ValidatorsSkipped)) - w.WriteString(fmt.Sprintf("skipped_zero_stake=%d\n", summary.SkippedZeroStake)) - w.WriteString(fmt.Sprintf("skipped_missing_data=%d\n", summary.SkippedMissingData)) - w.WriteString(fmt.Sprintf("skipped_zero_nodepk=%d\n", summary.SkippedZeroNodePk)) - w.WriteString("\n") - - // Hashes - local and RPC side by side - w.WriteString("## Comparison\n") - w.WriteString(fmt.Sprintf("local_hash=%s\n", summary.LocalHash)) - w.WriteString(fmt.Sprintf("rpc_hash=%s\n", summary.RPCHash)) - w.WriteString(fmt.Sprintf("\nstatus=%s\n", status)) - - mlog.Log.FileOnlyf("leader schedule validation summary written to: %s", filePath) -} - -// fetchLeaderScheduleFromRPC fetches leader schedule from RPC for validation purposes. -// Does NOT set it as the global schedule - this is for background validation only. -// Tries primary endpoint first, then backups with fewer retries. -// Uses the epoch-aware RPC method to ensure correct schedule during historical catchup. -// RPC method: getLeaderSchedule with epoch parameter -func fetchLeaderScheduleFromRPC( - epoch uint64, - epochSchedule *sealevel.SysvarEpochSchedule, - rpcClient *rpcclient.RpcClient, - backupEndpoints []string, -) (*leaderschedule.LeaderSchedule, error) { - firstSlotInEpoch := epochSchedule.FirstSlotInEpoch(epoch) - - // Try primary endpoint first (fewer retries since this is background validation) - // Pass epoch explicitly to get correct schedule during catchup - leaderMap, err := fetchLeaderScheduleForEpochWithRetry(rpcClient, epoch, 3) - if err == nil { - return leaderschedule.NewLeaderScheduleFromKeyedSlots(leaderMap, firstSlotInEpoch), nil - } - - lastErr := err - mlog.Log.Debugf("RPC leader schedule fetch (validation) for epoch %d failed on primary %s: %v", epoch, rpcClient.Endpoint(), err) - - // Try backup endpoints with fewer retries - for _, endpoint := range backupEndpoints { - backupClient := rpcclient.NewRpcClient(endpoint) - leaderMap, err := fetchLeaderScheduleForEpochWithRetry(backupClient, epoch, 2) - if err == nil { - mlog.Log.Debugf("RPC leader schedule for epoch %d fetched from backup %s (for validation)", epoch, endpoint) - return leaderschedule.NewLeaderScheduleFromKeyedSlots(leaderMap, firstSlotInEpoch), nil - } - lastErr = err - } - - return nil, fmt.Errorf("RPC leader schedule fetch for epoch %d failed from all endpoints: %w", epoch, lastErr) -} diff --git a/pkg/replay/profile.go b/pkg/replay/profile.go deleted file mode 100644 index 8536f1c3..00000000 --- a/pkg/replay/profile.go +++ /dev/null @@ -1,47 +0,0 @@ -package replay - -import ( - "os" - "os/signal" - "runtime/pprof" - "syscall" - - _ "net/http/pprof" - - "github.com/Overclock-Validator/mithril/pkg/accountsdb" - "github.com/Overclock-Validator/mithril/pkg/mlog" -) - -func installProfilerAndSignalHandler(acctsDb *accountsdb.AccountsDb) *os.File { - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, - syscall.SIGHUP, - syscall.SIGINT, - syscall.SIGTERM, - syscall.SIGQUIT) - - f, err := os.Create("../mithril.prof") - if err != nil { - panic("unable to create profile file") - } - - pprof.StartCPUProfile(f) - - go func() { - for { - s := <-sigChan - switch s { - case syscall.SIGINT: - { - mlog.Log.Infof("signal received. shutting down mithril.") - pprof.StopCPUProfile() - f.Close() - acctsDb.CloseDb() - os.Exit(0) - } - } - } - }() - - return f -} From 857794a3b647bcf49872221baf5dc86058b687ba Mon Sep 17 00:00:00 2001 From: smcio Date: Wed, 13 May 2026 22:13:16 +0200 Subject: [PATCH 2/4] block source hardening --- pkg/blockstream/block_source.go | 35 +++++++++++++- pkg/blockstream/block_source_test.go | 71 ++++++++++++++++++++++++++++ pkg/replay/block.go | 6 +++ 3 files changed, 111 insertions(+), 1 deletion(-) diff --git a/pkg/blockstream/block_source.go b/pkg/blockstream/block_source.go index fcc4930d..85a141f5 100644 --- a/pkg/blockstream/block_source.go +++ b/pkg/blockstream/block_source.go @@ -589,7 +589,11 @@ func (bs *BlockSource) forceRPCForCatchup(gap uint64) { clearedPrefetched := bs.clearBufferedLightbringerBlocks() bs.reorderMu.Lock() - waitingSlot, previousWaitingSlot := bs.rewindConsensusManagedFrontierForRPCFallbackLocked() + waitingSlot := bs.nextSlotToSend + previousWaitingSlot := waitingSlot + if wasActive { + waitingSlot, previousWaitingSlot = bs.rewindConsensusManagedFrontierForRPCFallbackLocked() + } removedSlots := make([]uint64, 0) for slot, blk := range bs.reorderBuffer { if blk != nil && blk.FromLightbringer && slot >= waitingSlot { @@ -2722,6 +2726,16 @@ func (bs *BlockSource) emitOrderedBlocks() { var shouldFallbackToRPC bool var emitObservationDirect bool + if result.slot < bs.nextSlotToSend { + bs.slotStateMu.Lock() + delete(bs.slotState, result.slot) + delete(bs.inflightStart, result.slot) + bs.slotStateMu.Unlock() + bs.clearSlotErrors(result.slot) + bs.reorderMu.Unlock() + continue + } + if result.block != nil && result.block.FromLightbringer { handoffSlot := bs.lightbringerHandoffSlot.Load() if handoffSlot != 0 && result.slot >= handoffSlot { @@ -3019,6 +3033,9 @@ func (bs *BlockSource) canScheduleMore(slot uint64) bool { if !bs.shouldUseRPCForSlot(slot) { return false } + if bs.slotBeforeEmissionFrontier(slot) { + return false + } if bs.isNearTip.Load() { // Near-tip mode: allow scheduling up to nearTipLookahead slots ahead @@ -3059,11 +3076,21 @@ func (bs *BlockSource) canScheduleMore(slot uint64) bool { return pending < defaultMaxPending } +func (bs *BlockSource) slotBeforeEmissionFrontier(slot uint64) bool { + bs.reorderMu.Lock() + nextToSend := bs.nextSlotToSend + bs.reorderMu.Unlock() + return nextToSend != 0 && slot < nextToSend +} + // scheduleSlot schedules a slot if not already scheduled func (bs *BlockSource) scheduleSlot(slot uint64) bool { if !bs.shouldUseRPCForSlot(slot) { return false } + if bs.slotBeforeEmissionFrontier(slot) { + return false + } bs.slotStateMu.Lock() if _, exists := bs.slotState[slot]; exists { @@ -3088,6 +3115,9 @@ func (bs *BlockSource) scheduleBackupRequest(slot uint64) bool { if !bs.shouldUseRPCForSlot(slot) { return false } + if bs.slotBeforeEmissionFrontier(slot) { + return false + } select { case bs.workQueue <- slot: @@ -3263,6 +3293,9 @@ func (bs *BlockSource) scheduler() { bs.reorderMu.Unlock() for _, slot := range bs.getRetrySlots() { + if waitingSlot != 0 && slot < waitingSlot { + continue + } if !bs.shouldUseRPCForSlot(slot) { continue } diff --git a/pkg/blockstream/block_source_test.go b/pkg/blockstream/block_source_test.go index f0379d3c..40a2e874 100644 --- a/pkg/blockstream/block_source_test.go +++ b/pkg/blockstream/block_source_test.go @@ -688,6 +688,42 @@ func TestForceRPCForCatchupRewindsConsensusManagedFrontier(t *testing.T) { } } +func TestForceRPCForCatchupKeepsPendingHandoffEmissionFrontier(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + ConsensusManagedLightbringer: true, + }) + + bs.lightbringerHandoffSlot.Store(121) + bs.lastExecutedSlot.Store(120) + bs.nextSlotToSend = 150 + bs.reorderBuffer[150] = &b.Block{Slot: 150, FromLightbringer: true} + bs.reorderBuffer[151] = &b.Block{Slot: 151, FromLightbringer: false} + bs.slotState[150] = slotDone + bs.slotState[151] = slotInflight + + bs.forceRPCForCatchup(64) + + if got := bs.nextSlotToSend; got != 150 { + t.Fatalf("expected pending handoff fallback to keep emitted RPC frontier 150, got %d", got) + } + if got := bs.lightbringerHandoffSlot.Load(); got != 0 { + t.Fatalf("expected pending handoff to be cleared, got %d", got) + } + if !bs.lightbringerNeedRPCResume.Load() { + t.Fatalf("expected scheduler to resume RPC from the current emission frontier") + } + if _, exists := bs.reorderBuffer[150]; exists { + t.Fatalf("expected pending Lightbringer slot 150 to be dropped") + } + if _, exists := bs.reorderBuffer[151]; !exists { + t.Fatalf("expected buffered RPC slot 151 to remain") + } +} + func TestEmitOrderedBlocksDirectlyStreamsConsensusManagedLightbringerObservations(t *testing.T) { bs := NewBlockSource(&BlockSourceOpts{ SourceType: BlockSourceLightbringer, @@ -727,6 +763,41 @@ func TestEmitOrderedBlocksDirectlyStreamsConsensusManagedLightbringerObservation } } +func TestEmitOrderedBlocksDropsResultsBehindEmissionFrontier(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceRpc, + StartSlot: 100, + EndSlot: 200, + }) + + bs.nextSlotToSend = 105 + bs.slotState[103] = slotInflight + bs.inflightStart[103] = time.Now() + + done := make(chan struct{}) + go func() { + bs.emitOrderedBlocks() + close(done) + }() + + bs.resultQueue <- fetchResult{ + slot: 103, + block: &b.Block{Slot: 103}, + } + close(bs.resultQueue) + <-done + + if len(bs.streamChan) != 0 { + t.Fatalf("expected stale result to be dropped without emission") + } + if _, exists := bs.reorderBuffer[103]; exists { + t.Fatalf("expected stale result not to enter reorder buffer") + } + if _, exists := bs.slotState[103]; exists { + t.Fatalf("expected stale slot state to be cleared") + } +} + func TestIsLightbringerReconnectCancelRecognizesGrpcCanceledStatus(t *testing.T) { err := status.Error(codes.Canceled, "context canceled") if !isLightbringerReconnectCancel(err) { diff --git a/pkg/replay/block.go b/pkg/replay/block.go index 96f4e00d..09b5c2fd 100644 --- a/pkg/replay/block.go +++ b/pkg/replay/block.go @@ -1665,6 +1665,12 @@ func ReplayBlocks( break } + if anchorSlot := currentConsensusAnchorSlot(); anchorSlot != 0 && block.Slot <= anchorSlot { + mlog.Log.Warnf("replay: discarding stale block source emission for slot %d; already executed through slot %d", + block.Slot, anchorSlot) + continue + } + syncConsensusBufferedExecutionMode(block.Slot) if block.FromLightbringer { From a07d0c6ca8c7e7892385378c08c861553d9ce921 Mon Sep 17 00:00:00 2001 From: smcio Date: Fri, 15 May 2026 23:43:04 +0200 Subject: [PATCH 3/4] implement direct account mapping, SIMD-0178, SIMD-0189, SIMD-0377, SIMD-0459, SIMD-0460, setup conformance testing, and fix conformance bugs --- Makefile | 8 +- conformance/debug_fixture_test.go | 147 ++++++ conformance/elf_loader_fb_test.go | 84 ++-- conformance/firedancer_fixture_test.go | 311 ++++++++++++ conformance/test_common.go | 39 +- conformance/vm_programs_test.go | 434 ++++++++++++++++ pkg/accountsdb/accountsdb.go | 12 + pkg/features/features.go | 8 +- pkg/features/features_test.go | 47 ++ pkg/features/gates.go | 15 +- pkg/sbpf/asm.go | 2 +- pkg/sbpf/interpreter.go | 307 +++++++++++- pkg/sbpf/interpreter_v3_test.go | 111 +++++ pkg/sbpf/loader/copy.go | 75 ++- pkg/sbpf/loader/loader.go | 41 +- pkg/sbpf/loader/parse.go | 162 +++++- pkg/sbpf/loader/relocate.go | 53 +- pkg/sbpf/loader/relocate_test.go | 11 + pkg/sbpf/loader/strict_v3_test.go | 147 ++++++ pkg/sbpf/opcode_test.go | 12 +- pkg/sbpf/program.go | 1 + pkg/sbpf/sbpf.go | 4 + pkg/sbpf/sbpfver/sbpf_version.go | 55 ++- pkg/sbpf/sbpfver/sbpf_version_test.go | 29 ++ pkg/sbpf/stack.go | 26 +- pkg/sbpf/vasa_test.go | 102 ++++ pkg/sbpf/verifier.go | 82 ++-- pkg/sbpf/vm.go | 14 + pkg/sealevel/bpf_loader.go | 501 +++++++++++++++---- pkg/sealevel/bpf_loader_error_test.go | 36 ++ pkg/sealevel/errors.go | 164 ++++++- pkg/sealevel/execution_ctx.go | 2 + pkg/sealevel/syscalls.go | 1 + pkg/sealevel/syscalls_common.go | 42 +- pkg/sealevel/syscalls_cpi.go | 600 ++++++++++++++++++----- pkg/sealevel/syscalls_cpi_0459_test.go | 58 +++ pkg/sealevel/syscalls_curve.go | 8 +- pkg/sealevel/syscalls_gen/main.go | 2 +- pkg/sealevel/syscalls_hash_test.go | 21 + pkg/sealevel/syscalls_log.go | 4 + pkg/sealevel/syscalls_log_test.go | 18 + pkg/sealevel/syscalls_pda.go | 6 +- pkg/sealevel/syscalls_pda_test.go | 23 + pkg/sealevel/syscalls_sysvar.go | 65 ++- pkg/sealevel/sysvar_epoch_rewards.go | 7 +- pkg/sealevel/sysvar_last_restart_slot.go | 15 +- pkg/sealevel/types.go | 9 +- 47 files changed, 3484 insertions(+), 437 deletions(-) create mode 100644 conformance/debug_fixture_test.go create mode 100644 conformance/firedancer_fixture_test.go create mode 100644 conformance/vm_programs_test.go create mode 100644 pkg/sbpf/interpreter_v3_test.go create mode 100644 pkg/sbpf/loader/strict_v3_test.go create mode 100644 pkg/sbpf/sbpfver/sbpf_version_test.go create mode 100644 pkg/sbpf/vasa_test.go create mode 100644 pkg/sealevel/bpf_loader_error_test.go create mode 100644 pkg/sealevel/syscalls_cpi_0459_test.go create mode 100644 pkg/sealevel/syscalls_hash_test.go create mode 100644 pkg/sealevel/syscalls_log_test.go create mode 100644 pkg/sealevel/syscalls_pda_test.go diff --git a/Makefile b/Makefile index 5a04af3d..2a579c80 100644 --- a/Makefile +++ b/Makefile @@ -8,7 +8,7 @@ LDFLAGS := -X github.com/Overclock-Validator/mithril/pkg/version.Version=$(VERSI -X github.com/Overclock-Validator/mithril/pkg/version.GitBranch=$(GIT_BRANCH) \ -X github.com/Overclock-Validator/mithril/pkg/version.BuildDate=$(BUILD_DATE) -.PHONY: build release clean server-setup disk-setup tune test-conformance-elf +.PHONY: build release clean server-setup disk-setup tune test-conformance-elf test-conformance-vm-programs test-conformance-sbpf build: go build -ldflags "$(LDFLAGS)" -o mithril ./cmd/mithril @@ -31,3 +31,9 @@ tune: test-conformance-elf: go test ./conformance/ -run TestConformance_ElfLoader_Firedancer -v + +test-conformance-vm-programs: + go test ./conformance/ -run TestConformance_VMPrograms_Firedancer -v + +test-conformance-sbpf: + go test ./conformance/ -run '^(TestConformance_ElfLoader_Firedancer|TestConformance_VMPrograms_Firedancer)$$' -v diff --git a/conformance/debug_fixture_test.go b/conformance/debug_fixture_test.go new file mode 100644 index 00000000..823b276f --- /dev/null +++ b/conformance/debug_fixture_test.go @@ -0,0 +1,147 @@ +package conformance + +import ( + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/Overclock-Validator/mithril/pkg/sbpf/loader" + sealevelPkg "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/gagliardetto/solana-go" +) + +func TestDebugDumpVMProgramFixture(t *testing.T) { + filter := os.Getenv("MITHRIL_CONFORMANCE_DUMP_FIXTURE") + if filter == "" { + t.Skip("set MITHRIL_CONFORMANCE_DUMP_FIXTURE") + } + + basePath := "test-vectors/instr/fixtures/vm-programs" + entries, err := os.ReadDir(basePath) + if err != nil { + t.Skipf("test-vectors not available: %v", err) + } + + var fixtures []string + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".fix") && strings.Contains(entry.Name(), filter) { + fixtures = append(fixtures, filepath.Join(basePath, entry.Name())) + } + } + sort.Strings(fixtures) + if len(fixtures) == 0 { + t.Fatalf("no fixture matching %q", filter) + } + + data, err := os.ReadFile(fixtures[0]) + if err != nil { + t.Fatal(err) + } + fixture, err := unmarshalFiredancerInstrFixture(data) + if err != nil { + t.Fatal(err) + } + + programID := solana.PublicKeyFromBytes(fixture.GetInput().GetProgramId()) + t.Logf("fixture=%s program_id=%s input_cu=%d output_cu=%d output_result=%d custom=%d data=%x", filepath.Base(fixtures[0]), programID, fixture.GetInput().GetCuAvail(), fixture.GetOutput().GetCuAvail(), fixture.GetOutput().GetResult(), fixture.GetOutput().GetCustomErr(), fixture.GetInput().GetData()) + t.Logf("%s", fixtureProgramSummary(fixture)) + for i, acct := range fixture.GetInput().GetAccounts() { + key := solana.PublicKeyFromBytes(acct.GetAddress()) + owner := solana.PublicKeyFromBytes(acct.GetOwner()) + prefixLen := min(len(acct.GetData()), 8) + prefix := acct.GetData()[:prefixLen] + t.Logf("acct[%d] key=%s owner=%s exec=%v lamports=%d data_len=%d data_prefix=%x", i, key, owner, acct.GetExecutable(), acct.GetLamports(), len(acct.GetData()), prefix) + } + for i, acct := range fixture.GetInput().GetInstrAccounts() { + t.Logf("instr_acct[%d] index=%d writable=%v signer=%v", i, acct.GetIndex(), acct.GetIsWritable(), acct.GetIsSigner()) + } + + execCtx, instrAccts, programIndices, err := newVMProgramExecCtxAndInstrAccts(fixture) + if err != nil { + t.Fatal(err) + } + t.Logf("features=%v", execCtx.Features.AllEnabled()) + + if idxStr := os.Getenv("MITHRIL_CONFORMANCE_DUMP_PROGRAM_INDEX"); idxStr != "" { + for i, acct := range fixture.GetInput().GetAccounts() { + if idxStr == strconv.Itoa(i) { + programID = solana.PublicKeyFromBytes(acct.GetAddress()) + break + } + } + } + var programBytes []byte + for _, acct := range fixture.GetInput().GetAccounts() { + if solana.PublicKeyFromBytes(acct.GetAddress()) == programID { + programBytes = acct.GetData() + break + } + } + if len(programBytes) == 0 { + t.Fatalf("program account %s has no bytes", programID) + } + + var program *sbpf.Program + if os.Getenv("MITHRIL_CONFORMANCE_DUMP_NO_DISASM") == "" { + syscalls := func(u uint32) (sbpf.Syscall, bool) { + syscall, ok := sealevelPkg.Syscalls(&execCtx.Features, false, u) + if !ok { + return nil, false + } + return debugSyscall{t: t, hash: u, inner: syscall}, true + } + l, err := loader.NewLoaderWithSyscalls(programBytes, syscalls, false, &execCtx.Features) + if err != nil { + t.Fatalf("loader: %v", err) + } + program, err = l.Load() + if err != nil { + t.Fatalf("load: %v", err) + } + + t.Logf("sbpf_version=%v entry=%d text_slots=%d funcs=%d", program.SbpfVersion, program.Entrypoint, len(program.Text), len(program.Funcs)) + t.Logf("verify_err=%v", program.Verify()) + } + execCtx.RecordInnerInstructions = true + execCtx.SetCurrentTopLevelInstr(0) + runErr := execCtx.ProcessInstruction(fixture.GetInput().GetData(), instrAccts, programIndices) + t.Logf("run_err=%v translated_result=%d remaining_cu=%d", runErr, instrResultFromErr(runErr), execCtx.ComputeMeter.Remaining()) + if recorder, ok := execCtx.Log.(*sealevelPkg.LogRecorder); ok { + for i, log := range recorder.Logs { + t.Logf("log[%d]=%s", i, log) + } + } + for i, inner := range execCtx.InnerInstrs { + t.Logf("inner[%d] stack=%d program_index=%d accounts=%v data=%x", i, inner.StackHeight, inner.ProgramIdIndex, inner.Accounts, inner.Data) + } + if program != nil { + limit := len(program.Text) + for pc := 0; pc < limit; pc++ { + slot := program.Text[pc] + t.Logf("%03d raw=%016x op=%02x dst=%d src=%d off=%d imm=%d uimm=%d", pc, uint64(slot), slot.Op(), slot.Dst(), slot.Src(), slot.Off(), slot.Imm(), slot.Uimm()) + } + } +} + +type debugSyscall struct { + t *testing.T + hash uint32 + inner sbpf.Syscall +} + +func (d debugSyscall) Invoke(vm sbpf.VM, r1, r2, r3, r4, r5 uint64) (uint64, error) { + before := vm.ComputeMeter().Remaining() + r0, err := d.inner.Invoke(vm, r1, r2, r3, r4, r5) + after := vm.ComputeMeter().Remaining() + d.t.Logf("syscall hash=0x%08x args=[0x%x 0x%x 0x%x 0x%x 0x%x] before_cu=%d after_cu=%d r0=%d err=%v", d.hash, r1, r2, r3, r4, r5, before, after, r0, err) + return r0, err +} + +func (d debugSyscall) String() string { + return fmt.Sprintf("debugSyscall(0x%08x)", d.hash) +} diff --git a/conformance/elf_loader_fb_test.go b/conformance/elf_loader_fb_test.go index 5544b330..0e1e3084 100644 --- a/conformance/elf_loader_fb_test.go +++ b/conformance/elf_loader_fb_test.go @@ -9,21 +9,15 @@ import ( "strings" "testing" - "github.com/Overclock-Validator/mithril/conformance/sealevel" "github.com/Overclock-Validator/mithril/pkg/features" "github.com/Overclock-Validator/mithril/pkg/sbpf" "github.com/Overclock-Validator/mithril/pkg/sbpf/loader" sealevelPkg "github.com/Overclock-Validator/mithril/pkg/sealevel" ) -func parseFBFeatures(fbFeatures *sealevel.FeatureSet) *features.Features { +func parseFeatureIds(featureIds []uint64) *features.Features { f := features.NewFeaturesDefault() - if fbFeatures == nil { - return f - } - n := fbFeatures.FeaturesLength() - for i := 0; i < n; i++ { - ftr := fbFeatures.Features(i) + for _, ftr := range featureIds { for _, featureGate := range features.AllFeatureGates { featureIdInt := binary.LittleEndian.Uint64(featureGate.Address[:8]) if featureIdInt == ftr { @@ -34,6 +28,13 @@ func parseFBFeatures(fbFeatures *sealevel.FeatureSet) *features.Features { return f } +func parsePBFeatures(pbFeatures *FeatureSet) *features.Features { + if pbFeatures == nil { + return features.NewFeaturesDefault() + } + return parseFeatureIds(pbFeatures.GetFeatures()) +} + func TestConformance_ElfLoader_Firedancer(t *testing.T) { basePath := "test-vectors/elf_loader/fixtures" @@ -82,32 +83,26 @@ func TestConformance_ElfLoader_Firedancer(t *testing.T) { continue } - fixture := sealevel.GetRootAsELFLoaderFixture(data, 0) - if fixture == nil { - parseErrors++ - continue - } - - input := fixture.Input(nil) - if input == nil { + fixture, err := unmarshalFiredancerELFLoaderFixture(data) + if err != nil { + failures = append(failures, fmt.Sprintf("PARSE_ERROR %s: %v", name, err)) parseErrors++ continue } - elfData := input.ElfDataBytes() - if elfData == nil { + if len(fixture.ElfData) == 0 { + failures = append(failures, fmt.Sprintf("PARSE_ERROR %s: missing ELF data", name)) parseErrors++ continue } - output := fixture.Output(nil) - fixtureExpectsSuccess := output != nil && output.ErrCode() == 0 + output := fixture.Output + fixtureExpectsSuccess := output.expectsSuccess() - fbFeatures := input.Features(nil) - f := parseFBFeatures(fbFeatures) + f := parsePBFeatures(fixture.Features) syscalls := sbpf.SyscallRegistry(func(hash uint32) (sbpf.Syscall, bool) { - return sealevelPkg.Syscalls(f, input.DeployChecks(), hash) + return sealevelPkg.Syscalls(f, fixture.DeployChecks, hash) }) var program *sbpf.Program @@ -123,7 +118,7 @@ func TestConformance_ElfLoader_Firedancer(t *testing.T) { } }() - l, err := loader.NewLoaderWithSyscalls(elfData, syscalls, input.DeployChecks(), f) + l, err := loader.NewLoaderWithSyscalls(fixture.ElfData, syscalls, fixture.DeployChecks, f) if err != nil { loadErr = err return @@ -139,18 +134,22 @@ func TestConformance_ElfLoader_Firedancer(t *testing.T) { passPass++ if output != nil { - entryTotal++ - if program.Entrypoint == output.EntryPc() { - entryMatch++ - } else { - failures = append(failures, fmt.Sprintf("ENTRY_MISMATCH %s: got=%d want=%d", name, program.Entrypoint, output.EntryPc())) + if output.HasEntryPc { + entryTotal++ + if program.Entrypoint == output.EntryPc { + entryMatch++ + } else { + failures = append(failures, fmt.Sprintf("ENTRY_MISMATCH %s: got=%d want=%d", name, program.Entrypoint, output.EntryPc)) + } } - textTotal++ - if uint64(len(program.Text)) == output.TextCnt() { - textMatch++ - } else { - failures = append(failures, fmt.Sprintf("TEXT_CNT_MISMATCH %s: got=%d want=%d", name, len(program.Text), output.TextCnt())) + if output.HasTextCnt { + textTotal++ + if uint64(len(program.Text)) == output.TextCnt { + textMatch++ + } else { + failures = append(failures, fmt.Sprintf("TEXT_CNT_MISMATCH %s: got=%d want=%d", name, len(program.Text), output.TextCnt)) + } } } } else if loadErr != nil && !fixtureExpectsSuccess { @@ -160,7 +159,11 @@ func TestConformance_ElfLoader_Firedancer(t *testing.T) { failures = append(failures, fmt.Sprintf("FALSE_PASS %s: loaded OK but fixture expects failure", name)) } else { falseFail++ - failures = append(failures, fmt.Sprintf("FALSE_FAIL %s: %v (entry_pc=%d text_cnt=%d)", name, loadErr, output.EntryPc(), output.TextCnt())) + if output != nil { + failures = append(failures, fmt.Sprintf("FALSE_FAIL %s: %v (entry_pc=%d text_cnt=%d err_code=%d)", name, loadErr, output.EntryPc, output.TextCnt, output.ErrCode)) + } else { + failures = append(failures, fmt.Sprintf("FALSE_FAIL %s: %v", name, loadErr)) + } } } @@ -205,6 +208,15 @@ func TestConformance_ElfLoader_Firedancer(t *testing.T) { t.Errorf("CRITICAL: %d fixtures caused panics in the loader", panics) } if disagree > 0 { - t.Logf("WARNING: %d disagreements found", disagree) + t.Errorf("%d ELF loader acceptance disagreements found", disagree) + } + if parseErrors > 0 { + t.Errorf("%d ELF loader fixture parse errors found", parseErrors) + } + if entryMatch != entryTotal { + t.Errorf("%d ELF loader entry PC mismatches found", entryTotal-entryMatch) + } + if textMatch != textTotal { + t.Errorf("%d ELF loader text count mismatches found", textTotal-textMatch) } } diff --git a/conformance/firedancer_fixture_test.go b/conformance/firedancer_fixture_test.go new file mode 100644 index 00000000..aca1e009 --- /dev/null +++ b/conformance/firedancer_fixture_test.go @@ -0,0 +1,311 @@ +package conformance + +import ( + "encoding/binary" + + legacyproto "github.com/golang/protobuf/proto" + "google.golang.org/protobuf/encoding/protowire" + "google.golang.org/protobuf/proto" +) + +type firedancerFeatureSet struct { + Features []uint64 `protobuf:"fixed64,1,rep,packed,name=features,proto3" json:"features,omitempty"` +} + +func (x *firedancerFeatureSet) Reset() { *x = firedancerFeatureSet{} } +func (x *firedancerFeatureSet) String() string { return legacyproto.CompactTextString(x) } +func (*firedancerFeatureSet) ProtoMessage() {} + +type firedancerCurrentELFLoaderCtx struct { + ElfData []byte `protobuf:"bytes,1,opt,name=elf_data,json=elfData,proto3" json:"elf_data,omitempty"` + Features *firedancerFeatureSet `protobuf:"bytes,2,opt,name=features,proto3" json:"features,omitempty"` + DeployChecks bool `protobuf:"varint,3,opt,name=deploy_checks,json=deployChecks,proto3" json:"deploy_checks,omitempty"` +} + +func (x *firedancerCurrentELFLoaderCtx) Reset() { *x = firedancerCurrentELFLoaderCtx{} } +func (x *firedancerCurrentELFLoaderCtx) String() string { return legacyproto.CompactTextString(x) } +func (*firedancerCurrentELFLoaderCtx) ProtoMessage() {} + +type firedancerCurrentELFLoaderEffects struct { + ErrCode *uint32 `protobuf:"varint,1,opt,name=err_code,json=errCode" json:"err_code,omitempty"` + RodataHash *uint64 `protobuf:"fixed64,2,opt,name=rodata_hash,json=rodataHash" json:"rodata_hash,omitempty"` + TextCnt *uint64 `protobuf:"varint,3,opt,name=text_cnt,json=textCnt" json:"text_cnt,omitempty"` + TextOff *uint64 `protobuf:"varint,4,opt,name=text_off,json=textOff" json:"text_off,omitempty"` + EntryPc *uint64 `protobuf:"varint,5,opt,name=entry_pc,json=entryPc" json:"entry_pc,omitempty"` + CalldestsHash *uint64 `protobuf:"fixed64,6,opt,name=calldests_hash,json=calldestsHash" json:"calldests_hash,omitempty"` +} + +func (x *firedancerCurrentELFLoaderEffects) Reset() { + *x = firedancerCurrentELFLoaderEffects{} +} +func (x *firedancerCurrentELFLoaderEffects) String() string { + return legacyproto.CompactTextString(x) +} +func (*firedancerCurrentELFLoaderEffects) ProtoMessage() {} + +type firedancerCurrentELFLoaderFixture struct { + Input *firedancerCurrentELFLoaderCtx `protobuf:"bytes,2,opt,name=input,proto3" json:"input,omitempty"` + Output *firedancerCurrentELFLoaderEffects `protobuf:"bytes,3,opt,name=output,proto3" json:"output,omitempty"` +} + +func (x *firedancerCurrentELFLoaderFixture) Reset() { + *x = firedancerCurrentELFLoaderFixture{} +} +func (x *firedancerCurrentELFLoaderFixture) String() string { + return legacyproto.CompactTextString(x) +} +func (*firedancerCurrentELFLoaderFixture) ProtoMessage() {} + +type firedancerELFLoaderFixtureCompat struct { + ElfData []byte + Features *FeatureSet + DeployChecks bool + Output *firedancerELFLoaderOutputCompat +} + +type firedancerELFLoaderOutputCompat struct { + ErrCode uint32 + HasErrCode bool + TextCnt uint64 + HasTextCnt bool + TextOff uint64 + HasTextOff bool + EntryPc uint64 + HasEntryPc bool + RodataHash uint64 + HasRodataHash bool + CalldestsHash uint64 + HasCallHash bool +} + +func (o *firedancerELFLoaderOutputCompat) expectsSuccess() bool { + if o == nil { + return false + } + return !o.HasErrCode || o.ErrCode == 0 +} + +type firedancerInstrFixture struct { + Input *InstrContext `protobuf:"bytes,2,opt,name=input,proto3" json:"input,omitempty"` + Output *InstrEffects `protobuf:"bytes,3,opt,name=output,proto3" json:"output,omitempty"` +} + +func (x *firedancerInstrFixture) Reset() { *x = firedancerInstrFixture{} } +func (x *firedancerInstrFixture) String() string { return legacyproto.CompactTextString(x) } +func (*firedancerInstrFixture) ProtoMessage() {} + +func unmarshalFiredancerELFLoaderFixture(data []byte) (*firedancerELFLoaderFixtureCompat, error) { + fixture := &ELFLoaderFixture{} + if err := proto.Unmarshal(data, fixture); err == nil && isELFData(fixture.GetInput().GetElf().GetData()) { + var output *firedancerELFLoaderOutputCompat + if fixture.GetOutput() != nil { + output = &firedancerELFLoaderOutputCompat{ + TextCnt: fixture.GetOutput().GetTextCnt(), + HasTextCnt: true, + TextOff: fixture.GetOutput().GetTextOff(), + HasTextOff: true, + EntryPc: fixture.GetOutput().GetEntryPc(), + HasEntryPc: true, + } + } + return &firedancerELFLoaderFixtureCompat{ + ElfData: fixture.GetInput().GetElf().GetData(), + Features: fixture.GetInput().GetFeatures(), + DeployChecks: fixture.GetInput().GetDeployChecks(), + Output: output, + }, nil + } + + currentFixture := &firedancerCurrentELFLoaderFixture{} + if err := legacyproto.Unmarshal(data, currentFixture); err != nil { + return nil, err + } + if currentFixture.Input == nil { + return nil, legacyproto.ErrNil + } + var features *FeatureSet + if currentFixture.Input != nil && currentFixture.Input.Features != nil { + features = &FeatureSet{Features: currentFixture.Input.Features.Features} + } + var output *firedancerELFLoaderOutputCompat + if currentFixture.Output != nil { + output = &firedancerELFLoaderOutputCompat{} + if currentFixture.Output.ErrCode != nil { + output.ErrCode = *currentFixture.Output.ErrCode + output.HasErrCode = true + } + if currentFixture.Output.TextCnt != nil { + output.TextCnt = *currentFixture.Output.TextCnt + output.HasTextCnt = true + } + if currentFixture.Output.TextOff != nil { + output.TextOff = *currentFixture.Output.TextOff + output.HasTextOff = true + } + if currentFixture.Output.EntryPc != nil { + output.EntryPc = *currentFixture.Output.EntryPc + output.HasEntryPc = true + } + if currentFixture.Output.RodataHash != nil { + output.RodataHash = *currentFixture.Output.RodataHash + output.HasRodataHash = true + } + if currentFixture.Output.CalldestsHash != nil { + output.CalldestsHash = *currentFixture.Output.CalldestsHash + output.HasCallHash = true + } + } + return &firedancerELFLoaderFixtureCompat{ + ElfData: currentFixture.Input.ElfData, + Features: features, + DeployChecks: currentFixture.Input.DeployChecks, + Output: output, + }, nil +} + +func isELFData(data []byte) bool { + return len(data) >= 4 && data[0] == 0x7f && data[1] == 'E' && data[2] == 'L' && data[3] == 'F' +} + +func unmarshalFiredancerInstrFixture(data []byte) (*InstrFixture, error) { + fixture := &InstrFixture{} + if err := proto.Unmarshal(data, fixture); err == nil && len(fixture.GetInput().GetProgramId()) == 32 { + return fixture, nil + } + + currentFixture := &firedancerInstrFixture{} + if err := legacyproto.Unmarshal(data, currentFixture); err != nil { + return nil, err + } + if currentFixture.Input == nil { + return nil, legacyproto.ErrNil + } + features, ok := currentInstrFeatures(data) + if ok { + currentFixture.Input.EpochContext = &EpochContext{Features: &FeatureSet{Features: features}} + } else if currentFixture.Input.EpochContext == nil { + currentFixture.Input.EpochContext = &EpochContext{} + } + if currentFixture.Input.SlotContext == nil { + currentFixture.Input.SlotContext = &SlotContext{} + } + return &InstrFixture{ + Input: currentFixture.Input, + Output: currentFixture.Output, + }, nil +} + +func currentInstrFeatures(data []byte) ([]uint64, bool) { + input, ok := consumeBytesField(data, 2) + if !ok { + return nil, false + } + epochContext, ok := consumeBytesField(input, 10) + if !ok { + return nil, false + } + featureSet, ok := consumeBytesField(epochContext, 1) + if !ok { + return nil, false + } + return consumePackedFeatureIds(featureSet), true +} + +func consumeBytesField(data []byte, want protowire.Number) ([]byte, bool) { + for len(data) > 0 { + num, typ, tagLen := protowire.ConsumeTag(data) + if tagLen < 0 { + return nil, false + } + data = data[tagLen:] + if typ == protowire.BytesType { + value, valueLen := protowire.ConsumeBytes(data) + if valueLen < 0 { + return nil, false + } + if num == want { + return value, true + } + data = data[valueLen:] + continue + } + valueLen := protowire.ConsumeFieldValue(num, typ, data) + if valueLen < 0 { + return nil, false + } + data = data[valueLen:] + } + return nil, false +} + +func consumeFeatureIds(data []byte) []uint64 { + var featureIds []uint64 + for len(data) > 0 { + num, typ, tagLen := protowire.ConsumeTag(data) + if tagLen < 0 { + return featureIds + } + data = data[tagLen:] + if num != 1 { + valueLen := protowire.ConsumeFieldValue(num, typ, data) + if valueLen < 0 { + return featureIds + } + data = data[valueLen:] + continue + } + switch typ { + case protowire.VarintType: + value, valueLen := protowire.ConsumeVarint(data) + if valueLen < 0 { + return featureIds + } + featureIds = append(featureIds, value) + data = data[valueLen:] + case protowire.Fixed64Type: + value, valueLen := protowire.ConsumeFixed64(data) + if valueLen < 0 { + return featureIds + } + featureIds = append(featureIds, value) + data = data[valueLen:] + case protowire.BytesType: + value, valueLen := protowire.ConsumeBytes(data) + if valueLen < 0 { + return featureIds + } + featureIds = append(featureIds, consumePackedFeatureIds(value)...) + data = data[valueLen:] + default: + valueLen := protowire.ConsumeFieldValue(num, typ, data) + if valueLen < 0 { + return featureIds + } + data = data[valueLen:] + } + } + return featureIds +} + +func consumePackedFeatureIds(data []byte) []uint64 { + var featureIds []uint64 + if len(data)%8 == 0 { + for len(data) > 0 { + featureIds = append(featureIds, binary.LittleEndian.Uint64(data[:8])) + data = data[8:] + } + return featureIds + } + + remaining := data + for len(remaining) > 0 { + value, valueLen := protowire.ConsumeVarint(remaining) + if valueLen < 0 { + featureIds = featureIds[:0] + break + } + featureIds = append(featureIds, value) + remaining = remaining[valueLen:] + } + return featureIds +} diff --git a/conformance/test_common.go b/conformance/test_common.go index 79f53edf..d082ace7 100644 --- a/conformance/test_common.go +++ b/conformance/test_common.go @@ -11,6 +11,7 @@ import ( "github.com/Overclock-Validator/mithril/pkg/cu" "github.com/Overclock-Validator/mithril/pkg/features" "github.com/Overclock-Validator/mithril/pkg/sealevel" + bin "github.com/gagliardetto/binary" "github.com/gagliardetto/solana-go" "github.com/stretchr/testify/assert" ) @@ -49,6 +50,14 @@ func instructionAcctsFromFixture(fixture *InstrFixture, transactionAccts sealeve } func configureSysvars(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { + configureSysvarsWithDefaults(execCtx, fixture, true) +} + +func configureSysvarsFromFixture(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { + configureSysvarsWithDefaults(execCtx, fixture, false) +} + +func configureSysvarsWithDefaults(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture, synthesizeDefaults bool) { /// clock var foundClockSysvar bool for _, acct := range fixture.Input.Accounts { @@ -67,7 +76,7 @@ func configureSysvars(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { } } - if !foundClockSysvar { + if !foundClockSysvar && synthesizeDefaults { fmt.Printf("******** setting default clock sysvar\n") var clock sealevel.SysvarClock clock.Slot = 10 @@ -98,7 +107,7 @@ func configureSysvars(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { } } - if !foundRentSysvar { + if !foundRentSysvar && synthesizeDefaults { var rent sealevel.SysvarRent rent.LamportsPerUint8Year = 3480 rent.ExemptionThreshold = 2.0 @@ -137,20 +146,17 @@ func configureSysvars(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { if solana.PublicKeyFromBytes(acct.Address) == sealevel.SysvarEpochScheduleAddr { fmt.Printf("adding state for sysvar: SysvarEpochSchedule\n") epochScheduleAcct := fixtureAcctStateToAccount(acct) - if len(epochScheduleAcct.Data) < sealevel.SysvarEpochScheduleStructLen { - fmt.Printf("******** epoch schedule data less than SysvarEpochScheduleStructLen\n") - break - } - execCtx.Accounts.SetAccount(&sealevel.SysvarEpochScheduleAddr, &epochScheduleAcct) - _, err := sealevel.ReadEpochScheduleSysvar(execCtx) + var epochSchedule sealevel.SysvarEpochSchedule + err := epochSchedule.UnmarshalWithDecoder(bin.NewBinDecoder(epochScheduleAcct.Data)) if err == nil { + execCtx.Accounts.SetAccount(&sealevel.SysvarEpochScheduleAddr, &epochScheduleAcct) foundEpochScheduleSysvar = true } } } - if !foundEpochScheduleSysvar { + if !foundEpochScheduleSysvar && synthesizeDefaults { fmt.Printf("******** adding default epoch schedule sysvar\n") epochSchedule := sealevel.SysvarEpochSchedule{SlotsPerEpoch: 432000, LeaderScheduleSlotOffset: 432000, Warmup: true, FirstNormalEpoch: 14, FirstNormalSlot: 524256} @@ -165,12 +171,23 @@ func configureSysvars(execCtx *sealevel.ExecutionCtx, fixture *InstrFixture) { if solana.PublicKeyFromBytes(acct.Address) == sealevel.SysvarEpochRewardsAddr { fmt.Printf("adding state for sysvar: SysvarEpochRewards\n") epochRewardsAcct := fixtureAcctStateToAccount(acct) - if len(epochRewardsAcct.Data) == sealevel.SysvarEpochRewardsStructLen { + var epochRewards sealevel.SysvarEpochRewards + if err := epochRewards.UnmarshalWithDecoder(bin.NewBinDecoder(epochRewardsAcct.Data)); err == nil { execCtx.Accounts.SetAccount(&sealevel.SysvarEpochRewardsAddr, &epochRewardsAcct) } } } + /// LastRestartSlot + for _, acct := range fixture.Input.Accounts { + if solana.PublicKeyFromBytes(acct.Address) == sealevel.SysvarLastRestartSlotAddr { + lastRestartSlotAcct := fixtureAcctStateToAccount(acct) + if len(lastRestartSlotAcct.Data) == sealevel.SysvarLastRestartSlotStructLen { + execCtx.Accounts.SetAccount(&sealevel.SysvarLastRestartSlotAddr, &lastRestartSlotAcct) + } + } + } + /// RecentBlockhashes for _, acct := range fixture.Input.Accounts { if solana.PublicKeyFromBytes(acct.Address) == sealevel.SysvarRecentBlockHashesAddr { @@ -226,7 +243,7 @@ func newExecCtxAndInstrAcctsFromFixture(fixture *InstrFixture) (*sealevel.Execut instr := sealevel.Instruction{Data: fixture.Input.Data} txCtx.AllInstructions = append(txCtx.AllInstructions, instr) - execCtx := sealevel.ExecutionCtx{TransactionContext: txCtx, ComputeMeter: cu.NewComputeMeter(fixture.Input.CuAvail)} + execCtx := sealevel.ExecutionCtx{TransactionContext: txCtx, ComputeMeter: cu.NewComputeMeter(fixture.Input.CuAvail), Log: &sealevel.LogRecorder{}} execCtx.Accounts = accounts.NewMemAccounts() configureSysvars(&execCtx, fixture) parseAndConfigureFeatures(&execCtx, fixture) diff --git a/conformance/vm_programs_test.go b/conformance/vm_programs_test.go new file mode 100644 index 00000000..bc7486a4 --- /dev/null +++ b/conformance/vm_programs_test.go @@ -0,0 +1,434 @@ +package conformance + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "runtime/debug" + "sort" + "strconv" + "strings" + "sync" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/accounts" + "github.com/Overclock-Validator/mithril/pkg/accountsdb" + "github.com/Overclock-Validator/mithril/pkg/cu" + sealevelPkg "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/gagliardetto/solana-go" +) + +func withoutConformanceStdout(fn func()) { + if os.Getenv("MITHRIL_CONFORMANCE_VERBOSE") != "" { + fn() + return + } + + devNull, err := os.Open(os.DevNull) + if err != nil { + fn() + return + } + defer devNull.Close() + + oldStdout := os.Stdout + os.Stdout = devNull + defer func() { + os.Stdout = oldStdout + }() + + fn() +} + +func instrReturnMatches(fixture *InstrFixture, err error) bool { + output := fixture.GetOutput() + if output == nil { + return err != nil + } + if err == nil { + return output.GetResult() == 0 + } + if output.GetResult() == 0 { + return false + } + if output.GetResult() == 26 && sealevelPkg.IsCustomErr(err) { + return uint32(sealevelPkg.TranslateErrToErrCode(err)) == output.GetCustomErr() + } + return int32(sealevelPkg.TranslateErrToErrCode(err)+1) == output.GetResult() +} + +var firedancerInstrResultNames = map[int32]string{ + 0: "Success", + 1: "GenericError", + 2: "InvalidArgument", + 3: "InvalidInstructionData", + 4: "InvalidAccountData", + 5: "AccountDataTooSmall", + 6: "InsufficientFunds", + 7: "IncorrectProgramId", + 8: "MissingRequiredSignature", + 9: "AccountAlreadyInitialized", + 10: "UninitializedAccount", + 11: "UnbalancedInstruction", + 12: "ModifiedProgramId", + 13: "ExternalAccountLamportSpend", + 14: "ExternalAccountDataModified", + 15: "ReadonlyLamportChange", + 16: "ReadonlyDataModified", + 17: "DuplicateAccountIndex", + 18: "ExecutableModified", + 19: "RentEpochModified", + 20: "NotEnoughAccountKeys", + 21: "AccountDataSizeChanged", + 22: "AccountNotExecutable", + 23: "AccountBorrowFailed", + 24: "AccountBorrowOutstanding", + 25: "DuplicateAccountOutOfSync", + 26: "Custom", + 27: "InvalidError", + 28: "ExecutableDataModified", + 29: "ExecutableLamportChange", + 30: "ExecutableAccountNotRentExempt", + 31: "UnsupportedProgramId", + 32: "CallDepth", + 33: "MissingAccount", + 34: "ReentrancyNotAllowed", + 35: "MaxSeedLengthExceeded", + 36: "InvalidSeeds", + 37: "InvalidRealloc", + 38: "ComputationalBudgetExceeded", + 39: "PrivilegeEscalation", + 40: "ProgramEnvironmentSetupFailure", + 41: "ProgramFailedToComplete", + 42: "ProgramFailedToCompile", + 43: "Immutable", + 44: "IncorrectAuthority", + 45: "BorshIoError", + 46: "AccountNotRentExempt", + 47: "InvalidAccountOwner", + 48: "ArithmeticOverflow", + 49: "UnsupportedSysvar", + 50: "IllegalOwner", + 51: "MaxAccountsDataAllocationsExceeded", + 52: "MaxAccountsExceeded", + 53: "MaxInstructionTraceLengthExceeded", + 54: "BuiltinProgramsMustConsumeComputeUnits", +} + +func firedancerInstrResultName(result int32) string { + if name, ok := firedancerInstrResultNames[result]; ok { + return name + } + return fmt.Sprintf("UnknownResult(%d)", result) +} + +func instrResultFromErr(err error) int32 { + if err == nil { + return 0 + } + if sealevelPkg.IsCustomErr(err) { + return 26 + } + return int32(sealevelPkg.TranslateErrToErrCode(err) + 1) +} + +type conformanceBucket struct { + key string + count int +} + +func topConformanceBuckets(counts map[string]int, limit int) []conformanceBucket { + buckets := make([]conformanceBucket, 0, len(counts)) + for key, count := range counts { + buckets = append(buckets, conformanceBucket{key: key, count: count}) + } + sort.Slice(buckets, func(i, j int) bool { + if buckets[i].count != buckets[j].count { + return buckets[i].count > buckets[j].count + } + return buckets[i].key < buckets[j].key + }) + if len(buckets) > limit { + buckets = buckets[:limit] + } + return buckets +} + +func returnMismatchBucket(fixture *InstrFixture, err error) string { + gotResult := instrResultFromErr(err) + gotCode := 0 + if err != nil { + gotCode = sealevelPkg.TranslateErrToErrCode(err) + } + wantResult := fixture.GetOutput().GetResult() + return fmt.Sprintf("got=%s(%d) got_code=%d got_err=%v want=%s(%d) want_custom=%d", + firedancerInstrResultName(gotResult), gotResult, gotCode, err, + firedancerInstrResultName(wantResult), wantResult, fixture.GetOutput().GetCustomErr()) +} + +func fixtureProgramSummary(fixture *InstrFixture) string { + input := fixture.GetInput() + if input == nil { + return "missing_input" + } + programId := solana.PublicKeyFromBytes(input.GetProgramId()) + for i, acct := range input.GetAccounts() { + key := solana.PublicKeyFromBytes(acct.GetAddress()) + if key != programId { + continue + } + owner := solana.PublicKeyFromBytes(acct.GetOwner()) + prefixLen := min(len(acct.GetData()), 8) + return fmt.Sprintf("program_idx=%d owner=%s exec=%v lamports=%d data_len=%d data_prefix=%x", + i, owner, acct.GetExecutable(), acct.GetLamports(), len(acct.GetData()), acct.GetData()[:prefixLen]) + } + return fmt.Sprintf("program=%s missing_from_accounts", programId) +} + +func newVMProgramExecCtxAndInstrAccts(fixture *InstrFixture) (*sealevelPkg.ExecutionCtx, []sealevelPkg.InstructionAccount, []uint64, error) { + input := fixture.GetInput() + if input == nil { + return nil, nil, nil, fmt.Errorf("missing input") + } + + accts := make([]accounts.Account, 0, len(input.GetAccounts())) + for _, acctState := range input.GetAccounts() { + accts = append(accts, fixtureAcctStateToAccount(acctState)) + } + + transactionAccts := sealevelPkg.NewTransactionAccounts(accts) + instrAccts := instructionAcctsFromFixture(fixture, *transactionAccts) + + txCtx := sealevelPkg.NewTransactionCtx(*transactionAccts, 8, 128) + txCtx.ComputeBudgetLimits = &sealevelPkg.ComputeBudgetLimits{ + UpdatedHeapBytes: sealevelPkg.MinHeapFrameBytes, + ComputeUnitLimit: sealevelPkg.MaxComputeUnitLimit, + LoadedAccountBytes: sealevelPkg.MaxLoadedAccountsDataSizeBytes, + } + + programId := solana.PublicKeyFromBytes(input.GetProgramId()) + txCtx.AllInstructions = append(txCtx.AllInstructions, sealevelPkg.Instruction{ + Data: input.GetData(), + ProgramId: programId, + }) + + execCtx := sealevelPkg.ExecutionCtx{ + TransactionContext: txCtx, + ComputeMeter: cu.NewComputeMeter(input.GetCuAvail()), + Log: &sealevelPkg.LogRecorder{}, + } + execCtx.Accounts = accounts.NewMemAccounts() + execCtx.Features = *parsePBFeatures(input.GetEpochContext().GetFeatures()) + + withoutConformanceStdout(func() { + configureSysvarsFromFixture(&execCtx, fixture) + }) + + programCacheDb := &accountsdb.AccountsDb{} + programCacheDb.InitCaches() + + slotAccounts := accounts.NewMemAccounts() + for i := range accts { + acct := accts[i] + key := [32]byte(acct.Key) + if err := slotAccounts.SetAccount(&key, &acct); err != nil { + return nil, nil, nil, err + } + } + + slot := input.GetSlotContext().GetSlot() + if slot == 0 { + slot = ^uint64(0) + } + execCtx.SlotCtx = &sealevelPkg.SlotCtx{ + Accounts: slotAccounts, + ParentAccts: accounts.NewMemAccounts(), + AccountsDb: programCacheDb, + Slot: slot, + AcctMapsMu: &sync.Mutex{}, + ModifiedAccts: make(map[solana.PublicKey]bool), + WritableAccts: make(map[solana.PublicKey]bool), + Features: &execCtx.Features, + } + + programIndex, err := txCtx.IndexOfAccount(programId) + if err != nil { + return nil, nil, nil, err + } + + return &execCtx, instrAccts, []uint64{programIndex}, nil +} + +func TestConformance_VMPrograms_Firedancer(t *testing.T) { + basePath := "test-vectors/instr/fixtures/vm-programs" + + entries, err := os.ReadDir(basePath) + if err != nil { + t.Skipf("test-vectors not available: %v", err) + } + + var fixtures []string + filter := os.Getenv("MITHRIL_CONFORMANCE_FIXTURE") + for _, entry := range entries { + if strings.HasSuffix(entry.Name(), ".fix") { + if filter != "" && !strings.Contains(entry.Name(), filter) { + continue + } + fixtures = append(fixtures, filepath.Join(basePath, entry.Name())) + } + } + sort.Strings(fixtures) + + if len(fixtures) == 0 { + t.Skip("no .fix fixtures found") + } + + t.Logf("Found %d VM program fixtures", len(fixtures)) + + var ( + total int + pass int + parseErrors int + setupErrors int + returnMismatches int + accountMismatches int + returnDataMismatches int + panics int + ) + var failures []string + panicBuckets := make(map[string]int) + returnBuckets := make(map[string]int) + + for _, fixturePath := range fixtures { + total++ + name := filepath.Base(fixturePath) + + data, err := os.ReadFile(fixturePath) + if err != nil { + parseErrors++ + failures = append(failures, fmt.Sprintf("READ_ERROR %s: %v", name, err)) + continue + } + + fixture, err := unmarshalFiredancerInstrFixture(data) + if err != nil { + parseErrors++ + failures = append(failures, fmt.Sprintf("PARSE_ERROR %s: %v", name, err)) + continue + } + + var execCtx *sealevelPkg.ExecutionCtx + var execErr error + var accountStateMatches bool + var returnDataMatches bool + var didPanic bool + + func() { + defer func() { + if r := recover(); r != nil { + didPanic = true + panics++ + panicBuckets[fmt.Sprint(r)]++ + panicMsg := fmt.Sprintf("PANIC %s: %v", name, r) + if os.Getenv("MITHRIL_CONFORMANCE_STACKS") != "" { + panicMsg = fmt.Sprintf("%s\n%s", panicMsg, debug.Stack()) + } + failures = append(failures, panicMsg) + } + }() + + var instrAccts []sealevelPkg.InstructionAccount + var programIndices []uint64 + execCtx, instrAccts, programIndices, err = newVMProgramExecCtxAndInstrAccts(fixture) + if err != nil { + setupErrors++ + failures = append(failures, fmt.Sprintf("SETUP_ERROR %s: %v", name, err)) + return + } + + withoutConformanceStdout(func() { + execErr = execCtx.ProcessInstruction(fixture.GetInput().GetData(), instrAccts, programIndices) + }) + + if execErr == nil { + accountStateMatches = accountStateChangesMatch(t, execCtx, fixture) + _, gotReturnData := execCtx.TransactionContext.ReturnData() + returnDataMatches = bytes.Equal(gotReturnData, fixture.GetOutput().GetReturnData()) + } + }() + + if didPanic || err != nil { + continue + } + + if !instrReturnMatches(fixture, execErr) { + returnMismatches++ + returnBuckets[returnMismatchBucket(fixture, execErr)]++ + gotResult := instrResultFromErr(execErr) + wantResult := fixture.GetOutput().GetResult() + failures = append(failures, fmt.Sprintf("RETURN_MISMATCH %s: got=%s(%d) got_err=%v want=%s(%d) custom=%d %s", + name, firedancerInstrResultName(gotResult), gotResult, execErr, + firedancerInstrResultName(wantResult), wantResult, fixture.GetOutput().GetCustomErr(), fixtureProgramSummary(fixture))) + continue + } + if execErr == nil && !accountStateMatches { + accountMismatches++ + failures = append(failures, fmt.Sprintf("ACCOUNT_MISMATCH %s", name)) + continue + } + if execErr == nil && !returnDataMatches { + returnDataMismatches++ + failures = append(failures, fmt.Sprintf("RETURN_DATA_MISMATCH %s", name)) + continue + } + + pass++ + } + + sort.Strings(failures) + + t.Logf("\n=== VM Program Conformance Results ===") + t.Logf("Total fixtures: %d", total) + t.Logf("Passed: %d", pass) + t.Logf("Parse errors: %d", parseErrors) + t.Logf("Setup errors: %d", setupErrors) + t.Logf("Return mismatches: %d", returnMismatches) + t.Logf("Account mismatches: %d", accountMismatches) + t.Logf("Return data mismatches: %d", returnDataMismatches) + t.Logf("Panics: %d", panics) + + if len(failures) > 0 { + limit := 50 + if envLimit := os.Getenv("MITHRIL_CONFORMANCE_FAILURE_LIMIT"); envLimit != "" { + if parsed, err := strconv.Atoi(envLimit); err == nil && parsed > 0 { + limit = parsed + } + } + t.Logf("\n=== First %d failures ===", limit) + if len(failures) < limit { + limit = len(failures) + } + for _, failure := range failures[:limit] { + t.Logf(" %s", failure) + } + } + + if len(panicBuckets) > 0 { + t.Logf("\n=== Panic Buckets ===") + for _, bucket := range topConformanceBuckets(panicBuckets, 20) { + t.Logf(" %dx %s", bucket.count, bucket.key) + } + } + + if len(returnBuckets) > 0 { + t.Logf("\n=== Return Mismatch Buckets ===") + for _, bucket := range topConformanceBuckets(returnBuckets, 20) { + t.Logf(" %dx %s", bucket.count, bucket.key) + } + } + + if pass != total { + t.Errorf("VM program conformance failures: %d/%d failed", total-pass, total) + } +} diff --git a/pkg/accountsdb/accountsdb.go b/pkg/accountsdb/accountsdb.go index 11693e38..a057c83d 100644 --- a/pkg/accountsdb/accountsdb.go +++ b/pkg/accountsdb/accountsdb.go @@ -163,6 +163,9 @@ func (accountsDb *AccountsDb) CloseDb() { func (accountsDb *AccountsDb) InitCaches() { var err error + if accountsDb.inProgressStoreRequests == nil { + accountsDb.inProgressStoreRequests = list.New() + } accountsDb.VoteAcctCache, err = otter.MustBuilder[solana.PublicKey, *accounts.Account](2500). Cost(func(key solana.PublicKey, acct *accounts.Account) uint32 { return 1 @@ -197,6 +200,9 @@ type ProgramCacheEntry struct { } func (accountsDb *AccountsDb) MaybeGetProgramFromCache(pubkey solana.PublicKey) (*ProgramCacheEntry, bool) { + if accountsDb == nil { + return nil, false + } return accountsDb.ProgramCache.Get(pubkey) } @@ -209,6 +215,9 @@ func (accountsDb *AccountsDb) RemoveProgramFromCache(pubkey solana.PublicKey) { } func (accountsDb *AccountsDb) GetAccount(slot uint64, pubkey solana.PublicKey) (*accounts.Account, error) { + if accountsDb == nil { + return nil, ErrNoAccount + } accts := accountsDb.getStoreInProgressAccounts([]solana.PublicKey{pubkey}) if accts[0] != nil { return accts[0], nil @@ -217,6 +226,9 @@ func (accountsDb *AccountsDb) GetAccount(slot uint64, pubkey solana.PublicKey) ( } func (accountsDb *AccountsDb) getStoredAccount(slot uint64, pubkey solana.PublicKey) (*accounts.Account, error) { + if accountsDb.Index == nil { + return nil, ErrNoAccount + } r := trace.StartRegion(context.Background(), "GetStoredAccountCache") cachedAcct, hasAcct := accountsDb.VoteAcctCache.Get(pubkey) if hasAcct { diff --git a/pkg/features/features.go b/pkg/features/features.go index 3c7820d2..8d2a7f04 100644 --- a/pkg/features/features.go +++ b/pkg/features/features.go @@ -2,6 +2,8 @@ package features import ( "fmt" + + "github.com/Overclock-Validator/mithril/pkg/base58" ) type FeatureGate struct { @@ -37,6 +39,10 @@ func (f *Features) IsActive(gate FeatureGate) bool { } } +func (f *Features) IsSbpfV3DeploymentAndExecutionActive() bool { + return f.IsActive(EnableSbpfV3DeploymentAndExecution) +} + func (f *Features) ActivationSlot(gate FeatureGate) (uint64, bool) { if !f.IsActive(gate) { return 0, false @@ -48,7 +54,7 @@ func (f *Features) AllEnabled() []string { enabledFeatureStrs := make([]string, 0) for feat, enabled := range *f { if enabled.Enabled { - enabledFeatureStrs = append(enabledFeatureStrs, fmt.Sprintf("feature %s (%s) enabled", feat.Name, feat.Address)) + enabledFeatureStrs = append(enabledFeatureStrs, fmt.Sprintf("feature %s (%s) enabled", feat.Name, base58.Encode(feat.Address[:]))) } } return enabledFeatureStrs diff --git a/pkg/features/features_test.go b/pkg/features/features_test.go index e90d3929..76d2dc65 100644 --- a/pkg/features/features_test.go +++ b/pkg/features/features_test.go @@ -3,6 +3,7 @@ package features import ( "testing" + "github.com/Overclock-Validator/mithril/pkg/base58" "github.com/stretchr/testify/assert" ) @@ -25,3 +26,49 @@ func TestFflags_ListEnabled(t *testing.T) { f.EnableFeature(StopTruncatingStringsInSyscalls, 0) assert.Equal(t, f.AllEnabled(), []string{"feature StopTruncatingStringsInSyscalls (16FMCmgLzCNNz6eTwGanbyN2ZxvTBSLuQ6DZhgeMshg) enabled"}) } + +func TestValidateChainedBlockIdFeatureGate(t *testing.T) { + assert.Equal(t, "ValidateChainedBlockId", ValidateChainedBlockId.Name) + assert.Equal(t, base58.MustDecodeFromString("vcmrbYbiMVKaq1snKP6eCacNDcr6qZvpCNUjmk6gxvZ"), ValidateChainedBlockId.Address) + assert.Contains(t, AllFeatureGates, ValidateChainedBlockId) +} + +func TestDiscardUnexpectedDataCompleteShredsFeatureGate(t *testing.T) { + assert.Equal(t, "DiscardUnexpectedDataCompleteShreds", DiscardUnexpectedDataCompleteShreds.Name) + assert.Equal(t, base58.MustDecodeFromString("dcomRRWHXP1FVWPqi9Mm4oxJhF4ehC795SvAtUdA9os"), DiscardUnexpectedDataCompleteShreds.Address) + assert.Contains(t, AllFeatureGates, DiscardUnexpectedDataCompleteShreds) +} + +func TestEnableSbpfV3DeploymentAndExecutionFeatureGates(t *testing.T) { + assert.Equal(t, "EnableSbpfV3DeploymentAndExecution", EnableSbpfV3DeploymentAndExecution.Name) + assert.Equal(t, base58.MustDecodeFromString("5cC3foj77CWun58pC51ebHFUWavHWKarWyR5UUik7dnC"), EnableSbpfV3DeploymentAndExecution.Address) + assert.Contains(t, AllFeatureGates, EnableSbpfV3DeploymentAndExecution) + + f := NewFeaturesDefault() + f.EnableFeature(EnableSbpfV3DeploymentAndExecution, 0) + assert.True(t, f.IsSbpfV3DeploymentAndExecutionActive()) +} + +func TestSyscallParameterAddressRestrictionsFeatureGate(t *testing.T) { + assert.Equal(t, "SyscallParameterAddressRestrictions", SyscallParameterAddressRestrictions.Name) + assert.Equal(t, base58.MustDecodeFromString("EDGMC5kxFxGk4ixsNkGt8bW7QL5hDMXnbwaZvYMwNfzF"), SyscallParameterAddressRestrictions.Address) + assert.Contains(t, AllFeatureGates, SyscallParameterAddressRestrictions) +} + +func TestBlake3SyscallEnabledFeatureGate(t *testing.T) { + assert.Equal(t, "Blake3SyscallEnabled", Blake3SyscallEnabled.Name) + assert.Equal(t, base58.MustDecodeFromString("HTW2pSyErTj4BV6KBM9NZ9VBUJVxt7sacNWcf76wtzb3"), Blake3SyscallEnabled.Address) + assert.Contains(t, AllFeatureGates, Blake3SyscallEnabled) +} + +func TestVirtualAddressSpaceAdjustmentsFeatureGate(t *testing.T) { + assert.Equal(t, "VirtualAddressSpaceAdjustments", VirtualAddressSpaceAdjustments.Name) + assert.Equal(t, base58.MustDecodeFromString("7VgiehxNxu53KdxgLspGQY8myE6f7UokaWa4jsGcaSz"), VirtualAddressSpaceAdjustments.Address) + assert.Contains(t, AllFeatureGates, VirtualAddressSpaceAdjustments) +} + +func TestAccountDataDirectMappingFeatureGate(t *testing.T) { + assert.Equal(t, "AccountDataDirectMapping", AccountDataDirectMapping.Name) + assert.Equal(t, base58.MustDecodeFromString("CR3dVN2Yoo95Y96kLSTaziWDAQT2MNEpiWh5cqVq2pNE"), AccountDataDirectMapping.Address) + assert.Contains(t, AllFeatureGates, AccountDataDirectMapping) +} diff --git a/pkg/features/gates.go b/pkg/features/gates.go index 4830b818..d8973013 100644 --- a/pkg/features/gates.go +++ b/pkg/features/gates.go @@ -26,6 +26,7 @@ var RelaxAuthoritySignerCheckForLookupTableCreation = FeatureGate{Name: "RelaxAu var DedupeConfigProgramSigners = FeatureGate{Name: "DedupeConfigProgramSigners", Address: base58.MustDecodeFromString("8kEuAshXLsgkUEdcFVLqrjCGGHVWFW99ZZpxvAzzMtBp")} var Ed25519PrecompileVerifyStrict = FeatureGate{Name: "Ed25519PrecompileVerifyStrict", Address: base58.MustDecodeFromString("ed9tNscbWLYBooxWA7FE2B5KHWs8A6sxfY8EzezEcoo")} var AbortOnInvalidCurve = FeatureGate{Name: "AbortOnInvalidCurve", Address: base58.MustDecodeFromString("FuS3FPfJDKSNot99ECLXtp3rueq36hMNStJkPJwWodLh")} +var Blake3SyscallEnabled = FeatureGate{Name: "Blake3SyscallEnabled", Address: base58.MustDecodeFromString("HTW2pSyErTj4BV6KBM9NZ9VBUJVxt7sacNWcf76wtzb3")} var Curve25519SyscallEnabled = FeatureGate{Name: "Curve25519SyscallEnabled", Address: base58.MustDecodeFromString("7rcw5UtqgDTBBv2EcynNfYckgdAaH1MAsCjKgXMkN7Ri")} var SimplifyAltBn128SyscallErrorCodes = FeatureGate{Name: "SimplityAltBn128SyscallErrorCodes", Address: base58.MustDecodeFromString("JDn5q3GBeqzvUa7z67BbmVHVdE3EbUAjvFep3weR3jxX")} var EnableAltbn128CompressionSyscall = FeatureGate{Name: "EnableAltbn128CompressionSyscall", Address: base58.MustDecodeFromString("EJJewYSddEEtSZHiqugnvhQHiWyZKjkFDQASd7oKSagn")} @@ -56,7 +57,7 @@ var RemoveAccountsDeltaHash = FeatureGate{Name: "RemoveAccountsDeltaHash", Addre var EnableLoaderV4 = FeatureGate{Name: "EnableLoaderV4", Address: base58.MustDecodeFromString("2aQJYqER2aKyb3cZw22v4SL2xMX7vwXBRWfvS4pTrtED")} var EnableSbpfV1DeploymentAndExecution = FeatureGate{Name: "EnableSbpfV1DeploymentAndExecution", Address: base58.MustDecodeFromString("JE86WkYvTrzW8HgNmrHY7dFYpCmSptUpKupbo2AdQ9cG")} var EnableSbpfV2DeploymentAndExecution = FeatureGate{Name: "EnableSbpfV2DeploymentAndExecution", Address: base58.MustDecodeFromString("F6UVKh1ujTEFK3en2SyAL3cdVnqko1FVEXWhmdLRu6WP")} -var EnableSbpfV3DeploymentAndExecution = FeatureGate{Name: "EnableSbpfV3DeploymentAndExecution", Address: base58.MustDecodeFromString("C8XZNs1bfzaiT3YDeXZJ7G5swQWQv7tVzDnCxtHvnSpw")} +var EnableSbpfV3DeploymentAndExecution = FeatureGate{Name: "EnableSbpfV3DeploymentAndExecution", Address: base58.MustDecodeFromString("5cC3foj77CWun58pC51ebHFUWavHWKarWyR5UUik7dnC")} var DisableSbpfV0Execution = FeatureGate{Name: "DisableSbpfV0Execution", Address: base58.MustDecodeFromString("TestFeature11111111111111111111111111111111")} var ReenableSbpfV0Execution = FeatureGate{Name: "ReenableSbpfV0Execution", Address: base58.MustDecodeFromString("TestFeature21111111111111111111111111111111")} var FormalizeLoadedTransactionDataSize = FeatureGate{Name: "FormalizeLoadedTransactionDataSize", Address: base58.MustDecodeFromString("DeS7sR48ZcFTUmt5FFEVDr1v1bh73aAbZiZq3SYr8Eh8")} @@ -66,6 +67,9 @@ var PoseidonEnforcePadding = FeatureGate{Name: "PoseidonEnforcePadding", Address var FixAltBn128PairingLengthCheck = FeatureGate{Name: "FixAltBn128PairingLengthCheck", Address: base58.MustDecodeFromString("bnYzodLwmybj7e1HAe98yZrdJTd7we69eMMLgCXqKZm")} var DeprecateRentExemptionThreshold = FeatureGate{Name: "DeprecateRentExemptionThreshold", Address: base58.MustDecodeFromString("rent6iVy6PDoViPBeJ6k5EJQrkj62h7DPyLbWGHwjrC")} var ProvideInstructionDataOffsetInVmR2 = FeatureGate{Name: "ProvideInstructionDataOffsetInVmR2", Address: base58.MustDecodeFromString("5xXZc66h4UdB6Yq7FzdBxBiRAFMMScMLwHxk2QZDaNZL")} +var SyscallParameterAddressRestrictions = FeatureGate{Name: "SyscallParameterAddressRestrictions", Address: base58.MustDecodeFromString("EDGMC5kxFxGk4ixsNkGt8bW7QL5hDMXnbwaZvYMwNfzF")} +var VirtualAddressSpaceAdjustments = FeatureGate{Name: "VirtualAddressSpaceAdjustments", Address: base58.MustDecodeFromString("7VgiehxNxu53KdxgLspGQY8myE6f7UokaWa4jsGcaSz")} +var AccountDataDirectMapping = FeatureGate{Name: "AccountDataDirectMapping", Address: base58.MustDecodeFromString("CR3dVN2Yoo95Y96kLSTaziWDAQT2MNEpiWh5cqVq2pNE")} var VoteStateV4 = FeatureGate{Name: "VoteStateV4", Address: base58.MustDecodeFromString("Gx4XFcrVMt4HUvPzTpTSVkdDVgcDSjKhDN1RqRS6KDuZ")} var RelaxProgramdataAccountCheckMigration = FeatureGate{Name: "RelaxProgramdataAccountCheckMigration", Address: base58.MustDecodeFromString("rexav5eNTUSNT1K2N7cfRjnthwhcP5BC25v2tA4rW4h")} var ReplaceSplTokenWithPToken = FeatureGate{Name: "ReplaceSplTokenWithPToken", Address: base58.MustDecodeFromString("ptokFjwyJtrwCa9Kgo9xoDS59V4QccBGEaRFnRPnSdP")} @@ -78,6 +82,8 @@ var EnableAltBn128G2Syscalls = FeatureGate{Name: "EnableAltBn128G2Syscalls", Add var EnableBls12_381Syscall = FeatureGate{Name: "EnableBls12_381Syscall", Address: base58.MustDecodeFromString("b1sgUiJ3qu7hYm3tNDyyqZNQd6gLGJmJppnLNa93PCQ")} var UpgradeBpfStakeProgramToV5 = FeatureGate{Name: "UpgradeBpfStakeProgramToV5", Address: base58.MustDecodeFromString("STk5Xj8hdAx3sTzmtJ3QysKkq6X2A3yj73JtxttiRyk")} var DelayCommissionUpdates = FeatureGate{Name: "DelayCommissionUpdates", Address: base58.MustDecodeFromString("76dHtohc2s5dR3ahJyBxs7eJJVipFkaPdih9CLgTTb4B")} +var ValidateChainedBlockId = FeatureGate{Name: "ValidateChainedBlockId", Address: base58.MustDecodeFromString("vcmrbYbiMVKaq1snKP6eCacNDcr6qZvpCNUjmk6gxvZ")} +var DiscardUnexpectedDataCompleteShreds = FeatureGate{Name: "DiscardUnexpectedDataCompleteShreds", Address: base58.MustDecodeFromString("dcomRRWHXP1FVWPqi9Mm4oxJhF4ehC795SvAtUdA9os")} var AllFeatureGates = []FeatureGate{StopTruncatingStringsInSyscalls, EnablePartitionedEpochReward, EnablePartitionedEpochRewardsSuperfeature, LastRestartSlotSysvar, Libsecp256k1FailOnBadCount, Libsecp256k1FailOnBadCount2, EnableBpfLoaderSetAuthorityCheckedIx, @@ -85,7 +91,7 @@ var AllFeatureGates = []FeatureGate{StopTruncatingStringsInSyscalls, EnableParti CommissionUpdatesOnlyAllowedInFirstHalfOfEpoch, TimelyVoteCredits, ReduceStakeWarmupCooldown, StakeRaiseMinimumDelegationTo1Sol, StakeRedelegateInstruction, RequireRentExemptSplitDestination, DeprecateExecutableMetaUpdateInBpfLoader, RelaxAuthoritySignerCheckForLookupTableCreation, DedupeConfigProgramSigners, - Ed25519PrecompileVerifyStrict, AbortOnInvalidCurve, Curve25519SyscallEnabled, SimplifyAltBn128SyscallErrorCodes, + Ed25519PrecompileVerifyStrict, AbortOnInvalidCurve, Blake3SyscallEnabled, Curve25519SyscallEnabled, SimplifyAltBn128SyscallErrorCodes, EnableAltbn128CompressionSyscall, EnableAltBn128Syscall, DisableRentFeesCollection, DeprecateUnusedLegacyVotePlumbing, RewardFullPriorityFee, StakeMinimumDelegationForRewards, MoveStakeAndMoveLamportsIxs, GetSysvarSyscallEnabled, AddNewReservedAccountKeys, EnableSecp256r1Precompile, FixAltBn128MultiplicationInputLength, EnableTowerSyncIx, SkipRentRewrites, @@ -94,6 +100,7 @@ var AllFeatureGates = []FeatureGate{StopTruncatingStringsInSyscalls, EnableParti AccountsLtHash, RemoveAccountsDeltaHash, EnableLoaderV4, EnableSbpfV1DeploymentAndExecution, EnableSbpfV2DeploymentAndExecution, EnableSbpfV3DeploymentAndExecution, DisableSbpfV0Execution, ReenableSbpfV0Execution, FormalizeLoadedTransactionDataSize, IncreaseCpiAccountInfoLimit, StaticInstructionLimit, PoseidonEnforcePadding, FixAltBn128PairingLengthCheck, DeprecateRentExemptionThreshold, - ProvideInstructionDataOffsetInVmR2, VoteStateV4, RelaxProgramdataAccountCheckMigration, ReplaceSplTokenWithPToken, CreateAccountAllowPrefund, + ProvideInstructionDataOffsetInVmR2, SyscallParameterAddressRestrictions, VirtualAddressSpaceAdjustments, AccountDataDirectMapping, + VoteStateV4, RelaxProgramdataAccountCheckMigration, ReplaceSplTokenWithPToken, CreateAccountAllowPrefund, RemoveSimpleVoteFromCostModel, DisableZkElgamalProofProgram, ReenableZkElgamalProofProgram, AltBn128LittleEndian, EnableAltBn128G2Syscalls, - EnableBls12_381Syscall, UpgradeBpfStakeProgramToV5, DelayCommissionUpdates} + EnableBls12_381Syscall, UpgradeBpfStakeProgramToV5, DelayCommissionUpdates, ValidateChainedBlockId, DiscardUnexpectedDataCompleteShreds} diff --git a/pkg/sbpf/asm.go b/pkg/sbpf/asm.go index f92f421c..f10c4c43 100644 --- a/pkg/sbpf/asm.go +++ b/pkg/sbpf/asm.go @@ -154,7 +154,7 @@ func (ip *Interpreter) disassemble(slot Slot, slot2 Slot) string { case OpJeqReg, OpJgtReg, OpJgeReg, OpJltReg, OpJleReg, OpJsetReg, OpJneReg, OpJsgtReg, OpJsgeReg, OpJsltReg, OpJsleReg: return fmt.Sprintf("%s r%d, r%d", mnemonic, slot.Dst(), slot.Src()) case OpCall: - return fmt.Sprintf("call") + return fmt.Sprintf("call %#x", slot.Uimm()) case OpCallx: return fmt.Sprintf("callx") case OpExit: diff --git a/pkg/sbpf/interpreter.go b/pkg/sbpf/interpreter.go index 41f07442..4fa0f784 100644 --- a/pkg/sbpf/interpreter.go +++ b/pkg/sbpf/interpreter.go @@ -21,12 +21,14 @@ import ( // Interpreter implements the SBF core in pure Go. type Interpreter struct { - textVA uint64 - text []Slot - ro []byte - stack Stack - heap []byte + textVA uint64 + textBytes []byte + text []Slot + ro []byte + stack Stack + heap []byte input []byte + inputRegions []InputRegion inputDataVaddr uint64 entry uint64 @@ -81,11 +83,13 @@ func NewInterpreter(p *Program, opts *VMOpts) *Interpreter { return &Interpreter{ textVA: p.TextVA, + textBytes: p.TextBytes, text: p.Text, ro: p.RO, - stack: NewStack(p.SbpfVersion), + stack: NewStack(p.SbpfVersion, opts.DisableStackFrameGaps), heap: heap, input: opts.Input, + inputRegions: opts.InputRegions, inputDataVaddr: opts.InputDataVaddr, entry: p.Entrypoint, syscalls: opts.Syscalls, @@ -109,6 +113,89 @@ func (ip *Interpreter) Finish() { ip.stack.Finish() } +func (ip *Interpreter) executeJmp32(ins Slot, pc int64, r *[11]uint64) (int64, error) { + var taken bool + dst := uint32(r[ins.Dst()]) + src := uint32(r[ins.Src()]) + imm := ins.Uimm() + + switch ins.Op() & 0xf0 { + case JumpEq: + if ins.Op()&SrcX != 0 { + taken = dst == src + } else { + taken = dst == imm + } + case JumpGt: + if ins.Op()&SrcX != 0 { + taken = dst > src + } else { + taken = dst > imm + } + case JumpGe: + if ins.Op()&SrcX != 0 { + taken = dst >= src + } else { + taken = dst >= imm + } + case JumpLt: + if ins.Op()&SrcX != 0 { + taken = dst < src + } else { + taken = dst < imm + } + case JumpLe: + if ins.Op()&SrcX != 0 { + taken = dst <= src + } else { + taken = dst <= imm + } + case JumpSet: + if ins.Op()&SrcX != 0 { + taken = dst&src != 0 + } else { + taken = dst&imm != 0 + } + case JumpNe: + if ins.Op()&SrcX != 0 { + taken = dst != src + } else { + taken = dst != imm + } + case JumpSgt: + if ins.Op()&SrcX != 0 { + taken = int32(dst) > int32(src) + } else { + taken = int32(dst) > ins.Imm() + } + case JumpSge: + if ins.Op()&SrcX != 0 { + taken = int32(dst) >= int32(src) + } else { + taken = int32(dst) >= ins.Imm() + } + case JumpSlt: + if ins.Op()&SrcX != 0 { + taken = int32(dst) < int32(src) + } else { + taken = int32(dst) < ins.Imm() + } + case JumpSle: + if ins.Op()&SrcX != 0 { + taken = int32(dst) <= int32(src) + } else { + taken = int32(dst) <= ins.Imm() + } + default: + return pc, ExcUnsupportedInstruction + } + + if taken { + pc += int64(ins.Off()) + } + return pc + 1, nil +} + // Run executes the program. // // This function may panic given code that doesn't pass the static verifier. @@ -135,7 +222,7 @@ mainLoop: if pc < 0 || pc >= int64(len(ip.text)) { return 0, 0, &Exception{ PC: pc, - Detail: fmt.Errorf("tx: %s, programId: %s - %s:", ip.txSignature, ip.programId, ExcExecutionOverrun), + Detail: fmt.Errorf("tx: %s, programId: %s - %w:", ip.txSignature, ip.programId, ExcExecutionOverrun), } } ins := ip.getSlot(pc) @@ -152,6 +239,10 @@ mainLoop: } // Execute + if ip.sbpfVersion.EnableJmp32() && ins.Op()&0x07 == ClassPqr { + pc, err = ip.executeJmp32(ins, pc, &r) + goto postExecute + } switch ins.Op() { case OpLdxb: if ip.sbpfVersion.MoveMemoryInstructionClasses() { @@ -933,39 +1024,68 @@ mainLoop: } pc++ case OpCall: - if sc, ok := ip.syscalls(ins.Uimm()); ok { - r[0], err = sc.Invoke(ip, r[1], r[2], r[3], r[4], r[5]) - if err != nil { - err = ExcSyscallError{Err: err} - } - pc++ - } else if target, ok := ip.funcs[ins.Uimm()]; ok { - ok = ip.stack.Push(r[:], pc+1) - if !ok { - err = ExcCallDepth + if ip.sbpfVersion.EnableStaticSyscalls() { + if ins.Src() == 0 { + sc, ok := ip.syscalls(ins.Uimm()) + if !ok { + err = ExcCallDest{ins.Uimm()} + break + } + r[0], err = sc.Invoke(ip, r[1], r[2], r[3], r[4], r[5]) + if err != nil { + err = ExcSyscallError{Err: err} + } + pc++ + } else if ins.Src() == 1 { + targetPC := ip.sbpfVersion.CalculateCallImmTargetPC(pc, ins.Imm()) + if targetPC < 0 || targetPC >= int64(len(ip.text)) { + err = ExcCallDest{uint32(targetPC)} + break + } + if ok := ip.stack.Push(r[:], pc+1); !ok { + err = ExcCallDepth + } + pc = targetPC + } else { + err = ExcUnsupportedInstruction } - pc = target } else { - err = ExcCallDest{ins.Uimm()} + if sc, ok := ip.syscalls(ins.Uimm()); ok { + r[0], err = sc.Invoke(ip, r[1], r[2], r[3], r[4], r[5]) + if err != nil { + err = ExcSyscallError{Err: err} + } + pc++ + } else if target, ok := ip.funcs[ins.Uimm()]; ok { + ok = ip.stack.Push(r[:], pc+1) + if !ok { + err = ExcCallDepth + } + pc = target + } else { + err = ExcCallDest{ins.Uimm()} + } } case OpCallx: var target uint64 if ip.sbpfVersion.CallXUsesSrcReg() { target = r[ins.Src()] + } else if ip.sbpfVersion.CallXUsesDstReg() { + target = r[ins.Dst()] } else { target = r[ins.Uimm()] } - target &= ^(uint64(0x7)) if target < ip.textVA || target >= VaddrStack || target >= ip.textVA+uint64(len(ip.text)*8) { err = NewExcBadAccess(target, 8, false, "jump out-of-bounds") break } + targetPC := int64((target - ip.textVA) / 8) if ok := ip.stack.Push(r[:], pc+1); !ok { err = ExcCallDepth break } - pc = int64((target - ip.textVA) / 8) + pc = targetPC case OpExit: var ok bool pc, ok = ip.stack.Pop(r[:]) @@ -979,6 +1099,7 @@ mainLoop: } // Post execute + postExecute: if err == cu.ErrComputeExceeded { err = ExcOutOfCU } @@ -986,7 +1107,7 @@ mainLoop: if err != nil { exc := &Exception{ PC: pc, - Detail: fmt.Errorf("tx: %s, programId: %s - %s:", ip.txSignature, ip.programId, err), + Detail: fmt.Errorf("tx: %s, programId: %s - %w:", ip.txSignature, ip.programId, err), } if IsLongIns(ins.Op()) { exc.PC-- // fix reported PC @@ -1035,7 +1156,30 @@ var emptySlice = reflect.ValueOf(emptyArray[:]).UnsafePointer() func (ip *Interpreter) translateInternal(addr uint64, size uint64, write bool) (unsafe.Pointer, error) { hi, lo := addr>>32, addr&math.MaxUint32 switch hi { + case 0: + if !ip.sbpfVersion.EnableLowerRodataVaddr() { + if size == 0 { + return emptySlice, nil + } + return nil, NewExcBadAccess(addr, size, write, "unmapped region") + } + if write { + return nil, NewExcBadAccess(addr, size, write, "write to program") + } + if size == 0 { + return emptySlice, nil + } + if addr+size < addr || addr+size > uint64(len(ip.ro)) { + return nil, NewExcBadAccess(addr, size, write, "out-of-bounds program read") + } + return unsafe.Pointer(&ip.ro[addr]), nil case VaddrProgram >> 32: + if ip.sbpfVersion.EnableLowerRodataVaddr() { + if size == 0 { + return emptySlice, nil + } + return nil, NewExcBadAccess(addr, size, write, "unmapped region") + } if write { return nil, NewExcBadAccess(addr, size, write, "write to program") } @@ -1067,6 +1211,9 @@ func (ip *Interpreter) translateInternal(addr uint64, size uint64, write bool) ( if size == 0 { return emptySlice, nil } + if len(ip.inputRegions) != 0 { + return ip.translateInputRegion(lo, size, write) + } if lo+size > uint64(len(ip.input)) { return nil, NewExcBadAccess(addr, size, write, "out-of-bounds input access") } @@ -1079,6 +1226,122 @@ func (ip *Interpreter) translateInternal(addr uint64, size uint64, write bool) ( } } +func (ip *Interpreter) inputRegionIndex(offset uint64) int { + idx, found := slices.BinarySearchFunc(ip.inputRegions, offset, func(region InputRegion, target uint64) int { + if target < region.Offset { + return 1 + } + if target >= region.Offset+region.AddressSpaceReserved { + return -1 + } + return 0 + }) + if !found { + return -1 + } + return idx +} + +func (ip *Interpreter) translateInputRegion(offset, size uint64, write bool) (unsafe.Pointer, error) { + idx := ip.inputRegionIndex(offset) + if idx < 0 { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "unmapped input region") + } + + region := &ip.inputRegions[idx] + regionOffset := offset - region.Offset + requestedLen := regionOffset + size + if requestedLen < regionOffset || requestedLen > region.AddressSpaceReserved { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "out-of-bounds input access") + } + if write && (!region.Writable || requestedLen > region.RegionSize) && region.OnWrite != nil { + if err := region.OnWrite(region, requestedLen); err != nil { + return nil, err + } + } + if requestedLen > region.RegionSize { + if !write || !region.Writable { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "out-of-bounds input access") + } + region.RegionSize = region.AddressSpaceReserved + } + if write && !region.Writable { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "write to readonly input region") + } + if region.Data != nil { + if requestedLen > uint64(len(region.Data)) { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "out-of-bounds input access") + } + return unsafe.Pointer(®ion.Data[regionOffset]), nil + } + + hostOffset := region.HostOffset + regionOffset + if hostOffset < region.HostOffset || hostOffset+size < hostOffset || hostOffset+size > uint64(len(ip.input)) { + return nil, NewExcBadAccess(VaddrInput+offset, size, write, "out-of-bounds input access") + } + return unsafe.Pointer(&ip.input[hostOffset]), nil +} + +func (ip *Interpreter) TranslateInput(addr uint64, size uint64) ([]byte, error) { + if size == 0 { + return nil, nil + } + if addr < VaddrInput { + return nil, NewExcBadAccess(addr, size, false, "unmapped input region") + } + offset := addr - VaddrInput + if len(ip.inputRegions) != 0 { + idx := ip.inputRegionIndex(offset) + if idx < 0 { + return nil, NewExcBadAccess(addr, size, false, "unmapped input region") + } + region := ip.inputRegions[idx] + regionOffset := offset - region.Offset + if regionOffset+size < regionOffset || regionOffset+size > region.AddressSpaceReserved { + return nil, NewExcBadAccess(addr, size, false, "out-of-bounds input access") + } + if region.Data != nil { + if regionOffset+size > uint64(len(region.Data)) { + return nil, NewExcBadAccess(addr, size, false, "out-of-bounds input access") + } + return region.Data[regionOffset : regionOffset+size], nil + } + hostOffset := region.HostOffset + regionOffset + if hostOffset < region.HostOffset || hostOffset+size < hostOffset || hostOffset+size > uint64(len(ip.input)) { + return nil, NewExcBadAccess(addr, size, false, "out-of-bounds input access") + } + return ip.input[hostOffset : hostOffset+size], nil + } + if offset+size < offset || offset+size > uint64(len(ip.input)) { + return nil, NewExcBadAccess(addr, size, false, "out-of-bounds input access") + } + return ip.input[offset : offset+size], nil +} + +func (ip *Interpreter) SetInputRegionData(addr uint64, data []byte, length uint64, writable bool) bool { + if addr < VaddrInput || len(ip.inputRegions) == 0 { + return false + } + idx := ip.inputRegionIndex(addr - VaddrInput) + if idx < 0 { + return false + } + region := &ip.inputRegions[idx] + if addr != VaddrInput+region.Offset || length > region.AddressSpaceReserved { + return false + } + if data != nil { + region.Data = data + } + region.RegionSize = length + region.Writable = writable + return true +} + +func (ip *Interpreter) SetInputRegionLength(addr uint64, length uint64, writable bool) bool { + return ip.SetInputRegionData(addr, nil, length, writable) +} + func (ip *Interpreter) Translate(addr uint64, size uint64, write bool) ([]byte, error) { if size == 0 { return nil, nil diff --git a/pkg/sbpf/interpreter_v3_test.go b/pkg/sbpf/interpreter_v3_test.go new file mode 100644 index 00000000..5b7f2d19 --- /dev/null +++ b/pkg/sbpf/interpreter_v3_test.go @@ -0,0 +1,111 @@ +package sbpf + +import ( + "encoding/binary" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/cu" + "github.com/Overclock-Validator/mithril/pkg/sbpf/sbpfver" + "github.com/stretchr/testify/require" +) + +func testSlot(op uint8, dst uint8, src uint8, off int16, imm uint32) Slot { + return Slot(op) | Slot(dst)<<8 | Slot(src)<<12 | Slot(uint16(off))<<16 | Slot(imm)<<32 +} + +func testSlotsToBytes(slots []Slot) []byte { + out := make([]byte, len(slots)*SlotSize) + for i, slot := range slots { + binary.LittleEndian.PutUint64(out[i*SlotSize:], uint64(slot)) + } + return out +} + +func testV3Program(text []Slot, funcs map[uint32]int64) *Program { + return &Program{ + TextBytes: testSlotsToBytes(text), + Text: text, + TextVA: VaddrProgram, + Entrypoint: 0, + Funcs: funcs, + SbpfVersion: sbpfver.SbpfVersion{Version: sbpfver.SbpfVersionV3}, + } +} + +func TestInterpreterV3StaticSyscallAndReturn(t *testing.T) { + hash := SymbolHash("test_syscall") + text := []Slot{ + testSlot(OpCall, 0, 0, 0, hash), + testSlot(OpExit, 0, 0, 0, 0), + } + program := testV3Program(text, nil) + require.NoError(t, program.Verify()) + + called := false + syscalls := SyscallRegistry(func(got uint32) (Syscall, bool) { + if got != hash { + return nil, false + } + return SyscallFunc0(func(VM) (uint64, error) { + called = true + return 7, nil + }), true + }) + computeMeter := cu.NewComputeMeter(100) + interpreter := NewInterpreter(program, &VMOpts{ + HeapMax: 1024, + Syscalls: syscalls, + ComputeMeter: &computeMeter, + }) + defer interpreter.Finish() + + ret, _, err := interpreter.Run() + require.NoError(t, err) + require.True(t, called) + require.Equal(t, uint64(7), ret) +} + +func TestInterpreterV3RelativeCall(t *testing.T) { + text := []Slot{ + testSlot(OpCall, 0, 1, 0, 1), + testSlot(OpExit, 0, 0, 0, 0), + testSlot(OpMov64Imm, 0, 0, 0, 42), + testSlot(OpExit, 0, 0, 0, 0), + } + program := testV3Program(text, nil) + require.NoError(t, program.Verify()) + + computeMeter := cu.NewComputeMeter(100) + interpreter := NewInterpreter(program, &VMOpts{ + HeapMax: 1024, + Syscalls: func(uint32) (Syscall, bool) { return nil, false }, + ComputeMeter: &computeMeter, + }) + defer interpreter.Finish() + + ret, _, err := interpreter.Run() + require.NoError(t, err) + require.Equal(t, uint64(42), ret) +} + +func TestInterpreterV3RodataRegionZero(t *testing.T) { + text := []Slot{ + testSlot(OpExit, 0, 0, 0, 0), + } + program := testV3Program(text, nil) + program.RO = []byte{1, 2, 3, 4} + computeMeter := cu.NewComputeMeter(100) + interpreter := NewInterpreter(program, &VMOpts{ + HeapMax: 1024, + Syscalls: func(uint32) (Syscall, bool) { return nil, false }, + ComputeMeter: &computeMeter, + }) + defer interpreter.Finish() + + bytes, err := interpreter.Translate(0, uint64(len(program.RO)), false) + require.NoError(t, err) + require.Equal(t, program.RO, bytes) + + _, err = interpreter.translateInternal(VaddrProgram, SlotSize, false) + require.Error(t, err) +} diff --git a/pkg/sbpf/loader/copy.go b/pkg/sbpf/loader/copy.go index b5e4a42c..6cf42f2d 100644 --- a/pkg/sbpf/loader/copy.go +++ b/pkg/sbpf/loader/copy.go @@ -8,6 +8,11 @@ import ( "github.com/Overclock-Validator/mithril/pkg/sbpf" ) +type sectionMapping struct { + src addrRange + dst addrRange +} + // The following ELF loading rules seem mostly arbitrary. // For the sake of cleanliness, this loader doesn't process // some badly malformed ELFs that would pass on Solana mainnet. @@ -15,8 +20,13 @@ import ( // copy allocates program buffers and copies ELF contents. func (l *Loader) copy() error { + if l.enableStricterElfHeaders() { + return l.copyStrict() + } + l.progRange = newAddrRange() l.rodatas = make([]addrRange, 0, 4) + l.rodataMappings = make([]sectionMapping, 0, 4) if err := l.getText(); err != nil { return err } @@ -29,6 +39,26 @@ func (l *Loader) copy() error { return nil } +func (l *Loader) copyStrict() error { + l.progRange = addrRange{min: 0, max: l.rodataRange.len()} + l.rodatas = nil + l.rodataMappings = nil + + l.program = make([]byte, l.rodataRange.len()) + if l.rodataRange.len() != 0 { + if err := l.readSection(l.rodataRange, l.program); err != nil { + return err + } + } + + l.text = make([]byte, l.textRange.len()) + if err := l.readSection(l.textRange, l.text); err != nil { + return err + } + + return nil +} + // getText remembers the range of .text in the program buffer func (l *Loader) getText() error { if err := l.checkSectionAddrs(l.shText); err != nil { @@ -65,23 +95,30 @@ func (l *Loader) mapSections() error { } // Section overlap check & bounds tracking - section := addrRange{min: sh.Off, max: sh.Off + sh.Size} - if section.len() == 0 { + src := addrRange{min: sh.Off, max: sh.Off + sh.Size} + dst := l.sectionProgramRange(sectionName, &sh) + if dst.len() == 0 { continue } - if l.progRange.containsRange(section) { + if l.progRange.containsRange(dst) { // TODO rbpf probably doesn't have this restriction return fmt.Errorf("rodata section %d overlaps with other section", i) } - l.progRange.insert(section) + l.progRange.insert(dst) - if section.min != l.textRange.min { - l.rodatas = append(l.rodatas, section) + if sectionName != ".text" { + l.rodatas = append(l.rodatas, dst) + l.rodataMappings = append(l.rodataMappings, sectionMapping{src: src, dst: dst}) } } return iter.Err() } +func (l *Loader) sectionProgramRange(sectionName string, sh *elf.Section64) addrRange { + src := addrRange{min: sh.Off, max: sh.Off + sh.Size} + return src +} + func (l *Loader) checkSectionAddrs(sh *elf.Section64) error { if sh.Size > l.fileSize { return io.ErrUnexpectedEOF @@ -109,28 +146,36 @@ func (l *Loader) copySections() error { l.progRange.extendToFit(0) // Allocate! - l.program = make([]byte, l.fileSize) + programSize := l.fileSize + if l.progRange.max > programSize { + programSize = l.progRange.max + } + l.program = make([]byte, programSize) // Read data from ELF file - for _, section := range l.rodatas { - if err := l.copySection(section); err != nil { + for _, section := range l.rodataMappings { + if err := l.copySection(section.src, section.dst); err != nil { return err } } - if err := l.copySection(l.textRange); err != nil { - return err - } // Special sub-slice for text + if err := l.copySection(l.textRange, l.textRange); err != nil { + return err + } l.text = l.getRange(l.textRange) return nil } -func (l *Loader) copySection(section addrRange) (err error) { - off, size := int64(section.min), int64(section.len()) +func (l *Loader) copySection(src addrRange, dst addrRange) error { + return l.readSection(src, l.program[dst.min:dst.max]) +} + +func (l *Loader) readSection(src addrRange, dst []byte) (err error) { + off, size := int64(src.min), int64(src.len()) rd := io.NewSectionReader(l.rd, off, size) - _, err = io.ReadFull(rd, l.program[section.min:section.max]) + _, err = io.ReadFull(rd, dst) return } diff --git a/pkg/sbpf/loader/loader.go b/pkg/sbpf/loader/loader.go index c2cb4a4c..7b9dc52c 100644 --- a/pkg/sbpf/loader/loader.go +++ b/pkg/sbpf/loader/loader.go @@ -50,15 +50,18 @@ type Loader struct { // Program section/segment mappings // Uses physical addressing - rodatas []addrRange - textRange addrRange - progRange addrRange + rodatas []addrRange + rodataRange addrRange + textRange addrRange + progRange addrRange + textAddr uint64 // Contains most of ELF (.text and rodata-like) // Non-loaded sections are zeroed - program []byte - text []byte - entrypoint uint64 // program counter + program []byte + text []byte + rodataMappings []sectionMapping + entrypoint uint64 // program counter // Symbols funcs map[uint32]int64 @@ -109,6 +112,27 @@ func NewLoaderWithSyscalls(buf []byte, syscalls sbpf.SyscallRegistry, elfDeployC return l, nil } +func (l *Loader) sbpfVersion() sbpfver.SbpfVersion { + return sbpfver.SbpfVersion{Version: l.eh.Flags} +} + +func (l *Loader) enableStaticSyscalls() bool { + ver := l.sbpfVersion() + return ver.EnableStaticSyscalls() +} + +func (l *Loader) enableStricterElfHeaders() bool { + ver := l.sbpfVersion() + return ver.EnableStricterElfHeaders() +} + +func (l *Loader) textVA() uint64 { + if l.enableStricterElfHeaders() { + return l.textAddr + } + return sbpf.VaddrProgram + l.textRange.min +} + // Load parses, loads, and relocates an SBF program. // // This loader differs from rbpf in a few ways: @@ -140,10 +164,11 @@ func parseSlots(bs []byte) []sbpf.Slot { func (l *Loader) getProgram() *sbpf.Program { return &sbpf.Program{ RO: l.program, + TextBytes: l.text, Text: parseSlots(l.text), - TextVA: sbpf.VaddrProgram + l.textRange.min, + TextVA: l.textVA(), Entrypoint: l.entrypoint, Funcs: l.funcs, - SbpfVersion: sbpfver.SbpfVersion{Version: l.eh.Flags}, + SbpfVersion: l.sbpfVersion(), } } diff --git a/pkg/sbpf/loader/parse.go b/pkg/sbpf/loader/parse.go index bf1b87a5..c3b6e067 100644 --- a/pkg/sbpf/loader/parse.go +++ b/pkg/sbpf/loader/parse.go @@ -9,6 +9,8 @@ import ( "math" "math/bits" "strings" + + "github.com/Overclock-Validator/mithril/pkg/sbpf" ) // parse checks ELF file for validity and loads metadata with minimal allocations. @@ -16,6 +18,9 @@ func (l *Loader) parse() error { if err := l.readHeader(); err != nil { return err } + if l.enableStricterElfHeaders() { + return l.parseStrict() + } if err := l.validateElfHeader(); err != nil { return err } @@ -57,6 +62,69 @@ const ( EF_SBPF_V2 = 32 ) +const ( + elfIdentABIVersion = 8 + elfIdentPadStart = 9 +) + +const ( + progFlagX = 0x1 + progFlagR = 0x4 +) + +func (l *Loader) parseStrict() error { + if err := l.validateStrictElfHeader(); err != nil { + return err + } + + programHeaders, err := l.readProgramHeaders() + if err != nil { + return err + } + + expectedProgramHeaders := []struct { + flags uint32 + vaddr uint64 + }{ + {progFlagR, 0}, + {progFlagX, sbpf.VaddrProgram}, + } + + skipRodataProgramHeader := programHeaders[0].Flags != expectedProgramHeaders[0].flags + if skipRodataProgramHeader { + expectedProgramHeaders = expectedProgramHeaders[1:] + } else if l.eh.Phnum < 2 { + return fmt.Errorf("invalid ELF file") + } + + expectedOffset := uint64(ehLen) + uint64(l.eh.Phnum)*phEntLen + for i, expected := range expectedProgramHeaders { + ph := programHeaders[i] + if err := l.validateStrictProgramHeader(ph, expected.flags, expected.vaddr, expectedOffset); err != nil { + return err + } + expectedOffset = clampAddUint64(expectedOffset, ph.Filesz) + } + + var bytecodeHeader elf.Prog64 + if skipRodataProgramHeader { + l.rodataRange = addrRange{min: uint64(ehLen) + uint64(l.eh.Phnum)*phEntLen, max: uint64(ehLen) + uint64(l.eh.Phnum)*phEntLen} + bytecodeHeader = programHeaders[0] + } else { + rodataHeader := programHeaders[0] + l.rodataRange = addrRange{min: rodataHeader.Off, max: rodataHeader.Off + rodataHeader.Filesz} + bytecodeHeader = programHeaders[1] + } + + l.textAddr = bytecodeHeader.Vaddr + l.textRange = addrRange{min: bytecodeHeader.Off, max: bytecodeHeader.Off + bytecodeHeader.Filesz} + if !bytecodeHeaderContainsEntrypoint(bytecodeHeader, l.eh.Entry) || l.eh.Entry%sbpf.SlotSize != 0 { + return fmt.Errorf("invalid ELF file") + } + + return nil +} + func (l *Loader) newShTableIter() *shTableIter { eh := &l.eh return &shTableIter{ @@ -170,6 +238,93 @@ func (l *Loader) validateElfHeader() error { return nil } +func (l *Loader) validateSbpfVersion() error { + eh := &l.eh + if eh.Flags == EF_SBF_V2 { + return fmt.Errorf("invalid sbpf version") + } + if eh.Flags < l.minSbpfVersion || eh.Flags > l.maxSbpfVersion { + return fmt.Errorf("invalid sbpf version") + } + return nil +} + +func (l *Loader) validateStrictElfHeader() error { + eh := &l.eh + ident := &eh.Ident + + if err := l.validateSbpfVersion(); err != nil { + return err + } + + if string(ident[:elf.EI_CLASS]) != elf.ELFMAG || + elf.Class(ident[elf.EI_CLASS]) != elf.ELFCLASS64 || + elf.Data(ident[elf.EI_DATA]) != elf.ELFDATA2LSB || + elf.Version(ident[elf.EI_VERSION]) != elf.EV_CURRENT || + elf.OSABI(ident[elf.EI_OSABI]) != elf.ELFOSABI_NONE || + ident[elfIdentABIVersion] != 0 || + !allZeroBytes(ident[elfIdentPadStart:]) || + elf.Machine(eh.Machine) != elf.EM_BPF || + eh.Version != uint32(elf.EV_CURRENT) || + eh.Phoff != ehLen || + eh.Ehsize != ehLen || + eh.Phentsize != phEntLen || + eh.Phnum == 0 { + return fmt.Errorf("invalid ELF file") + } + + phTableEnd := uint64(ehLen) + uint64(eh.Phnum)*phEntLen + if phTableEnd > l.fileSize { + return fmt.Errorf("invalid ELF file") + } + + return nil +} + +func allZeroBytes(bytes []byte) bool { + for _, b := range bytes { + if b != 0 { + return false + } + } + return true +} + +func (l *Loader) readProgramHeaders() ([]elf.Prog64, error) { + programHeaders := make([]elf.Prog64, 0, l.eh.Phnum) + iter := l.newPhTableIter() + for iter.Next() && iter.Err() == nil { + programHeaders = append(programHeaders, iter.Item()) + } + if err := iter.Err(); err != nil { + return nil, err + } + return programHeaders, nil +} + +func (l *Loader) validateStrictProgramHeader(ph elf.Prog64, expectedFlags uint32, expectedVaddr uint64, expectedOffset uint64) error { + if elf.ProgType(ph.Type) != elf.PT_LOAD || + ph.Flags != expectedFlags || + ph.Off != expectedOffset || + ph.Off >= l.fileSize || + ph.Off%sbpf.SlotSize != 0 || + ph.Vaddr != expectedVaddr || + ph.Paddr != expectedVaddr || + ph.Filesz != ph.Memsz || + ph.Filesz > l.fileSize-ph.Off || + ph.Filesz%sbpf.SlotSize != 0 || + ph.Memsz >= sbpf.VaddrProgram { + return fmt.Errorf("invalid program header") + } + return nil +} + +func bytecodeHeaderContainsEntrypoint(ph elf.Prog64, entrypoint uint64) bool { + entrypointEnd := clampAddUint64(entrypoint, sbpf.SlotSize-1) + bytecodeEnd := clampAddUint64(ph.Vaddr, ph.Memsz) + return ph.Vaddr <= entrypointEnd && entrypointEnd < bytecodeEnd +} + // scan the program header table and remember the last PT_LOAD segment func (l *Loader) loadProgramHeaderTable() error { iter := l.newPhTableIter() @@ -212,7 +367,6 @@ func (l *Loader) loadProgramHeaderTable() error { func (l *Loader) readSectionHeaderTable() error { eh := &l.eh iter := l.newShTableIter() - sectionDataOff := uint64(0) if !iter.Next() { return fmt.Errorf("missing section 0") @@ -261,10 +415,6 @@ func (l *Loader) readSectionHeaderTable() error { return fmt.Errorf("section %d overlaps with section header", i) } - // More checks - if eh.Shoff < sectionDataOff { - return fmt.Errorf("sections not in order") - } if shend > l.fileSize { return fmt.Errorf("section %d out of bounds", i) } @@ -273,8 +423,6 @@ func (l *Loader) readSectionHeaderTable() error { if eh.Shstrndx != uint16(elf.SHN_UNDEF) && uint32(eh.Shstrndx) == i { l.shShstrtab = sh } - - sectionDataOff = shend } // TODO validate offset and size (?) if elf.SectionType(l.shShstrtab.Type) != elf.SHT_STRTAB { diff --git a/pkg/sbpf/loader/relocate.go b/pkg/sbpf/loader/relocate.go index 997e38dd..3426e7f6 100644 --- a/pkg/sbpf/loader/relocate.go +++ b/pkg/sbpf/loader/relocate.go @@ -12,11 +12,13 @@ import ( func (l *Loader) relocate() error { l.funcs = make(map[uint32]int64) l.funcName = make(map[uint32]int64) - if err := l.fixupRelativeCalls(); err != nil { - return err - } - if err := l.applyDynamicRelocs(); err != nil { - return err + if !l.enableStaticSyscalls() { + if err := l.fixupRelativeCalls(); err != nil { + return err + } + if err := l.applyDynamicRelocs(); err != nil { + return err + } } if err := l.getEntrypoint(); err != nil { return err @@ -55,9 +57,12 @@ func (l *Loader) fixupRelativeCalls() error { } func (l *Loader) registerFunc(target uint64) (uint32, error) { - hash := sbpf.PCHash(target) + hash := uint32(target) + if !l.enableStaticSyscalls() { + hash = sbpf.PCHash(target) + } - if l.syscalls != nil && l.syscalls.ExistsByHash(hash) { + if !l.enableStaticSyscalls() && l.syscalls != nil && l.syscalls.ExistsByHash(hash) { return 0, fmt.Errorf("symbol hash collision with syscall") } @@ -69,6 +74,13 @@ func (l *Loader) registerFunc(target uint64) (uint32, error) { return hash, nil } +func (l *Loader) normalizeVaddr(addr uint64) uint64 { + if addr < sbpf.VaddrProgram { + return clampAddUint64(addr, sbpf.VaddrProgram) + } + return addr +} + func (l *Loader) applyDynamicRelocs() error { iter := l.relocsIter if iter == nil { @@ -104,9 +116,7 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { relAddr := binary.LittleEndian.Uint32(l.program[rOff+4 : rOff+8]) addr := clampAddUint64(sym.Value, uint64(relAddr)) - if addr < sbpf.VaddrProgram { - addr += sbpf.VaddrProgram - } + addr = l.normalizeVaddr(addr) // Write to imm field of two slots binary.LittleEndian.PutUint32(l.program[rOff+4:rOff+8], uint32(addr)) @@ -120,9 +130,7 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { if addr == 0 { return fmt.Errorf("invalid R_BPF_64_RELATIVE") } - if addr < sbpf.VaddrProgram { - addr += sbpf.VaddrProgram - } + addr = l.normalizeVaddr(addr) // Write to imm field of two slots binary.LittleEndian.PutUint32(l.program[rOff+4:rOff+8], uint32(addr)) @@ -131,12 +139,10 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { var addr uint64 if l.eh.Flags == EF_SBF_V2 { addr = binary.LittleEndian.Uint64(l.program[rOff : rOff+8]) - if addr < sbpf.VaddrProgram { - addr += sbpf.VaddrProgram - } + addr = l.normalizeVaddr(addr) } else { addr = uint64(binary.LittleEndian.Uint32(l.program[rOff+4 : rOff+8])) - addr = clampAddUint64(addr, sbpf.VaddrProgram) + addr = l.normalizeVaddr(addr) } binary.LittleEndian.PutUint64(l.program[rOff:rOff+8], addr) } @@ -153,10 +159,11 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { var hash uint32 if elf.ST_TYPE(sym.Info) == elf.STT_FUNC && sym.Value != 0 { // Function call - if !l.textRange.contains(sym.Value) { + textVMRange := addrRange{min: l.shText.Addr, max: clampAddUint64(l.shText.Addr, l.shText.Size)} + if !textVMRange.contains(sym.Value) { return fmt.Errorf("out-of-bounds R_BPF_64_32 function ref") } - target := (sym.Value - l.textRange.min) / 8 + target := (sym.Value - textVMRange.min) / 8 nameHash := sbpf.SymbolHash(name) if existing, ok := l.funcName[nameHash]; ok && existing != int64(target) { @@ -186,7 +193,13 @@ func (l *Loader) applyReloc(reloc *elf.Rel64) error { } func (l *Loader) getEntrypoint() error { - offset := l.eh.Entry - l.shText.Addr + textAddr := uint64(0) + if l.enableStricterElfHeaders() { + textAddr = l.textAddr + } else { + textAddr = l.shText.Addr + } + offset := l.eh.Entry - textAddr if offset%sbpf.SlotSize != 0 { return fmt.Errorf("invalid entrypoint") } diff --git a/pkg/sbpf/loader/relocate_test.go b/pkg/sbpf/loader/relocate_test.go index 3fc1a0bd..bcecf626 100644 --- a/pkg/sbpf/loader/relocate_test.go +++ b/pkg/sbpf/loader/relocate_test.go @@ -1,12 +1,23 @@ package loader import ( + "debug/elf" "testing" "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/Overclock-Validator/mithril/pkg/sbpf/sbpfver" "github.com/stretchr/testify/assert" ) func TestSymbolHash_Entrypoint(t *testing.T) { assert.Equal(t, sbpf.EntrypointHash, sbpf.SymbolHash("entrypoint")) } + +func TestSectionProgramRangeUsesFileOffsets(t *testing.T) { + l := &Loader{eh: elf.Header64{Flags: sbpfver.SbpfVersionV2}} + rodata := elf.Section64{Addr: 128, Off: 128, Size: 4} + text := elf.Section64{Addr: 64, Off: 64, Size: 8} + + assert.Equal(t, addrRange{min: 128, max: 132}, l.sectionProgramRange(".rodata", &rodata)) + assert.Equal(t, addrRange{min: 64, max: 72}, l.sectionProgramRange(".text", &text)) +} diff --git a/pkg/sbpf/loader/strict_v3_test.go b/pkg/sbpf/loader/strict_v3_test.go new file mode 100644 index 00000000..7f93ce87 --- /dev/null +++ b/pkg/sbpf/loader/strict_v3_test.go @@ -0,0 +1,147 @@ +package loader + +import ( + "debug/elf" + "encoding/binary" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/stretchr/testify/require" +) + +func strictV3Features() *features.Features { + f := features.NewFeaturesDefault() + f.EnableFeature(features.EnableSbpfV3DeploymentAndExecution, 0) + return f +} + +func strictV3ELF(t *testing.T, rodata []byte, text []sbpf.Slot, mutate func([]byte)) []byte { + t.Helper() + + textBytes := make([]byte, len(text)*sbpf.SlotSize) + for i, slot := range text { + binary.LittleEndian.PutUint64(textBytes[i*sbpf.SlotSize:], uint64(slot)) + } + + phnum := uint16(1) + if rodata != nil { + phnum = 2 + } + phTableEnd := uint64(ehLen) + uint64(phnum)*phEntLen + rodataOffset := phTableEnd + textOffset := phTableEnd + uint64(len(rodata)) + fileSize := textOffset + uint64(len(textBytes)) + buf := make([]byte, fileSize) + + copy(buf[0:4], []byte{0x7f, 'E', 'L', 'F'}) + buf[elf.EI_CLASS] = byte(elf.ELFCLASS64) + buf[elf.EI_DATA] = byte(elf.ELFDATA2LSB) + buf[elf.EI_VERSION] = byte(elf.EV_CURRENT) + buf[elf.EI_OSABI] = byte(elf.ELFOSABI_NONE) + + binary.LittleEndian.PutUint16(buf[16:18], uint16(elf.ET_DYN)) + binary.LittleEndian.PutUint16(buf[18:20], uint16(elf.EM_BPF)) + binary.LittleEndian.PutUint32(buf[20:24], uint32(elf.EV_CURRENT)) + binary.LittleEndian.PutUint64(buf[24:32], sbpf.VaddrProgram) + binary.LittleEndian.PutUint64(buf[32:40], ehLen) + binary.LittleEndian.PutUint32(buf[48:52], 3) + binary.LittleEndian.PutUint16(buf[52:54], ehLen) + binary.LittleEndian.PutUint16(buf[54:56], phEntLen) + binary.LittleEndian.PutUint16(buf[56:58], phnum) + + phoff := ehLen + if rodata != nil { + writeProgHeader(buf[phoff:], uint32(elf.PT_LOAD), progFlagR, rodataOffset, 0, uint64(len(rodata))) + phoff += phEntLen + copy(buf[int(rodataOffset):int(textOffset)], rodata) + } + writeProgHeader(buf[phoff:], uint32(elf.PT_LOAD), progFlagX, textOffset, sbpf.VaddrProgram, uint64(len(textBytes))) + copy(buf[int(textOffset):], textBytes) + + if mutate != nil { + mutate(buf) + } + return buf +} + +func writeProgHeader(buf []byte, typ uint32, flags uint32, offset uint64, vaddr uint64, size uint64) { + binary.LittleEndian.PutUint32(buf[0:4], typ) + binary.LittleEndian.PutUint32(buf[4:8], flags) + binary.LittleEndian.PutUint64(buf[8:16], offset) + binary.LittleEndian.PutUint64(buf[16:24], vaddr) + binary.LittleEndian.PutUint64(buf[24:32], vaddr) + binary.LittleEndian.PutUint64(buf[32:40], size) + binary.LittleEndian.PutUint64(buf[40:48], size) +} + +func TestStrictV3LoaderUsesProgramHeaders(t *testing.T) { + rodata := []byte{1, 2, 3, 4, 5, 6, 7, 8} + text := []sbpf.Slot{sbpf.Slot(sbpf.OpExit)} + elfBytes := strictV3ELF(t, rodata, text, nil) + + l, err := NewLoaderWithSyscalls(elfBytes, nil, true, strictV3Features()) + require.NoError(t, err) + program, err := l.Load() + require.NoError(t, err) + + require.Equal(t, sbpf.VaddrProgram, program.TextVA) + require.Equal(t, rodata, program.RO) + require.Equal(t, text, program.Text) + require.Equal(t, uint64(0), program.Entrypoint) + require.NoError(t, program.Verify()) +} + +func TestStrictV3LoaderAllowsMissingRodataProgramHeader(t *testing.T) { + text := []sbpf.Slot{sbpf.Slot(sbpf.OpExit)} + elfBytes := strictV3ELF(t, nil, text, nil) + + l, err := NewLoaderWithSyscalls(elfBytes, nil, true, strictV3Features()) + require.NoError(t, err) + program, err := l.Load() + require.NoError(t, err) + + require.Empty(t, program.RO) + require.Equal(t, sbpf.VaddrProgram, program.TextVA) + require.Equal(t, text, program.Text) +} + +func TestStrictV3LoaderDoesNotRequireDynElfType(t *testing.T) { + text := []sbpf.Slot{sbpf.Slot(sbpf.OpExit)} + elfBytes := strictV3ELF(t, nil, text, func(buf []byte) { + binary.LittleEndian.PutUint16(buf[16:18], uint16(elf.ET_EXEC)) + }) + + l, err := NewLoaderWithSyscalls(elfBytes, nil, true, strictV3Features()) + require.NoError(t, err) + _, err = l.Load() + require.NoError(t, err) +} + +func TestStrictV3LoaderRejectsAgaveStrictProgramHeaderMismatches(t *testing.T) { + text := []sbpf.Slot{sbpf.Slot(sbpf.OpExit)} + + tests := map[string]func([]byte){ + "wrong machine": func(buf []byte) { + binary.LittleEndian.PutUint16(buf[18:20], uint16(EM_SBPF)) + }, + "wrong text offset": func(buf []byte) { + textProgramHeader := ehLen + binary.LittleEndian.PutUint64(buf[textProgramHeader+8:textProgramHeader+16], uint64(ehLen)) + }, + "wrong text vaddr": func(buf []byte) { + textProgramHeader := ehLen + binary.LittleEndian.PutUint64(buf[textProgramHeader+16:textProgramHeader+24], 0) + }, + } + + for name, mutate := range tests { + t.Run(name, func(t *testing.T) { + elfBytes := strictV3ELF(t, nil, text, mutate) + l, err := NewLoaderWithSyscalls(elfBytes, nil, true, strictV3Features()) + require.NoError(t, err) + _, err = l.Load() + require.Error(t, err) + }) + } +} diff --git a/pkg/sbpf/opcode_test.go b/pkg/sbpf/opcode_test.go index 8404f6f6..d5e80fd5 100644 --- a/pkg/sbpf/opcode_test.go +++ b/pkg/sbpf/opcode_test.go @@ -80,10 +80,14 @@ func TestOpcodes(t *testing.T) { {0xcf, OpArsh64Reg}, {0xd4, OpLe}, {0xdc, OpBe}, - {0xe4, OpSdiv32Imm}, - {0xe7, OpSdiv64Imm}, - {0xec, OpSdiv32Reg}, - {0xef, OpSdiv64Reg}, + {0xc6, OpSdiv32Imm}, + {0xd6, OpSdiv64Imm}, + {0xce, OpSdiv32Reg}, + {0xde, OpSdiv64Reg}, + {0xe6, OpSrem32Imm}, + {0xf6, OpSrem64Imm}, + {0xee, OpSrem32Reg}, + {0xfe, OpSrem64Reg}, {0x05, OpJa}, {0x15, OpJeqImm}, diff --git a/pkg/sbpf/program.go b/pkg/sbpf/program.go index c4f323d0..5f1c14dd 100644 --- a/pkg/sbpf/program.go +++ b/pkg/sbpf/program.go @@ -7,6 +7,7 @@ import ( // Program is a loaded SBF program. type Program struct { RO []byte // read-only segment containing text and ELFs + TextBytes []byte Text []Slot TextVA uint64 Entrypoint uint64 // PC diff --git a/pkg/sbpf/sbpf.go b/pkg/sbpf/sbpf.go index 8d0065a9..f338ee99 100644 --- a/pkg/sbpf/sbpf.go +++ b/pkg/sbpf/sbpf.go @@ -22,6 +22,10 @@ func IsLongIns(op uint8) bool { return op == OpLddw } +func IsFunctionStartMarker(slot Slot) bool { + return slot.Op() == OpAdd64Imm && slot.Dst() == 10 +} + // Slot holds the content of one instruction slot. type Slot uint64 diff --git a/pkg/sbpf/sbpfver/sbpf_version.go b/pkg/sbpf/sbpfver/sbpf_version.go index 20578e99..62bb3118 100644 --- a/pkg/sbpf/sbpfver/sbpf_version.go +++ b/pkg/sbpf/sbpfver/sbpf_version.go @@ -14,7 +14,11 @@ type SbpfVersion struct { } func (ver *SbpfVersion) DynamicStackFrames() bool { - return ver.Version >= SbpfVersionV1 + return ver.Version == SbpfVersionV1 || ver.Version == SbpfVersionV2 +} + +func (ver *SbpfVersion) StackFrameGaps() bool { + return ver.Version == SbpfVersionV0 } func (ver *SbpfVersion) RejectRodataStackOverlap() bool { @@ -26,48 +30,79 @@ func (ver *SbpfVersion) EnableElfVAddr() bool { } func (ver *SbpfVersion) EnablePqr() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) ExplicitSignExtensionOfResults() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) SwapSubRegImmOperands() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) DisableNeg() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) CallXUsesSrcReg() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 +} + +func (ver *SbpfVersion) CallXUsesDstReg() bool { + return ver.Version >= SbpfVersionV3 } func (ver *SbpfVersion) DisableLe() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) DisableLddw() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) MoveMemoryInstructionClasses() bool { - return ver.Version >= SbpfVersionV2 + return ver.Version == SbpfVersionV2 } func (ver *SbpfVersion) EnableStaticSyscalls() bool { return ver.Version >= SbpfVersionV3 } +func (ver *SbpfVersion) EnableStricterElfHeaders() bool { + return ver.Version >= SbpfVersionV3 +} + +func (ver *SbpfVersion) EnableStricterVerification() bool { + return false +} + +func (ver *SbpfVersion) EnableLowerRodataVaddr() bool { + return ver.Version >= SbpfVersionV3 +} + +func (ver *SbpfVersion) EnableLowerBytecodeVaddr() bool { + return false +} + +func (ver *SbpfVersion) EnableJmp32() bool { + return ver.Version >= SbpfVersionV3 +} + +func (ver *SbpfVersion) CalculateCallImmTargetPC(pc int64, imm int32) int64 { + if ver.EnableStaticSyscalls() { + return pc + int64(imm) + 1 + } + return int64(uint32(imm)) +} + func GetMinAndMaxSbpfVersions(f *features.Features) (uint32, uint32) { disableSbpfV0 := f.IsActive(features.DisableSbpfV0Execution) reenableSbpfV0 := f.IsActive(features.ReenableSbpfV0Execution) enableSbpfV0 := !disableSbpfV0 || reenableSbpfV0 enableSbpfV1 := f.IsActive(features.EnableSbpfV1DeploymentAndExecution) enableSbpfV2 := f.IsActive(features.EnableSbpfV2DeploymentAndExecution) - enableSbpfV3 := f.IsActive(features.EnableSbpfV3DeploymentAndExecution) + enableSbpfV3 := f.IsSbpfV3DeploymentAndExecutionActive() var maxVer, minVer uint32 diff --git a/pkg/sbpf/sbpfver/sbpf_version_test.go b/pkg/sbpf/sbpfver/sbpf_version_test.go new file mode 100644 index 00000000..49d5a228 --- /dev/null +++ b/pkg/sbpf/sbpfver/sbpf_version_test.go @@ -0,0 +1,29 @@ +package sbpfver + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/stretchr/testify/require" +) + +func TestGetMinAndMaxSbpfVersionsAcceptsSbpfV3Gate(t *testing.T) { + f := features.NewFeaturesDefault() + f.EnableFeature(features.EnableSbpfV3DeploymentAndExecution, 0) + + _, maxVersion := GetMinAndMaxSbpfVersions(f) + require.Equal(t, uint32(SbpfVersionV3), maxVersion) +} + +func TestSbpfV3VersionPredicates(t *testing.T) { + v3 := SbpfVersion{Version: SbpfVersionV3} + + require.False(t, v3.DynamicStackFrames()) + require.False(t, v3.StackFrameGaps()) + require.False(t, v3.EnablePqr()) + require.False(t, v3.DisableLddw()) + require.True(t, v3.EnableStaticSyscalls()) + require.True(t, v3.EnableStricterElfHeaders()) + require.True(t, v3.EnableJmp32()) + require.True(t, v3.CallXUsesDstReg()) +} diff --git a/pkg/sbpf/stack.go b/pkg/sbpf/stack.go index 6134cd96..3f3331b6 100644 --- a/pkg/sbpf/stack.go +++ b/pkg/sbpf/stack.go @@ -35,6 +35,7 @@ type Stack struct { sp uint64 shadow []Frame dynamicStackFrames bool + stackFrameGaps bool } // Frame is an entry on the shadow stack. @@ -79,7 +80,7 @@ var ( }} ) -func NewStack(sbpfVer sbpfver.SbpfVersion) Stack { +func NewStack(sbpfVer sbpfver.SbpfVersion, disableStackFrameGaps bool) Stack { var m []byte var sh []Frame if UsePool { @@ -102,6 +103,7 @@ func NewStack(sbpfVer sbpfver.SbpfVersion) Stack { s.dynamicStackFrames = true } else { sz = StackFrameSize + s.stackFrameGaps = sbpfVer.StackFrameGaps() && !disableStackFrameGaps } s.shadow[0] = Frame{ @@ -132,14 +134,16 @@ func (s *Stack) GetFrame(addr uint32) []byte { off := uint64(addr & math.MaxUint32) if !s.dynamicStackFrames { - // disallow addressing a gap - hi := addr / StackFrameSize - if hi%2 == 1 { - return nil + if s.stackFrameGaps { + // disallow addressing a gap + hi := addr / StackFrameSize + if hi%2 == 1 { + return nil + } + + // account for gapping in virtual addr space but not in the underlying memory + off = ((off & GapMask) >> 1) | (off & ^GapMask) } - - // account for gapping in virtual addr space but not in the underlying memory - off = ((off & GapMask) >> 1) | (off & ^GapMask) } if off > StackMax { @@ -166,7 +170,11 @@ func (s *Stack) Push(regs []uint64, ret int64) bool { s.shadow = append(s.shadow, frame) if !s.dynamicStackFrames { - regs[10] += StackFrameSize * 2 + if s.stackFrameGaps { + regs[10] += StackFrameSize * 2 + } else { + regs[10] += StackFrameSize + } } return true diff --git a/pkg/sbpf/vasa_test.go b/pkg/sbpf/vasa_test.go new file mode 100644 index 00000000..e7d64034 --- /dev/null +++ b/pkg/sbpf/vasa_test.go @@ -0,0 +1,102 @@ +package sbpf + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/sbpf/sbpfver" + "github.com/stretchr/testify/require" +) + +func TestInputRegionVirtualAddressSpaceAdjustments(t *testing.T) { + input := make([]byte, 32) + ip := &Interpreter{ + input: input, + inputRegions: []InputRegion{ + {Offset: 0, RegionSize: 8, AddressSpaceReserved: 8, Writable: true, AccountIndex: -1}, + {Offset: 8, RegionSize: 4, AddressSpaceReserved: 12, Writable: true, AccountIndex: 0}, + {Offset: 20, RegionSize: 4, AddressSpaceReserved: 4, Writable: false, AccountIndex: 1}, + }, + } + + _, err := ip.Read8(VaddrInput + 8 + 4) + require.Error(t, err) + + require.NoError(t, ip.Write8(VaddrInput+8+4, 0xaa)) + require.Equal(t, uint64(12), ip.inputRegions[1].RegionSize) + + value, err := ip.Read8(VaddrInput + 8 + 4) + require.NoError(t, err) + require.Equal(t, uint8(0xaa), value) + + require.Error(t, ip.Write8(VaddrInput+20, 0xbb)) +} + +func TestInputRegionDirectMappedAccountData(t *testing.T) { + backing := []byte{1, 2, 0, 0, 0} + onWriteCalls := 0 + ip := &Interpreter{ + input: make([]byte, 1), + inputRegions: []InputRegion{ + { + Offset: 0, + RegionSize: 2, + AddressSpaceReserved: 5, + Writable: false, + AccountIndex: 0, + Data: backing[:2], + OnWrite: func(region *InputRegion, requestedLen uint64) error { + onWriteCalls++ + require.Equal(t, uint64(3), requestedLen) + region.Data = backing + region.RegionSize = uint64(len(backing)) + region.Writable = true + return nil + }, + }, + }, + } + + _, err := ip.Read8(VaddrInput + 2) + require.Error(t, err) + + require.NoError(t, ip.Write8(VaddrInput+2, 0xcc)) + require.Equal(t, 1, onWriteCalls) + require.Equal(t, uint8(0xcc), backing[2]) + + data, err := ip.TranslateInput(VaddrInput, uint64(len(backing))) + require.NoError(t, err) + require.Equal(t, backing, data) +} + +func TestStackFrameGapsCanBeDisabled(t *testing.T) { + version := sbpfver.SbpfVersion{Version: sbpfver.SbpfVersionV0} + + gapped := NewStack(version, false) + defer gapped.Finish() + gappedRegs := make([]uint64, 11) + gappedRegs[10] = VaddrStack + StackFrameSize + require.True(t, gapped.Push(gappedRegs, 0)) + require.Equal(t, VaddrStack+StackFrameSize*3, gappedRegs[10]) + require.Nil(t, gapped.GetFrame(StackFrameSize)) + + contiguous := NewStack(version, true) + defer contiguous.Finish() + contiguousRegs := make([]uint64, 11) + contiguousRegs[10] = VaddrStack + StackFrameSize + require.True(t, contiguous.Push(contiguousRegs, 0)) + require.Equal(t, VaddrStack+StackFrameSize*2, contiguousRegs[10]) + require.NotNil(t, contiguous.GetFrame(StackFrameSize)) +} + +func TestStackFrameGapsAreLegacyOnly(t *testing.T) { + version := sbpfver.SbpfVersion{Version: sbpfver.SbpfVersionV3} + + stack := NewStack(version, false) + defer stack.Finish() + + regs := make([]uint64, 11) + regs[10] = VaddrStack + StackFrameSize + require.True(t, stack.Push(regs, 0)) + require.Equal(t, VaddrStack+StackFrameSize*2, regs[10]) + require.NotNil(t, stack.GetFrame(StackFrameSize)) +} diff --git a/pkg/sbpf/verifier.go b/pkg/sbpf/verifier.go index 009e7a17..fd5b8864 100644 --- a/pkg/sbpf/verifier.go +++ b/pkg/sbpf/verifier.go @@ -2,7 +2,6 @@ package sbpf import ( "fmt" - "sort" ) type Verifier struct { @@ -26,8 +25,6 @@ const ( verifyInvalid verifyCheckCallReg verifyCheckCallRegDepr - verifyCheckCallImm - verifyCheckSyscall verifyCheckJmpV0 ) @@ -39,31 +36,7 @@ func (v *Verifier) VerifyProgram() error { return fmt.Errorf("empty text") } - var funcStarts []int64 - if v.Program.SbpfVersion.EnableStaticSyscalls() { - for _, pc := range v.Program.Funcs { - funcStarts = append(funcStarts, pc) - } - sort.Slice(funcStarts, func(i, j int) bool { return funcStarts[i] < funcStarts[j] }) - } - - functionStart := int64(0) - functionNext := int64(len(text)) - funcIdx := 0 - for pc := 0; pc < len(text); pc++ { - if v.Program.SbpfVersion.EnableStaticSyscalls() { - for funcIdx < len(funcStarts) && int64(pc) >= funcStarts[funcIdx] { - functionStart = funcStarts[funcIdx] - if funcIdx+1 < len(funcStarts) { - functionNext = funcStarts[funcIdx+1] - } else { - functionNext = int64(len(text)) - } - funcIdx++ - } - } - ins := text[pc] if ins.Src() > 10 { @@ -93,7 +66,7 @@ func (v *Verifier) VerifyProgram() error { case verifyCheckJmpV3: { dst := int64(pc) + int64(ins.Off()) + 1 - if dst < functionStart || dst >= functionNext { + if dst < 0 || dst >= int64(len(text)) { return fmt.Errorf("jump out of code") } } @@ -140,28 +113,22 @@ func (v *Verifier) VerifyProgram() error { case verifyCheckCallReg: { - if ins.Src() > 9 { + reg := ins.Src() + if v.Program.SbpfVersion.CallXUsesDstReg() { + reg = ins.Dst() + } + if reg > 9 { return fmt.Errorf("invalid register") } } case verifyCheckCallRegDepr: { - if ins.Imm() > 9 { + if ins.Imm() < 0 || ins.Imm() > 9 { return fmt.Errorf("invalid register") } } - case verifyCheckCallImm: - { - if _, ok := v.Program.Funcs[ins.Uimm()]; !ok { - return fmt.Errorf("invalid function") - } - } - - case verifyCheckSyscall: - // nothing to do - already verified by the loader - case verifyInvalid: fallthrough default: @@ -193,9 +160,6 @@ func (v *Verifier) VerifyProgram() error { func (v *Verifier) buildValidationMap() { checkJmp := verifyCheckJmpV0 - if v.Program.SbpfVersion.EnableStaticSyscalls() { - checkJmp = verifyCheckJmpV3 - } v.validationMap = [256]int{ /* 0x00 */ verifyInvalid /* 0x01 */, verifyInvalid /* 0x02 */, verifyInvalid /* 0x03 */, verifyInvalid, @@ -231,13 +195,13 @@ func (v *Verifier) buildValidationMap() { /* 0x78 */ verifyInvalid /* 0x79 */, verifyInvalid /* 0x7a */, verifyInvalid /* 0x7b */, verifyInvalid, /* 0x7c */ verifyValid /* 0x7d */, checkJmp /* 0x7e */, verifyValid /* 0x7f */, verifyValid, /* 0x80 */ verifyInvalid /* 0x81 */, verifyInvalid /* 0x82 */, verifyInvalid /* 0x83 */, verifyInvalid, - /* 0x84 */ verifyInvalid /* 0x85 */, verifyCheckCallImm /*0x86*/, verifyValid /* 0x87 */, verifyCheckSt, + /* 0x84 */ verifyInvalid /* 0x85 */, verifyValid /*0x86*/, verifyValid /* 0x87 */, verifyCheckSt, /* 0x88 */ verifyInvalid /* 0x89 */, verifyInvalid /* 0x8a */, verifyInvalid /* 0x8b */, verifyInvalid, /* 0x8c */ verifyValid /* 0x8d */, verifyCheckCallReg /*0x8e*/, verifyValid /* 0x8f */, verifyCheckSt, /* 0x90 */ verifyInvalid /* 0x91 */, verifyInvalid /* 0x92 */, verifyInvalid /* 0x93 */, verifyInvalid, - /* 0x94 */ verifyInvalid /* 0x95 */, verifyCheckSyscall /*0x96*/, verifyValid /* 0x97 */, verifyCheckSt, + /* 0x94 */ verifyInvalid /* 0x95 */, verifyValid /*0x96*/, verifyValid /* 0x97 */, verifyCheckSt, /* 0x98 */ verifyInvalid /* 0x99 */, verifyInvalid /* 0x9a */, verifyInvalid /* 0x9b */, verifyInvalid, - /* 0x9c */ verifyValid /* 0x9d */, verifyValid /* 0x9e */, verifyValid /* 0x9f */, verifyCheckSt, + /* 0x9c */ verifyValid /* 0x9d */, verifyInvalid /* 0x9e */, verifyValid /* 0x9f */, verifyCheckSt, /* 0xa0 */ verifyInvalid /* 0xa1 */, verifyInvalid /* 0xa2 */, verifyInvalid /* 0xa3 */, verifyInvalid, /* 0xa4 */ verifyValid /* 0xa5 */, checkJmp /* 0xa6 */, verifyInvalid /* 0xa7 */, verifyValid, /* 0xa8 */ verifyInvalid /* 0xa9 */, verifyInvalid /* 0xaa */, verifyInvalid /* 0xab */, verifyInvalid, @@ -345,6 +309,8 @@ func (v *Verifier) buildValidationMap() { /* SIMD-0173: CALLX */ if v.Program.SbpfVersion.CallXUsesSrcReg() { v.validationMap[0x8d] = verifyCheckCallReg + } else if v.Program.SbpfVersion.CallXUsesDstReg() { + v.validationMap[0x8d] = verifyCheckCallReg } else { v.validationMap[0x8d] = verifyCheckCallRegDepr } @@ -420,11 +386,29 @@ func (v *Verifier) buildValidationMap() { v.validationMap[0xfe] = verifyInvalid } + if v.Program.SbpfVersion.EnableJmp32() { + for _, op := range []uint8{ + ClassPqr | SrcK | JumpEq, ClassPqr | SrcX | JumpEq, + ClassPqr | SrcK | JumpGt, ClassPqr | SrcX | JumpGt, + ClassPqr | SrcK | JumpGe, ClassPqr | SrcX | JumpGe, + ClassPqr | SrcK | JumpLt, ClassPqr | SrcX | JumpLt, + ClassPqr | SrcK | JumpLe, ClassPqr | SrcX | JumpLe, + ClassPqr | SrcK | JumpSet, ClassPqr | SrcX | JumpSet, + ClassPqr | SrcK | JumpNe, ClassPqr | SrcX | JumpNe, + ClassPqr | SrcK | JumpSgt, ClassPqr | SrcX | JumpSgt, + ClassPqr | SrcK | JumpSge, ClassPqr | SrcX | JumpSge, + ClassPqr | SrcK | JumpSlt, ClassPqr | SrcX | JumpSlt, + ClassPqr | SrcK | JumpSle, ClassPqr | SrcX | JumpSle, + } { + v.validationMap[op] = verifyCheckJmpV0 + } + } + /* SIMD-0178: static syscalls */ if v.Program.SbpfVersion.EnableStaticSyscalls() { - v.validationMap[0x85] = verifyCheckCallImm - v.validationMap[0x95] = verifyCheckSyscall - v.validationMap[0x9d] = verifyValid + v.validationMap[0x85] = verifyValid + v.validationMap[0x95] = verifyValid + v.validationMap[0x9d] = verifyInvalid } else { v.validationMap[0x85] = verifyValid v.validationMap[0x95] = verifyValid diff --git a/pkg/sbpf/vm.go b/pkg/sbpf/vm.go index 255c4f6d..40296a35 100644 --- a/pkg/sbpf/vm.go +++ b/pkg/sbpf/vm.go @@ -49,13 +49,27 @@ type VMOpts struct { MaxCU int ComputeMeter *cu.ComputeMeter Input []byte // mapped at VaddrInput + InputRegions []InputRegion InputDataVaddr uint64 // VM address of instruction data within Input (SIMD-0321) + // DisableStackFrameGaps is used by SIMD-0460 virtual address space adjustments. + DisableStackFrameGaps bool // Debug ProgramId solana.PublicKey TxSignature solana.Signature } +type InputRegion struct { + Offset uint64 + HostOffset uint64 + RegionSize uint64 + AddressSpaceReserved uint64 + Writable bool + AccountIndex int + Data []byte + OnWrite func(region *InputRegion, requestedLen uint64) error +} + type Exception struct { PC int64 Detail error diff --git a/pkg/sealevel/bpf_loader.go b/pkg/sealevel/bpf_loader.go index 2bccf493..b80ab885 100644 --- a/pkg/sealevel/bpf_loader.go +++ b/pkg/sealevel/bpf_loader.go @@ -3,6 +3,7 @@ package sealevel import ( "bytes" "encoding/binary" + "errors" "fmt" "math" "time" @@ -472,21 +473,116 @@ type serializedAcctMetadata struct { vmOwnerAddr uint64 } -func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64, error) { +const bpfAlignOfU128 = 8 + +func appendVasaMetadataRegion(regions *[]sbpf.InputRegion, vmStart, vmEnd, hostStart uint64) { + if vmEnd <= vmStart { + return + } + *regions = append(*regions, sbpf.InputRegion{ + Offset: vmStart, + HostOffset: hostStart, + RegionSize: vmEnd - vmStart, + AddressSpaceReserved: vmEnd - vmStart, + Writable: true, + AccountIndex: -1, + }) +} + +func appendVasaAccountDataRegion(regions *[]sbpf.InputRegion, offset, hostOffset, regionSize, reserved uint64, writable bool, accountIndex int, data []byte, onWrite func(*sbpf.InputRegion, uint64) error) { + if reserved == 0 { + return + } + *regions = append(*regions, sbpf.InputRegion{ + Offset: offset, + HostOffset: hostOffset, + RegionSize: regionSize, + AddressSpaceReserved: reserved, + Writable: writable, + AccountIndex: accountIndex, + Data: data, + OnWrite: onWrite, + }) +} + +func accountDataRegionWritable(acct *BorrowedAccount, f features.Features) bool { + return acct.DataCanBeChanged(f) == nil +} + +func accountDataDirectMappingActive(execCtx *ExecutionCtx) bool { + return execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) && + execCtx.Features.IsActive(features.AccountDataDirectMapping) +} + +func directMappedAccountData(execCtx *ExecutionCtx, acct *BorrowedAccount, reserved uint64) ([]byte, bool, func(*sbpf.InputRegion, uint64) error, error) { + canChange := acct.DataCanBeChanged(execCtx.Features) == nil + data := acct.Data() + writable := false + + onWrite := func(region *sbpf.InputRegion, requestedLen uint64) error { + if !canChange { + return InstrErrReadonlyDataModified + } + if requestedLen > reserved { + return InstrErrInvalidRealloc + } + + touchedAcct, err := acct.TxCtx.Accounts.Touch(acct.IndexInTransaction) + if err != nil { + return err + } + acct.Account = touchedAcct + + oldLen := uint64(len(touchedAcct.Data)) + if requestedLen > oldLen { + newLen := reserved + if newLen > MaxPermittedDataLength { + newLen = MaxPermittedDataLength + } + if requestedLen > newLen { + return InstrErrInvalidRealloc + } + acct.UpdateAccountsResizeDelta(newLen) + touchedAcct.Resize(newLen, 0) + } + + region.Data = touchedAcct.Data + region.RegionSize = uint64(len(touchedAcct.Data)) + region.Writable = true + return nil + } + + if canChange { + if acct.TxCtx != nil && int(acct.IndexInTransaction) < len(acct.TxCtx.Accounts.Shared) && !acct.TxCtx.Accounts.Shared[acct.IndexInTransaction] { + touchedAcct, err := acct.TxCtx.Accounts.Touch(acct.IndexInTransaction) + if err != nil { + return nil, false, nil, err + } + acct.Account = touchedAcct + data = touchedAcct.Data + writable = true + } + return data, writable, onWrite, nil + } + + return data, false, nil, nil +} + +func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64, []serializedAcctMetadata, []sbpf.InputRegion, error) { txCtx := execCtx.TransactionContext instrCtx, err := txCtx.CurrentInstructionCtx() if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } numIxAccts := instrCtx.NumberOfInstructionAccounts() if numIxAccts > MaxInstructionAccounts { - return nil, nil, 0, InstrErrMaxAccountsExceeded + return nil, nil, 0, nil, nil, InstrErrMaxAccountsExceeded } programAcct, err := instrCtx.BorrowLastProgramAccount(txCtx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } programId := programAcct.Key() programAcct.Drop() @@ -504,20 +600,23 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 for instrAcctIdx := uint64(0); instrAcctIdx < instrCtx.NumberOfInstructionAccounts(); instrAcctIdx++ { isDupe, idxInCallee, err := instrCtx.IsInstructionAccountDuplicate(instrAcctIdx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } if isDupe { accts[int(instrAcctIdx)] = serializeAcct{isDuplicate: true, indexOfAcct: idxInCallee} } else { acct, err := instrCtx.BorrowInstructionAccount(txCtx, instrAcctIdx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } accts[int(instrAcctIdx)] = serializeAcct{indexOfAcct: instrAcctIdx, acct: acct} } } + vasa := execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + directMapping := accountDataDirectMappingActive(execCtx) + size := uint64(8) for _, acct := range accts { @@ -537,9 +636,13 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 size += solana.PublicKeyLength // owner size += 8 // lamports size += 8 // data len - size += MaxPermittedDataIncrease + if directMapping { + size += bpfAlignOfU128 + } else { + size += MaxPermittedDataIncrease + size += alignedDataLen + } size += 8 // rent epoch - size += alignedDataLen } } @@ -556,25 +659,37 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 serializedData = binary.LittleEndian.AppendUint64(serializedData, uint64(len(accts))) preLens := make([]uint64, len(accts)) + accountMetadatas := make([]serializedAcctMetadata, len(accts)) + var inputRegions []sbpf.InputRegion + var regionStart uint64 + var hostRegionStart uint64 + virtualOffset := uint64(8) for i, acct := range accts { borrowedAcct := acct.acct l := len(serializedData) + vmAcctOffset := virtualOffset if acct.isDuplicate { // duplicate serializedData = serializedData[:l+8] position := acct.indexOfAcct serializedData[l] = byte(position) preLens[i] = preLens[position] + accountMetadatas[i] = accountMetadatas[position] + virtualOffset += 8 } else { // not a duplicate dataLen := uint64(len(borrowedAcct.Data())) - numPaddingBytes := ReallocSpace + util.AlignUp(dataLen, 8) - dataLen + alignmentOffset := util.AlignUp(dataLen, 8) - dataLen + reserved := safemath.SaturatingAddU64(dataLen, MaxPermittedDataIncrease) + payloadLen := dataLen + MaxPermittedDataIncrease + alignmentOffset + if directMapping { + payloadLen = bpfAlignOfU128 + } serializedData = serializedData[:l+ 8+ /*not duplicate, signer, writable, executable, 4 bytes padding*/ 32+ /*account pubkey*/ 32+ /*owner pubkey*/ 8+ /*lamports*/ 8+ /*acct data len*/ - len(borrowedAcct.Data())+ /*acct data*/ - int(numPaddingBytes)+ + int(payloadLen)+ 8 /*rent epoch*/] serializedData[l] = 0xff if borrowedAcct.IsSigner() { @@ -611,14 +726,42 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 // acct data len preLens[i] = dataLen binary.LittleEndian.PutUint64(serializedData[l+80:l+88], dataLen) + accountMetadatas[i] = serializedAcctMetadata{ + originalDataLen: dataLen, + vmDataAddr: sbpf.VaddrInput + vmAcctOffset + 88, + vmKeyAddr: sbpf.VaddrInput + vmAcctOffset + 8, + vmLamportsAddr: sbpf.VaddrInput + vmAcctOffset + 72, + vmOwnerAddr: sbpf.VaddrInput + vmAcctOffset + 40, + } + + if vasa { + dataStart := vmAcctOffset + 88 + appendVasaMetadataRegion(&inputRegions, regionStart, dataStart, hostRegionStart) + if directMapping { + data, writable, onWrite, err := directMappedAccountData(execCtx, borrowedAcct, reserved) + if err != nil { + return nil, nil, 0, nil, nil, err + } + appendVasaAccountDataRegion(&inputRegions, dataStart, 0, dataLen, reserved, writable, int(acct.indexOfAcct), data, onWrite) + hostRegionStart = uint64(l) + 88 + (bpfAlignOfU128 - alignmentOffset) + } else { + appendVasaAccountDataRegion(&inputRegions, dataStart, uint64(l)+88, dataLen, reserved, accountDataRegionWritable(borrowedAcct, execCtx.Features), int(acct.indexOfAcct), nil, nil) + hostRegionStart = uint64(l) + 88 + reserved + } + regionStart = safemath.SaturatingAddU64(dataStart, reserved) + } - // data in account - copy(serializedData[l+88:l+88+len(borrowedAcct.Data())], borrowedAcct.Data()) + if directMapping { + clear(serializedData[l+88 : l+88+bpfAlignOfU128]) + } else { + // data in account + copy(serializedData[l+88:l+88+len(borrowedAcct.Data())], borrowedAcct.Data()) - // zero the padding - paddingStart := l + 88 + len(borrowedAcct.Data()) - paddingEnd := l + 88 + len(borrowedAcct.Data()) + int(numPaddingBytes) - clear(serializedData[paddingStart:paddingEnd]) + // zero the padding + paddingStart := l + 88 + len(borrowedAcct.Data()) + paddingEnd := l + 88 + len(borrowedAcct.Data()) + int(MaxPermittedDataIncrease+alignmentOffset) + clear(serializedData[paddingStart:paddingEnd]) + } // rent epoch var rentEpoch uint64 @@ -628,11 +771,12 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 rentEpoch = borrowedAcct.RentEpoch() } binary.LittleEndian.PutUint64(serializedData[len(serializedData)-8:], rentEpoch) + virtualOffset += 88 + reserved + alignmentOffset + 8 } } l := len(serializedData) - instructionDataOffset := uint64(l) + 8 // offset of actual instruction data bytes (past length prefix) + instructionDataOffset := virtualOffset + 8 // offset of actual instruction data bytes (past length prefix) serializedData = serializedData[:len(serializedData)+ 8+ /*instr data len*/ len(instrData)+ /*instr data*/ @@ -641,13 +785,18 @@ func serializeParametersAligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64 binary.LittleEndian.PutUint64(serializedData[l:l+8], uint64(len(instrData))) copy(serializedData[l+8:l+8+len(instrData)], instrData) copy(serializedData[len(serializedData)-32:], programId[:]) + virtualOffset += 8 + uint64(len(instrData)) + solana.PublicKeyLength + + if vasa { + appendVasaMetadataRegion(&inputRegions, regionStart, virtualOffset, hostRegionStart) + } // sanity check for expected len vs. serialized data size if uint64(len(serializedData)) != size { panic(fmt.Sprintf("mismatch between serialized data and expected length: len(serializedData) = %d, expected size = %d", uint64(len(serializedData)), size)) } - return serializedData, preLens, instructionDataOffset, nil + return serializedData, preLens, instructionDataOffset, accountMetadatas, inputRegions, nil } func deserializeParametersAligned(execCtx *ExecutionCtx, parameterBytes []byte, preLens []uint64) error { @@ -656,6 +805,8 @@ func deserializeParametersAligned(execCtx *ExecutionCtx, parameterBytes []byte, if err != nil { return err } + vasa := execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + directMapping := accountDataDirectMappingActive(execCtx) var off uint64 @@ -718,30 +869,55 @@ func deserializeParametersAligned(execCtx *ExecutionCtx, parameterBytes []byte, //alignmentMask := uint64(7) // (alignment - 1) alignmentOffset := util.AlignUp(preLen, 8) - preLen - if uint64(len(parameterBytes)) < (off + postLen) { - return InstrErrInvalidArgument + var data []byte + if !directMapping { + if uint64(len(parameterBytes)) < (off + postLen) { + return InstrErrInvalidArgument + } + data = parameterBytes[off : off+postLen] } - data := parameterBytes[off : off+postLen] - resizeErr := borrowedAcct.CanDataBeResized(postLen) - changedErr := borrowedAcct.DataCanBeChanged(execCtx.Features) + if !vasa { + resizeErr := borrowedAcct.CanDataBeResized(postLen) + changedErr := borrowedAcct.DataCanBeChanged(execCtx.Features) - if resizeErr != nil || changedErr != nil { - acctBytes := borrowedAcct.Data() - if !bytes.Equal(acctBytes, data) { - return fmt.Errorf("data cannot be changed, but did anyway") + if resizeErr != nil || changedErr != nil { + acctBytes := borrowedAcct.Data() + if !bytes.Equal(acctBytes, data) { + return fmt.Errorf("data cannot be changed, but did anyway") + } + } else { + err = borrowedAcct.SetData(execCtx.Features, data) + if err != nil { + return err + } } - } else { + } else if directMapping { + if uint64(len(borrowedAcct.Data())) != postLen { + err = borrowedAcct.SetDataLength(postLen, execCtx.Features) + if err != nil { + return err + } + } + } else if borrowedAcct.DataCanBeChanged(execCtx.Features) == nil { err = borrowedAcct.SetData(execCtx.Features, data) if err != nil { return err } + } else if uint64(len(borrowedAcct.Data())) != postLen { + err = borrowedAcct.SetDataLength(postLen, execCtx.Features) + if err != nil { + return err + } } - off += preLen - - off += MaxPermittedDataIncrease - off += alignmentOffset + if directMapping { + off += bpfAlignOfU128 + } else { + off += preLen + off += MaxPermittedDataIncrease + off += alignmentOffset + } off += 8 // rent epoch ownerPk := solana.PublicKeyFromBytes(owner) @@ -757,21 +933,21 @@ func deserializeParametersAligned(execCtx *ExecutionCtx, parameterBytes []byte, return nil } -func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64, error) { +func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint64, []serializedAcctMetadata, []sbpf.InputRegion, error) { txCtx := execCtx.TransactionContext instrCtx, err := txCtx.CurrentInstructionCtx() if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } numIxAccts := instrCtx.NumberOfInstructionAccounts() if numIxAccts > MaxInstructionAccounts { - return nil, nil, 0, InstrErrMaxAccountsExceeded + return nil, nil, 0, nil, nil, InstrErrMaxAccountsExceeded } programAcct, err := instrCtx.BorrowLastProgramAccount(txCtx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } programId := programAcct.Key() programAcct.Drop() @@ -783,7 +959,7 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint for instrAcctIdx := uint64(0); instrAcctIdx < instrCtx.NumberOfInstructionAccounts(); instrAcctIdx++ { isDupe, idxInCallee, err := instrCtx.IsInstructionAccountDuplicate(instrAcctIdx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } if isDupe { sa := serializeAcct{isDuplicate: true, indexOfAcct: idxInCallee} @@ -791,7 +967,7 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint } else { acct, err := instrCtx.BorrowInstructionAccount(txCtx, instrAcctIdx) if err != nil { - return nil, nil, 0, err + return nil, nil, 0, nil, nil, err } defer acct.Drop() @@ -800,6 +976,9 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint } } + vasa := execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + directMapping := accountDataDirectMappingActive(execCtx) + size := uint64(8) for _, acct := range accts { @@ -816,7 +995,9 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint size += solana.PublicKeyLength // owner size += 1 // executable size += 8 // rent epoch - size += dataLen + if !directMapping { + size += dataLen + } } } @@ -831,13 +1012,22 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint serializedData = make([]byte, 0, size) // No arena configured } serializedData = binary.LittleEndian.AppendUint64(serializedData, uint64(len(accts))) + accountMetadatas := make([]serializedAcctMetadata, 0, len(accts)) + var inputRegions []sbpf.InputRegion + var regionStart uint64 + var hostRegionStart uint64 + virtualOffset := uint64(8) for _, acct := range accts { borrowedAcct := acct.acct + l := len(serializedData) + vmAcctOffset := virtualOffset if acct.isDuplicate { // duplicate position := acct.indexOfAcct serializedData = append(serializedData, byte(position)) preLens = append(preLens, preLens[position]) + accountMetadatas = append(accountMetadatas, accountMetadatas[position]) + virtualOffset++ } else { // not a duplicate serializedData = append(serializedData, 0xff) @@ -865,9 +1055,35 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint dataLen := uint64(len(borrowedAcct.Data())) preLens = append(preLens, dataLen) serializedData = binary.LittleEndian.AppendUint64(serializedData, dataLen) + accountMetadatas = append(accountMetadatas, serializedAcctMetadata{ + originalDataLen: dataLen, + vmDataAddr: sbpf.VaddrInput + vmAcctOffset + 51, + vmKeyAddr: sbpf.VaddrInput + vmAcctOffset + 3, + vmLamportsAddr: sbpf.VaddrInput + vmAcctOffset + 35, + vmOwnerAddr: sbpf.VaddrInput + vmAcctOffset + 51 + dataLen, + }) + + if vasa { + dataStart := vmAcctOffset + 51 + appendVasaMetadataRegion(&inputRegions, regionStart, dataStart, hostRegionStart) + if directMapping { + data, writable, onWrite, err := directMappedAccountData(execCtx, borrowedAcct, dataLen) + if err != nil { + return nil, nil, 0, nil, nil, err + } + appendVasaAccountDataRegion(&inputRegions, dataStart, 0, dataLen, dataLen, writable, int(acct.indexOfAcct), data, onWrite) + hostRegionStart = uint64(l) + 51 + } else { + appendVasaAccountDataRegion(&inputRegions, dataStart, uint64(l)+51, dataLen, dataLen, accountDataRegionWritable(borrowedAcct, execCtx.Features), int(acct.indexOfAcct), nil, nil) + hostRegionStart = uint64(l) + 51 + dataLen + } + regionStart = safemath.SaturatingAddU64(dataStart, dataLen) + } - // data in account - serializedData = append(serializedData, borrowedAcct.Data()...) + if !directMapping { + // data in account + serializedData = append(serializedData, borrowedAcct.Data()...) + } // owner owner := [32]byte(borrowedAcct.Owner()) @@ -888,11 +1104,12 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint rentEpoch = borrowedAcct.RentEpoch() } serializedData = binary.LittleEndian.AppendUint64(serializedData, rentEpoch) + virtualOffset += 51 + dataLen + solana.PublicKeyLength + 1 + 8 } } // instr data len - instructionDataOffset := uint64(len(serializedData)) + 8 // offset of actual instruction data bytes (past length prefix) + instructionDataOffset := virtualOffset + 8 // offset of actual instruction data bytes (past length prefix) serializedData = binary.LittleEndian.AppendUint64(serializedData, uint64(len(instrData))) // instr data @@ -901,13 +1118,18 @@ func serializeParametersUnaligned(execCtx *ExecutionCtx) ([]byte, []uint64, uint // program id programIdSlice := programId[:] serializedData = append(serializedData, programIdSlice...) + virtualOffset += 8 + uint64(len(instrData)) + solana.PublicKeyLength + + if vasa { + appendVasaMetadataRegion(&inputRegions, regionStart, virtualOffset, hostRegionStart) + } // sanity check for expected len vs. serialized data size if uint64(len(serializedData)) != size { panic("mismatch between serialized data and expected length") } - return serializedData, preLens, instructionDataOffset, nil + return serializedData, preLens, instructionDataOffset, accountMetadatas, inputRegions, nil } func deserializeParametersUnaligned(execCtx *ExecutionCtx, parameterBytes []byte, preLens []uint64) error { @@ -916,6 +1138,8 @@ func deserializeParametersUnaligned(execCtx *ExecutionCtx, parameterBytes []byte if err != nil { return err } + vasa := execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + directMapping := accountDataDirectMappingActive(execCtx) var off uint64 @@ -951,32 +1175,56 @@ func deserializeParametersUnaligned(execCtx *ExecutionCtx, parameterBytes []byte off += 8 // data length - if uint64(len(parameterBytes)) < (off + preLen) { - return InstrErrInvalidArgument + var data []byte + if !directMapping { + if uint64(len(parameterBytes)) < (off + preLen) { + return InstrErrInvalidArgument + } + data = parameterBytes[off : off+preLen] } - data := parameterBytes[off : off+preLen] - resizeErr := borrowedAcct.CanDataBeResized(uint64(len(data))) - changedErr := borrowedAcct.DataCanBeChanged(execCtx.Features) + if !vasa { + resizeErr := borrowedAcct.CanDataBeResized(uint64(len(data))) + changedErr := borrowedAcct.DataCanBeChanged(execCtx.Features) - if resizeErr != nil || changedErr != nil { - acctBytes := borrowedAcct.Data() - if len(acctBytes) != len(data) { - return fmt.Errorf("data cannot be changed, but did anyway") - } - for count := range acctBytes { - if acctBytes[count] != data[count] { + if resizeErr != nil || changedErr != nil { + acctBytes := borrowedAcct.Data() + if len(acctBytes) != len(data) { return fmt.Errorf("data cannot be changed, but did anyway") } + for count := range acctBytes { + if acctBytes[count] != data[count] { + return fmt.Errorf("data cannot be changed, but did anyway") + } + } + } else { + err = borrowedAcct.SetData(execCtx.Features, data) + if err != nil { + return err + } } - } else { + } else if directMapping { + if uint64(len(borrowedAcct.Data())) != preLen { + err = borrowedAcct.SetDataLength(preLen, execCtx.Features) + if err != nil { + return err + } + } + } else if borrowedAcct.DataCanBeChanged(execCtx.Features) == nil { err = borrowedAcct.SetData(execCtx.Features, data) if err != nil { return err } + } else if uint64(len(borrowedAcct.Data())) != preLen { + err = borrowedAcct.SetDataLength(preLen, execCtx.Features) + if err != nil { + return err + } } - off += preLen + if !directMapping { + off += preLen + } off += solana.PublicKeyLength // owner off += 1 // executable @@ -1013,19 +1261,26 @@ func executeLoadedProgram(execCtx *ExecutionCtx, program *sbpf.Program, syscallR var parameterBytes []byte var preLens []uint64 var instrDataOffset uint64 + var accountMetadatas []serializedAcctMetadata + var inputRegions []sbpf.InputRegion if isLoaderDeprecated { - parameterBytes, preLens, instrDataOffset, err = serializeParametersUnaligned(execCtx) + parameterBytes, preLens, instrDataOffset, accountMetadatas, inputRegions, err = serializeParametersUnaligned(execCtx) if err != nil { return err } } else { - parameterBytes, preLens, instrDataOffset, err = serializeParametersAligned(execCtx) + parameterBytes, preLens, instrDataOffset, accountMetadatas, inputRegions, err = serializeParametersAligned(execCtx) if err != nil { return err } } + execCtx.serializedAccountMetadataStack = append(execCtx.serializedAccountMetadataStack, accountMetadatas) + defer func() { + execCtx.serializedAccountMetadataStack = execCtx.serializedAccountMetadataStack[:len(execCtx.serializedAccountMetadataStack)-1] + }() + var inputDataVaddr uint64 if execCtx.Features.IsActive(features.ProvideInstructionDataOffsetInVmR2) { inputDataVaddr = sbpf.VaddrInput + instrDataOffset @@ -1041,8 +1296,10 @@ func executeLoadedProgram(execCtx *ExecutionCtx, program *sbpf.Program, syscallR Context: execCtx, TxSignature: execCtx.TransactionContext.Signature, ProgramId: programId, + InputRegions: inputRegions, + DisableStackFrameGaps: execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) || + !program.SbpfVersion.StackFrameGaps(), } - start := time.Now() interpreter := sbpf.NewInterpreter(program, opts) defer interpreter.Finish() @@ -1051,57 +1308,52 @@ func executeLoadedProgram(execCtx *ExecutionCtx, program *sbpf.Program, syscallR ret, _, runErr := interpreter.Run() metrics.GlobalBlockReplay.SbpfInterpreterRun.AddTimingSince(start) + if execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) { + runErr = mapVirtualAddressSpaceRunErr(execCtx, runErr, inputRegions) + } + if runErr != nil { //mlog.Log.Debugf("program execution result: %s", runErr) } else if ret != 0 { - runErr = fmt.Errorf("program execution (%s) returned failure: %d", programId, ret) - //mlog.Log.Debugf("program execution (%s) returned failure: %d", programId, ret) - } else { - //mlog.Log.Debugf("program execution (%s) returned success", programId) + runErr = instrErrFromProgramStatus(ret) } - /* - _, returnData := execCtx.TransactionContext.ReturnData() - if len(returnData) != 0 { - base64.StdEncoding.EncodeToString(returnData) - mlog.Log.Debugf("Program return %s %s", returnedDataProgId, encodedStr) - }*/ - // deserialize data if runErr == nil { if isLoaderDeprecated { err = deserializeParametersUnaligned(execCtx, parameterBytes, preLens) if err != nil { - //mlog.Log.Debugf("failed to deserialize (unaligned), %s", err) return InstrErrInvalidArgument } } else { err = deserializeParametersAligned(execCtx, parameterBytes, preLens) if err != nil { - //mlog.Log.Debugf("failed to deserialize (aligned), %s", err) return InstrErrInvalidArgument } } } - return runErr + return normalizeProgramRunErr(runErr) } func executeProgramFromBytes(execCtx *ExecutionCtx, programAddr solana.PublicKey, programData []byte, syscallRegistry sbpf.SyscallRegistry) error { start := time.Now() loader, err := loader.NewLoaderWithSyscalls(programData, syscallRegistry, false, &execCtx.Features) if err != nil { - return err + return InstrErrUnsupportedProgramId } program, err := loader.Load() if err != nil { - return err + return InstrErrUnsupportedProgramId + } + if err := program.Verify(); err != nil { + return InstrErrUnsupportedProgramId } entry := &accountsdb.ProgramCacheEntry{Program: program} if !execCtx.IsSimulation { - execCtx.SlotCtx.AccountsDb.AddProgramToCache(programAddr, entry) + addProgramToCache(execCtx, programAddr, entry) } metrics.GlobalBlockReplay.AddProgramToCache.AddTimingSince(start) @@ -1109,6 +1361,78 @@ func executeProgramFromBytes(execCtx *ExecutionCtx, programAddr solana.PublicKey return executeLoadedProgram(execCtx, program, syscallRegistry) } +func addProgramToCache(execCtx *ExecutionCtx, programAddr solana.PublicKey, entry *accountsdb.ProgramCacheEntry) { + if execCtx.SlotCtx == nil || execCtx.SlotCtx.AccountsDb == nil { + return + } + execCtx.SlotCtx.AccountsDb.AddProgramToCache(programAddr, entry) +} + +func mapVirtualAddressSpaceRunErr(execCtx *ExecutionCtx, err error, inputRegions []sbpf.InputRegion) error { + if err == nil { + return nil + } + + var badAccess sbpf.ExcBadAccess + if !errors.As(err, &badAccess) { + return err + } + + for _, region := range inputRegions { + if region.AccountIndex < 0 { + continue + } + + regionStart := sbpf.VaddrInput + region.Offset + regionEnd := safemath.SaturatingAddU64(regionStart, region.AddressSpaceReserved) + if badAccess.Addr < regionStart || badAccess.Addr >= regionEnd { + continue + } + accessEnd := safemath.SaturatingAddU64(badAccess.Addr, badAccess.Size) + if accessEnd > regionEnd { + return err + } + + txCtx := execCtx.TransactionContext + instrCtx, borrowErr := txCtx.CurrentInstructionCtx() + if borrowErr != nil { + return borrowErr + } + account, borrowErr := instrCtx.BorrowInstructionAccount(txCtx, uint64(region.AccountIndex)) + if borrowErr != nil { + return borrowErr + } + changeErr := account.DataCanBeChanged(execCtx.Features) + account.Drop() + + if badAccess.Write { + if changeErr != nil { + return changeErr + } + return InstrErrInvalidRealloc + } + if changeErr != nil { + return InstrErrAccountDataTooSmall + } + return InstrErrInvalidRealloc + } + + return err +} + +func normalizeProgramRunErr(err error) error { + if err == nil { + return nil + } + if _, ok := solanaErrCode(err); ok { + return err + } + if IsCustomErr(err) { + return err + } + return InstrErrProgramFailedToComplete +} + func BpfLoaderProgramExecute(execCtx *ExecutionCtx) error { //mlog.Log.Debugf("BpfLoaderProgramExecute") @@ -1155,7 +1479,9 @@ func BpfLoaderProgramExecute(execCtx *ExecutionCtx) error { } if !programAcct.IsExecutable() { - //mlog.Log.Debugf("program %s is not executable", programAcct) + return InstrErrUnsupportedProgramId + } + if programAcct.Lamports() == 0 { return InstrErrUnsupportedProgramId } @@ -1244,8 +1570,8 @@ func BpfLoaderProgramExecute(execCtx *ExecutionCtx) error { return err } - if programDataAcctState.Type == UpgradeableLoaderStateTypeUninitialized { - return InstrErrInvalidAccountData + if programDataAcctState.Type != UpgradeableLoaderStateTypeProgramData { + return InstrErrUnsupportedProgramId } programDataSlot := programDataAcctState.ProgramData.Slot @@ -1253,6 +1579,9 @@ func BpfLoaderProgramExecute(execCtx *ExecutionCtx) error { return InstrErrInvalidAccountData } + if len(programDataAcct.Data) < upgradeableLoaderSizeOfProgramDataMetaData { + return InstrErrUnsupportedProgramId + } programAcctKey = programAcctState.Program.ProgramDataAddress programBytes = programDataAcct.Data[upgradeableLoaderSizeOfProgramDataMetaData:] metrics.GlobalBlockReplay.GetProgramDataUncachedMarshal.AddTimingSince(start) @@ -1479,8 +1808,8 @@ func UpgradeableLoaderDeployWithMaxDataLen(execCtx *ExecutionCtx, txCtx *Transac return InstrErrInvalidArgument } - if bufferAcctState.Buffer.AuthorityAddress != nil && authorityKey != nil && - *bufferAcctState.Buffer.AuthorityAddress != *authorityKey { + if (bufferAcctState.Buffer.AuthorityAddress == nil) != (authorityKey == nil) || + (bufferAcctState.Buffer.AuthorityAddress != nil && *bufferAcctState.Buffer.AuthorityAddress != *authorityKey) { return InstrErrIncorrectAuthority } @@ -1656,7 +1985,7 @@ func UpgradeableLoaderDeployWithMaxDataLen(execCtx *ExecutionCtx, txCtx *Transac //mlog.Log.Debugf("deployed program: %s", newProgramId) entry := &accountsdb.ProgramCacheEntry{Program: loadedProgram, DeploymentSlot: clock.Slot} - execCtx.SlotCtx.AccountsDb.AddProgramToCache(programDataKey, entry) + addProgramToCache(execCtx, programDataKey, entry) return nil } @@ -1908,7 +2237,7 @@ func UpgradeableLoaderUpgrade(execCtx *ExecutionCtx, txCtx *TransactionCtx, inst //mlog.Log.Debugf("upgraded program %s", program.Key()) entry := &accountsdb.ProgramCacheEntry{Program: loadedProgram, DeploymentSlot: clock.Slot} - execCtx.SlotCtx.AccountsDb.AddProgramToCache(programData.Key(), entry) + addProgramToCache(execCtx, programData.Key(), entry) return nil } @@ -2558,7 +2887,7 @@ func UpgradeableLoaderExtendProgram(execCtx *ExecutionCtx, txCtx *TransactionCtx //mlog.Log.Debugf("Extended ProgramData account by %d bytes", additionalBytes) entry := &accountsdb.ProgramCacheEntry{Program: loadedProgram, DeploymentSlot: clock.Slot} - execCtx.SlotCtx.AccountsDb.AddProgramToCache(programDataAcct.Key(), entry) + addProgramToCache(execCtx, programDataAcct.Key(), entry) return nil } diff --git a/pkg/sealevel/bpf_loader_error_test.go b/pkg/sealevel/bpf_loader_error_test.go new file mode 100644 index 00000000..4264c143 --- /dev/null +++ b/pkg/sealevel/bpf_loader_error_test.go @@ -0,0 +1,36 @@ +package sealevel + +import ( + "fmt" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/cu" + "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/stretchr/testify/require" +) + +func TestNormalizeProgramRunErrMapsOutOfCU(t *testing.T) { + err := &sbpf.Exception{Detail: fmt.Errorf("wrapped: %w", sbpf.ExcOutOfCU)} + + require.ErrorIs(t, normalizeProgramRunErr(err), InstrErrProgramFailedToComplete) + require.ErrorIs(t, normalizeProgramRunErr(cu.ErrComputeExceeded), InstrErrProgramFailedToComplete) +} + +func TestNormalizeProgramRunErrPreservesSyscallInstructionErr(t *testing.T) { + err := &sbpf.Exception{Detail: fmt.Errorf("wrapped: %w", sbpf.ExcSyscallError{Err: InstrErrComputationalBudgetExceeded})} + + require.ErrorIs(t, normalizeProgramRunErr(err), InstrErrComputationalBudgetExceeded) +} + +func TestInstrErrFromProgramStatus(t *testing.T) { + require.ErrorAs(t, instrErrFromProgramStatus(uint64(1)<<32), &InstrErrCustomCode{}) + require.ErrorIs(t, instrErrFromProgramStatus(uint64(2)<<32), InstrErrInvalidArgument) + require.ErrorIs(t, instrErrFromProgramStatus(uint64(3)<<32), InstrErrInvalidInstructionData) + require.ErrorIs(t, instrErrFromProgramStatus(uint64(23)<<32), InstrErrInvalidAccountOwner) + require.ErrorIs(t, instrErrFromProgramStatus(uint64(26)<<32), InstrErrIncorrectAuthority) + + var custom InstrErrCustomCode + require.ErrorAs(t, instrErrFromProgramStatus(42), &custom) + require.Equal(t, uint32(42), custom.Code) + require.ErrorIs(t, instrErrFromProgramStatus(uint64(99)<<32), InstrErrInvalidError) +} diff --git a/pkg/sealevel/errors.go b/pkg/sealevel/errors.go index 64cb6e13..b2040ec6 100644 --- a/pkg/sealevel/errors.go +++ b/pkg/sealevel/errors.go @@ -1,6 +1,83 @@ package sealevel -import "errors" +import ( + "errors" + "fmt" +) + +type InstrErrCustomCode struct { + Code uint32 +} + +func (err InstrErrCustomCode) Error() string { + return fmt.Sprintf("InstrErrCustom(%d)", err.Code) +} + +const programErrorBuiltinBitShift = 32 + +func instrErrFromProgramStatus(status uint64) error { + const customZero = uint64(1) << programErrorBuiltinBitShift + + switch status { + case customZero: + return InstrErrCustomCode{Code: 0} + case uint64(2) << programErrorBuiltinBitShift: + return InstrErrInvalidArgument + case uint64(3) << programErrorBuiltinBitShift: + return InstrErrInvalidInstructionData + case uint64(4) << programErrorBuiltinBitShift: + return InstrErrInvalidAccountData + case uint64(5) << programErrorBuiltinBitShift: + return InstrErrAccountDataTooSmall + case uint64(6) << programErrorBuiltinBitShift: + return InstrErrInsufficientFunds + case uint64(7) << programErrorBuiltinBitShift: + return InstrErrIncorrectProgramId + case uint64(8) << programErrorBuiltinBitShift: + return InstrErrMissingRequiredSignature + case uint64(9) << programErrorBuiltinBitShift: + return InstrErrAccountAlreadyInitialized + case uint64(10) << programErrorBuiltinBitShift: + return InstrErrUninitializedAccount + case uint64(11) << programErrorBuiltinBitShift: + return InstrErrNotEnoughAccountKeys + case uint64(12) << programErrorBuiltinBitShift: + return InstrErrAccountBorrowFailed + case uint64(13) << programErrorBuiltinBitShift: + return InstrErrMaxSeedLengthExceeded + case uint64(14) << programErrorBuiltinBitShift: + return InstrErrInvalidSeeds + case uint64(15) << programErrorBuiltinBitShift: + return InstrErrBorshIoError + case uint64(16) << programErrorBuiltinBitShift: + return InstrErrAccountNotRentExempt + case uint64(17) << programErrorBuiltinBitShift: + return InstrErrUnsupportedSysvar + case uint64(18) << programErrorBuiltinBitShift: + return InstrErrIllegalOwner + case uint64(19) << programErrorBuiltinBitShift: + return InstrErrMaxAccountsDataAllocationsExceeded + case uint64(20) << programErrorBuiltinBitShift: + return InstrErrInvalidRealloc + case uint64(21) << programErrorBuiltinBitShift: + return InstrErrMaxInstructionTraceLengthExceeded + case uint64(22) << programErrorBuiltinBitShift: + return InstrErrBuiltinProgramsMustConsumeComputeUnits + case uint64(23) << programErrorBuiltinBitShift: + return InstrErrInvalidAccountOwner + case uint64(24) << programErrorBuiltinBitShift: + return InstrErrArithmeticOverflow + case uint64(25) << programErrorBuiltinBitShift: + return InstrErrImmutable + case uint64(26) << programErrorBuiltinBitShift: + return InstrErrIncorrectAuthority + default: + if status>>programErrorBuiltinBitShift == 0 { + return InstrErrCustomCode{Code: uint32(status)} + } + return InstrErrInvalidError + } +} // instruction errors var ( @@ -79,6 +156,7 @@ var ( SyscallErrInstructionTooLarge = errors.New("SyscallErrInstructionTooLarge") SyscallErrMaxInstructionAccountInfosExceeded = errors.New("SyscallErrMaxInstructionAccountInfosExceeded") SyscallErrTooManyAccounts = errors.New("SyscallErrTooManyAccounts") + SyscallErrInvalidPointer = errors.New("SyscallError::InvalidPointer") ) var ( @@ -331,14 +409,92 @@ var customErrs = map[error]bool{ } func IsCustomErr(err error) bool { + var custom InstrErrCustomCode + if errors.As(err, &custom) { + return true + } return customErrs[err] } +var instructionErrTargets = []error{ + InstrErrGenericError, + InstrErrInvalidArgument, + InstrErrInvalidInstructionData, + InstrErrInvalidAccountData, + InstrErrAccountDataTooSmall, + InstrErrInsufficientFunds, + InstrErrIncorrectProgramId, + InstrErrMissingRequiredSignature, + InstrErrAccountAlreadyInitialized, + InstrErrUninitializedAccount, + InstrErrUnbalancedInstruction, + InstrErrModifiedProgramId, + InstrErrExternalAccountLamportSpend, + InstrErrExternalAccountDataModified, + InstrErrReadonlyLamportChange, + InstrErrReadonlyDataModified, + InstrErrDuplicateAccountIndex, + InstrErrExecutableModified, + InstrErrRentEpochModified, + InstrErrNotEnoughAccountKeys, + InstrErrAccountDataSizeChanged, + InstrErrAccountNotExecutable, + InstrErrAccountBorrowFailed, + InstrErrAccountBorrowOutstanding, + InstrErrDuplicateAccountOutOfSync, + InstrErrCustom, + InstrErrInvalidError, + InstrErrExecutableDataModified, + InstrErrExecutableLamportChange, + InstrErrExecutableAccountNotRentExempt, + InstrErrUnsupportedProgramId, + InstrErrCallDepth, + InstrErrMissingAccount, + InstrErrReentrancyNotAllowed, + InstrErrMaxSeedLengthExceeded, + InstrErrInvalidSeeds, + InstrErrInvalidRealloc, + InstrErrComputationalBudgetExceeded, + InstrErrPrivilegeEscalation, + InstrErrProgramEnvironmentSetupFailure, + InstrErrProgramFailedToComplete, + InstrErrProgramFailedToCompile, + InstrErrImmutable, + InstrErrIncorrectAuthority, + InstrErrBorshIoError, + InstrErrAccountNotRentExempt, + InstrErrInvalidAccountOwner, + InstrErrArithmeticOverflow, + InstrErrUnsupportedSysvar, + InstrErrIllegalOwner, + InstrErrMaxAccountsDataAllocationsExceeded, + InstrErrMaxAccountsExceeded, + InstrErrMaxInstructionTraceLengthExceeded, + InstrErrBuiltinProgramsMustConsumeComputeUnits, +} + +func solanaErrCode(err error) (int, bool) { + if err == nil { + return InstrErrCodeSuccess, true + } + if code, ok := solanaNumericalErrCodes[err]; ok { + return code, true + } + for _, instructionErr := range instructionErrTargets { + if errors.Is(err, instructionErr) { + return solanaNumericalErrCodes[instructionErr], true + } + } + return 0, false +} + // TODO: add additional error conversions func TranslateErrToErrCode(err error) int { - if err == nil { - return InstrErrCodeSuccess + var custom InstrErrCustomCode + if errors.As(err, &custom) { + return int(custom.Code) } - return solanaNumericalErrCodes[err] + code, _ := solanaErrCode(err) + return code } diff --git a/pkg/sealevel/execution_ctx.go b/pkg/sealevel/execution_ctx.go index cc7e8152..4862c48f 100644 --- a/pkg/sealevel/execution_ctx.go +++ b/pkg/sealevel/execution_ctx.go @@ -34,6 +34,8 @@ type ExecutionCtx struct { RecordInnerInstructions bool currentTopLevelInstrIdx uint8 InnerInstrs []RecordedInnerInstr + + serializedAccountMetadataStack [][]serializedAcctMetadata } // RecordedInnerInstr is a CPI invocation captured during execution. diff --git a/pkg/sealevel/syscalls.go b/pkg/sealevel/syscalls.go index 8fb8b505..03b49868 100644 --- a/pkg/sealevel/syscalls.go +++ b/pkg/sealevel/syscalls.go @@ -72,6 +72,7 @@ func Syscalls(ft *features.Features, isDeploy bool, h uint32) (f sbpf.Syscall, o f = SyscallKeccak256 case hash_sol_blake3: f = SyscallBlake3 + ok = ft.IsActive(features.Blake3SyscallEnabled) case hash_sol_secp256k1_recover: f = SyscallSecp256k1Recover case hash_sol_poseidon: diff --git a/pkg/sealevel/syscalls_common.go b/pkg/sealevel/syscalls_common.go index 860c9c1f..c8b656b9 100644 --- a/pkg/sealevel/syscalls_common.go +++ b/pkg/sealevel/syscalls_common.go @@ -4,7 +4,9 @@ import ( "errors" "math" - "github.com/Overclock-Validator/mithril/pkg/cu" + a "github.com/Overclock-Validator/mithril/pkg/addresses" + "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/Overclock-Validator/mithril/pkg/sbpf" ) func isNonOverlapping(src, srcLen, dst, dstLen uint64) bool { @@ -23,12 +25,48 @@ func syscallErrCustom(msg string) (uint64, error) { return math.MaxUint64, errors.New(msg) } +func syscallCheckAligned(execCtx *ExecutionCtx) bool { + txCtx := execCtx.TransactionContext + if txCtx == nil { + return true + } + instrCtx, err := txCtx.CurrentInstructionCtx() + if err != nil { + return true + } + programAcct, err := instrCtx.BorrowLastProgramAccount(txCtx) + if err != nil { + return true + } + defer programAcct.Drop() + + return programAcct.Owner() != a.BpfLoaderDeprecatedAddr +} + +func syscallAddressIsAligned(execCtx *ExecutionCtx, addr uint64, alignment uint64) bool { + if !syscallCheckAligned(execCtx) { + return false + } + if alignment <= 1 { + return true + } + return addr%alignment == 0 +} + +func syscallAddressRequiresAlignment(execCtx *ExecutionCtx, addr uint64, alignment uint64) bool { + return syscallCheckAligned(execCtx) && alignment > 1 && addr%alignment != 0 +} + +func syscallParameterAddressRestricted(execCtx *ExecutionCtx, addr uint64) bool { + return execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) && addr >= sbpf.VaddrInput +} + func syscallErr(err error) (uint64, error) { return math.MaxUint64, err } func syscallCuErr() (uint64, error) { - return math.MaxUint64, cu.ErrComputeExceeded + return math.MaxUint64, InstrErrComputationalBudgetExceeded } func syscallSuccess(result uint64) (uint64, error) { diff --git a/pkg/sealevel/syscalls_cpi.go b/pkg/sealevel/syscalls_cpi.go index 3324f764..5513f749 100644 --- a/pkg/sealevel/syscalls_cpi.go +++ b/pkg/sealevel/syscalls_cpi.go @@ -3,7 +3,6 @@ package sealevel import ( "bytes" "encoding/binary" - "unsafe" a "github.com/Overclock-Validator/mithril/pkg/addresses" "github.com/Overclock-Validator/mithril/pkg/base58" @@ -24,6 +23,9 @@ const ( MaxCpiAccountInfos = 128 MaxCpiAccountInfosSimd0339 = 255 AccountInfoByteSize = 80 + + solAccountInfoCDataLenOffset = 16 + solAccountInfoRustDataLenOffset = 32 ) func checkInstructionSize(execCtx *ExecutionCtx, numAccounts uint64, dataLen uint64) error { @@ -336,8 +338,82 @@ func cpiInvokeUnits(f *features.Features) uint64 { return cu.CUInvokeUnits } +func syscallParameterAddressRangeRestricted(execCtx *ExecutionCtx, addr, size uint64) bool { + return execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) && + safemath.SaturatingAddU64(addr, size) >= sbpf.VaddrInput +} + +func currentSerializedAccountMetadata(execCtx *ExecutionCtx, indexInCaller uint64) (serializedAcctMetadata, error) { + if len(execCtx.serializedAccountMetadataStack) == 0 { + return serializedAcctMetadata{}, InstrErrMissingAccount + } + accountMetadatas := execCtx.serializedAccountMetadataStack[len(execCtx.serializedAccountMetadataStack)-1] + if indexInCaller >= uint64(len(accountMetadatas)) { + return serializedAcctMetadata{}, InstrErrMissingAccount + } + return accountMetadatas[indexInCaller], nil +} + +func checkAccountInfoPointer(vmAddr, expectedVmAddr uint64) error { + if vmAddr != expectedVmAddr { + return SyscallErrInvalidPointer + } + return nil +} + +func accountDataLenForCpi(execCtx *ExecutionCtx, refToLenInVm []byte, fallbackLen uint64) uint64 { + if execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + return binary.LittleEndian.Uint64(refToLenInVm) + } + return fallbackLen +} + +func checkCpiDataLength(execCtx *ExecutionCtx, dataLen, originalDataLen uint64, isLoaderDeprecated bool) error { + if !execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + return nil + } + + reserved := originalDataLen + if !isLoaderDeprecated { + reserved = safemath.SaturatingAddU64(reserved, MaxPermittedDataIncrease) + } + if dataLen > reserved { + return InstrErrInvalidRealloc + } + return nil +} + +type inputBackingTranslator interface { + TranslateInput(addr uint64, size uint64) ([]byte, error) +} + +type inputRegionLengthUpdater interface { + SetInputRegionLength(addr uint64, length uint64, writable bool) bool +} + +type inputRegionDataUpdater interface { + SetInputRegionData(addr uint64, data []byte, length uint64, writable bool) bool +} + +func translateSerializedAccountData(vm sbpf.VM, execCtx *ExecutionCtx, addr, size uint64) ([]byte, error) { + if accountDataDirectMappingActive(execCtx) { + return nil, nil + } + if execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) { + if translator, ok := vm.(inputBackingTranslator); ok { + return translator.TranslateInput(addr, size) + } + } + return vm.Translate(addr, size, true) +} + func translateAccountInfosC(vm sbpf.VM, accountInfosAddr, accountInfosLen uint64) ([]SolAccountInfoC, []solana.PublicKey, error) { size := safemath.SaturatingMulU64(accountInfosLen, SolAccountInfoCSize) + execCtx := executionCtx(vm) + if syscallParameterAddressRangeRestricted(execCtx, accountInfosAddr, size) { + return nil, nil, SyscallErrInvalidPointer + } + accountInfosData, err := vm.Translate(accountInfosAddr, size, false) if err != nil { return nil, nil, err @@ -355,7 +431,6 @@ func translateAccountInfosC(vm sbpf.VM, accountInfosAddr, accountInfosLen uint64 accountInfos = append(accountInfos, acctInfo) } - execCtx := executionCtx(vm) err = checkAccountInfos(execCtx, uint64(len(accountInfos))) if err != nil { return nil, nil, err @@ -384,6 +459,11 @@ func translateAccountInfosC(vm sbpf.VM, accountInfosAddr, accountInfosLen uint64 func translateAccountInfosRust(vm sbpf.VM, accountInfosAddr, accountInfosLen uint64) ([]SolAccountInfoRust, []solana.PublicKey, error) { size := safemath.SaturatingMulU64(accountInfosLen, SolAccountInfoRustSize) + execCtx := executionCtx(vm) + if syscallParameterAddressRangeRestricted(execCtx, accountInfosAddr, size) { + return nil, nil, SyscallErrInvalidPointer + } + accountInfosData, err := vm.Translate(accountInfosAddr, size, false) if err != nil { return nil, nil, err @@ -401,7 +481,6 @@ func translateAccountInfosRust(vm sbpf.VM, accountInfosAddr, accountInfosLen uin accountInfos = append(accountInfos, acctInfo) } - execCtx := executionCtx(vm) err = checkAccountInfos(execCtx, uint64(len(accountInfos))) if err != nil { return nil, nil, err @@ -428,7 +507,33 @@ func translateAccountInfosRust(vm sbpf.VM, accountInfosAddr, accountInfosLen uin return accountInfos, accountInfoKeys, nil } -func callerAccountFromAccountInfoC(vm sbpf.VM, execCtx *ExecutionCtx, callerAcctIdx uint64, accountInfo SolAccountInfoC, accountInfosAddr uint64) (CallerAccount, error) { +func callerAccountFromAccountInfoC(vm sbpf.VM, execCtx *ExecutionCtx, indexInCaller, callerAcctIdx uint64, accountInfo SolAccountInfoC, accountInfosAddr uint64, isLoaderDeprecated bool) (CallerAccount, error) { + originalDataLen := accountInfo.DataLen + var accountMetadata serializedAcctMetadata + var err error + usesSerializedAccountMetadata := execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) || + execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + if usesSerializedAccountMetadata { + accountMetadata, err = currentSerializedAccountMetadata(execCtx, indexInCaller) + if err != nil { + return CallerAccount{}, err + } + originalDataLen = accountMetadata.originalDataLen + } + if execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + if err = checkAccountInfoPointer(accountInfo.KeyAddr, accountMetadata.vmKeyAddr); err != nil { + return CallerAccount{}, err + } + if err = checkAccountInfoPointer(accountInfo.OwnerAddr, accountMetadata.vmOwnerAddr); err != nil { + return CallerAccount{}, err + } + if err = checkAccountInfoPointer(accountInfo.LamportsAddr, accountMetadata.vmLamportsAddr); err != nil { + return CallerAccount{}, err + } + if err = checkAccountInfoPointer(accountInfo.DataAddr, accountMetadata.vmDataAddr); err != nil { + return CallerAccount{}, err + } + } lamports, err := vm.Translate(accountInfo.LamportsAddr, 8, true) if err != nil { @@ -440,31 +545,61 @@ func callerAccountFromAccountInfoC(vm sbpf.VM, execCtx *ExecutionCtx, callerAcct return CallerAccount{}, err } - cost := accountInfo.DataLen / cu.CUCpiBytesPerUnit - err = execCtx.ComputeMeter.Consume(cost) + dataLenVmAddr := accountInfosAddr + (callerAcctIdx * SolAccountInfoCSize) + solAccountInfoCDataLenOffset + if syscallParameterAddressRestricted(execCtx, dataLenVmAddr) { + return CallerAccount{}, SyscallErrInvalidPointer + } + + refToLenInVm, err := vm.Translate(dataLenVmAddr, 8, true) if err != nil { return CallerAccount{}, err } - serializedData, err := vm.Translate(accountInfo.DataAddr, accountInfo.DataLen, true) + dataLen := accountDataLenForCpi(execCtx, refToLenInVm, accountInfo.DataLen) + err = checkCpiDataLength(execCtx, dataLen, originalDataLen, isLoaderDeprecated) if err != nil { return CallerAccount{}, err } - dataLenVmAddr := (accountInfosAddr + (callerAcctIdx * SolAccountInfoCSize)) + uint64(uintptr(unsafe.Pointer(&accountInfo.DataLen))) - uint64(uintptr(unsafe.Pointer(&accountInfo))) + cost := dataLen / cu.CUCpiBytesPerUnit + err = execCtx.ComputeMeter.Consume(cost) + if err != nil { + return CallerAccount{}, err + } - refToLenInVm, err := vm.Translate(dataLenVmAddr, 8, true) + serializedData, err := translateSerializedAccountData(vm, execCtx, accountInfo.DataAddr, dataLen) if err != nil { return CallerAccount{}, err } - callerAcct := CallerAccount{Lamports: lamports, Owner: owner, OriginalDataLen: accountInfo.DataLen, + callerAcct := CallerAccount{Lamports: lamports, Owner: owner, OriginalDataLen: originalDataLen, SerializedData: serializedData, VmDataAddr: accountInfo.DataAddr, RefToLenInVm: refToLenInVm} return callerAcct, nil } -func callerAccountFromAccountInfoRust(vm sbpf.VM, execCtx *ExecutionCtx, accountInfo SolAccountInfoRust) (CallerAccount, error) { +func callerAccountFromAccountInfoRust(vm sbpf.VM, execCtx *ExecutionCtx, indexInCaller uint64, accountInfo SolAccountInfoRust, isLoaderDeprecated bool) (CallerAccount, error) { + var accountMetadata serializedAcctMetadata + var err error + usesSerializedAccountMetadata := execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) || + execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) + if usesSerializedAccountMetadata { + accountMetadata, err = currentSerializedAccountMetadata(execCtx, indexInCaller) + if err != nil { + return CallerAccount{}, err + } + } + if execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + if err = checkAccountInfoPointer(accountInfo.PubkeyAddr, accountMetadata.vmKeyAddr); err != nil { + return CallerAccount{}, err + } + if err = checkAccountInfoPointer(accountInfo.OwnerAddr, accountMetadata.vmOwnerAddr); err != nil { + return CallerAccount{}, err + } + if syscallParameterAddressRestricted(execCtx, accountInfo.LamportsBoxAddr) { + return CallerAccount{}, SyscallErrInvalidPointer + } + } lamportsBoxData, err := vm.Translate(accountInfo.LamportsBoxAddr, RefCellRustSize, false) if err != nil { @@ -479,16 +614,26 @@ func callerAccountFromAccountInfoRust(vm sbpf.VM, execCtx *ExecutionCtx, account return CallerAccount{}, err } + if execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + if err = checkAccountInfoPointer(lamportsBox.Addr, accountMetadata.vmLamportsAddr); err != nil { + return CallerAccount{}, err + } + } + lamports, err := vm.Translate(lamportsBox.Addr, 8, true) if err != nil { return CallerAccount{}, err } - owner, err := vm.Translate(accountInfo.OwnerAddr, solana.PublicKeyLength, false) + owner, err := vm.Translate(accountInfo.OwnerAddr, solana.PublicKeyLength, true) if err != nil { return CallerAccount{}, err } + if syscallParameterAddressRestricted(execCtx, accountInfo.DataBoxAddr) { + return CallerAccount{}, SyscallErrInvalidPointer + } + dataBoxBytes, err := vm.Translate(accountInfo.DataBoxAddr, RefCellVecRustSize, false) if err != nil { return CallerAccount{}, err @@ -502,85 +647,149 @@ func callerAccountFromAccountInfoRust(vm sbpf.VM, execCtx *ExecutionCtx, account return CallerAccount{}, err } - cost := dataBox.Len / cu.CUCpiBytesPerUnit - err = execCtx.ComputeMeter.Consume(cost) + originalDataLen := dataBox.Len + if usesSerializedAccountMetadata { + originalDataLen = accountMetadata.originalDataLen + } + if execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) { + if err = checkAccountInfoPointer(dataBox.Addr, accountMetadata.vmDataAddr); err != nil { + return CallerAccount{}, err + } + } + + dataLenVmAddr := safemath.SaturatingAddU64(accountInfo.DataBoxAddr, solAccountInfoRustDataLenOffset) + if syscallParameterAddressRestricted(execCtx, dataLenVmAddr) { + return CallerAccount{}, SyscallErrInvalidPointer + } + + refToLenInVm, err := vm.Translate(dataLenVmAddr, 8, true) if err != nil { return CallerAccount{}, err } - serializedData, err := vm.Translate(dataBox.Addr, dataBox.Len, false) + dataLen := accountDataLenForCpi(execCtx, refToLenInVm, dataBox.Len) + err = checkCpiDataLength(execCtx, dataLen, originalDataLen, isLoaderDeprecated) if err != nil { return CallerAccount{}, err } - refToLenInVm, err := vm.Translate(safemath.SaturatingAddU64(accountInfo.DataBoxAddr, 32), 8, true) + cost := dataLen / cu.CUCpiBytesPerUnit + err = execCtx.ComputeMeter.Consume(cost) if err != nil { return CallerAccount{}, err } - callerAcct := CallerAccount{Lamports: lamports, Owner: owner, OriginalDataLen: dataBox.Len, + serializedData, err := translateSerializedAccountData(vm, execCtx, dataBox.Addr, dataLen) + if err != nil { + return CallerAccount{}, err + } + + callerAcct := CallerAccount{Lamports: lamports, Owner: owner, OriginalDataLen: originalDataLen, SerializedData: serializedData, VmDataAddr: dataBox.Addr, RefToLenInVm: refToLenInVm} return callerAcct, nil } -func updateCalleeAccount(execCtx *ExecutionCtx, callerAccount CallerAccount, calleeAccount *BorrowedAccount) error { +func updateCalleeAccount(vm sbpf.VM, execCtx *ExecutionCtx, callerAccount CallerAccount, calleeAccount *BorrowedAccount) (bool, error) { var err error + mustUpdateCaller := false callerLamports := binary.LittleEndian.Uint64(callerAccount.Lamports) if calleeAccount.Account.Lamports != callerLamports { err = calleeAccount.SetLamports(callerLamports, execCtx.Features) if err != nil { - return err + return false, err } } - err1 := calleeAccount.CanDataBeResized(uint64(len(callerAccount.SerializedData))) - err2 := calleeAccount.DataCanBeChanged(execCtx.Features) - - if err1 != nil { - err = err1 - } else if err2 != nil { - err = err2 - } + if execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) { + directMapping := accountDataDirectMappingActive(execCtx) + prevLen := uint64(len(calleeAccount.Data())) + postLen := binary.LittleEndian.Uint64(callerAccount.RefToLenInVm) + if prevLen != postLen { + if !directMapping && postLen < prevLen { + previousSerializedData, translateErr := translateSerializedAccountData(vm, execCtx, callerAccount.VmDataAddr, prevLen) + if translateErr != nil { + return false, translateErr + } + if uint64(len(previousSerializedData)) < prevLen { + return false, InstrErrAccountDataTooSmall + } + for i := postLen; i < prevLen; i++ { + previousSerializedData[i] = 0 + } + } + err = calleeAccount.SetDataLength(postLen, execCtx.Features) + if err != nil { + return false, err + } + mustUpdateCaller = true + } - // can't change data - if err != nil { - if !bytes.Equal(callerAccount.SerializedData, calleeAccount.Data()) { - return err + if !directMapping && calleeAccount.DataCanBeChanged(execCtx.Features) == nil { + if uint64(len(callerAccount.SerializedData)) < postLen { + return false, InstrErrAccountDataTooSmall + } + err = calleeAccount.SetData(execCtx.Features, callerAccount.SerializedData[:postLen]) + if err != nil { + return false, err + } } - err = nil } else { - err = calleeAccount.SetData(execCtx.Features, callerAccount.SerializedData) + err1 := calleeAccount.CanDataBeResized(uint64(len(callerAccount.SerializedData))) + err2 := calleeAccount.DataCanBeChanged(execCtx.Features) + + if err1 != nil { + err = err1 + } else if err2 != nil { + err = err2 + } + + // can't change data if err != nil { - return err + if !bytes.Equal(callerAccount.SerializedData, calleeAccount.Data()) { + return false, err + } + err = nil + } else { + err = calleeAccount.SetData(execCtx.Features, callerAccount.SerializedData) + if err != nil { + return false, err + } } } if calleeAccount.Owner() != solana.PublicKeyFromBytes(callerAccount.Owner) { err = calleeAccount.SetOwner(execCtx.Features, solana.PublicKeyFromBytes(callerAccount.Owner)) + if err == nil { + mustUpdateCaller = true + } } - return err + return mustUpdateCaller, err } -func updateCallerAccount(vm sbpf.VM, callerAcct *CallerAccount, calleeAcct *BorrowedAccount) error { +func updateCallerAccount(vm sbpf.VM, callerAcct *CallerAccount, calleeAcct *BorrowedAccount, isLoaderDeprecated bool) error { binary.LittleEndian.PutUint64(callerAcct.Lamports, calleeAcct.Lamports()) copy(callerAcct.Owner, calleeAcct.Account.Owner[:]) prevLen := binary.LittleEndian.Uint64(callerAcct.RefToLenInVm) postLen := uint64(len(calleeAcct.Data())) + execCtx := executionCtx(vm) + syscallParameterAddressRestrictions := execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) + directMapping := accountDataDirectMappingActive(execCtx) + maxPermittedIncrease := uint64(MaxPermittedDataIncrease) + if syscallParameterAddressRestrictions && isLoaderDeprecated { + maxPermittedIncrease = 0 + } + addressSpaceReservedForAccount := safemath.SaturatingAddU64(callerAcct.OriginalDataLen, maxPermittedIncrease) - if prevLen != postLen { - // TODO: use constant - maxPermittedIncrease := uint64(10240) - - // account data size increased by too much - if postLen > safemath.SaturatingAddU64(callerAcct.OriginalDataLen, maxPermittedIncrease) { - return InstrErrInvalidRealloc - } + if postLen > addressSpaceReservedForAccount && (syscallParameterAddressRestrictions || prevLen != postLen) { + return InstrErrInvalidRealloc + } - if postLen < prevLen { + if prevLen != postLen { + if !directMapping && postLen < prevLen { if uint64(len(callerAcct.SerializedData)) < postLen { return InstrErrAccountDataTooSmall } @@ -589,12 +798,14 @@ func updateCallerAccount(vm sbpf.VM, callerAcct *CallerAccount, calleeAcct *Borr } } - sd, err := vm.Translate(callerAcct.VmDataAddr, postLen, true) - if err != nil { - return err + if !directMapping { + sd, err := translateSerializedAccountData(vm, execCtx, callerAcct.VmDataAddr, postLen) + if err != nil { + return err + } + callerAcct.SerializedData = sd } - callerAcct.SerializedData = sd binary.LittleEndian.PutUint64(callerAcct.RefToLenInVm, postLen) ptrAddr := safemath.SaturatingSubU64(callerAcct.VmDataAddr, 8) @@ -606,6 +817,10 @@ func updateCallerAccount(vm sbpf.VM, callerAcct *CallerAccount, calleeAcct *Borr binary.LittleEndian.PutUint64(serializedLenSlice, postLen) } + if directMapping { + return nil + } + toSlice := callerAcct.SerializedData fromSlice := calleeAcct.Data() @@ -624,9 +839,40 @@ func updateCallerAccount(vm sbpf.VM, callerAcct *CallerAccount, calleeAcct *Borr return nil } +func updateCallerAccountRegion(vm sbpf.VM, execCtx *ExecutionCtx, callerAcct *CallerAccount, calleeAcct *BorrowedAccount, isLoaderDeprecated bool) error { + reserved := callerAcct.OriginalDataLen + if !isLoaderDeprecated { + reserved = safemath.SaturatingAddU64(reserved, MaxPermittedDataIncrease) + } + if reserved == 0 { + return nil + } + + writable := calleeAcct.DataCanBeChanged(execCtx.Features) == nil + if accountDataDirectMappingActive(execCtx) { + updater, ok := vm.(inputRegionDataUpdater) + if !ok { + return nil + } + if !updater.SetInputRegionData(callerAcct.VmDataAddr, calleeAcct.Data(), uint64(len(calleeAcct.Data())), writable) { + return InstrErrMissingAccount + } + } else { + updater, ok := vm.(inputRegionLengthUpdater) + if !ok { + return nil + } + if !updater.SetInputRegionLength(callerAcct.VmDataAddr, uint64(len(calleeAcct.Data())), writable) { + return InstrErrMissingAccount + } + } + return nil +} + func translateAndUpdateAccountsC(vm sbpf.VM, instructionAccts []InstructionAccount, programIndices []uint64, accountInfoKeys []solana.PublicKey, accountInfos []SolAccountInfoC, accountInfosAddr uint64, isLoaderDeprecated bool) (TranslatedAccounts, error) { execCtx := executionCtx(vm) txCtx := execCtx.TransactionContext + syscallParameterAddressRestrictions := execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) ixCtx, err := txCtx.CurrentInstructionCtx() if err != nil { @@ -647,52 +893,85 @@ func translateAndUpdateAccountsC(vm sbpf.VM, instructionAccts []InstructionAccou if uint64(instructionAcctIdx) != instructionAcct.IndexInCallee { continue } - calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, instructionAcct.IndexInCaller) - if err != nil { - return nil, err - } - defer calleeAcct.Drop() - accountKey, err := txCtx.KeyOfAccountAtIndex(instructionAcct.IndexInTransaction) - if err != nil { - return nil, err - } + err := func() error { + calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, instructionAcct.IndexInCaller) + if err != nil { + return err + } + defer calleeAcct.Drop() - if calleeAcct.IsExecutable() { - cost := uint64(len(calleeAcct.Data()) / cu.CUCpiBytesPerUnit) - err = execCtx.ComputeMeter.Consume(cost) + accountKey, err := txCtx.KeyOfAccountAtIndex(instructionAcct.IndexInTransaction) if err != nil { - return nil, InstrErrComputationalBudgetExceeded + return err } - accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: nil}) - } else { - var found bool + + if calleeAcct.IsExecutable() { + cost := uint64(len(calleeAcct.Data()) / cu.CUCpiBytesPerUnit) + err = execCtx.ComputeMeter.Consume(cost) + if err != nil { + return InstrErrComputationalBudgetExceeded + } + accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: nil}) + return nil + } + for index, accountInfoKey := range accountInfoKeys { - if accountKey == accountInfoKey { - accountInfo := accountInfos[index] - callerAcct, err := callerAccountFromAccountInfoC(vm, execCtx, uint64(index), accountInfo, accountInfosAddr) - if err != nil { - return nil, err - } - err = updateCalleeAccount(execCtx, callerAcct, calleeAcct) + if accountKey != accountInfoKey { + continue + } + + accountInfo := accountInfos[index] + callerAcct, err := callerAccountFromAccountInfoC(vm, execCtx, instructionAcct.IndexInCaller, uint64(index), accountInfo, accountInfosAddr, isLoaderDeprecated) + if err != nil { + return err + } + + mustUpdateCaller := false + if !syscallParameterAddressRestrictions { + mustUpdateCaller, err = updateCalleeAccount(vm, execCtx, callerAcct, calleeAcct) if err != nil { - return nil, err + return err } + } - var c *CallerAccount - if instructionAcct.IsWritable { - c = &callerAcct - } else { - c = nil - } - accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: c}) - found = true - break + var c *CallerAccount + if instructionAcct.IsWritable || syscallParameterAddressRestrictions { + c = &callerAcct } + accounts = append(accounts, TranslatedAccount{ + IndexOfAccount: instructionAcct.IndexInCaller, + CallerAccount: c, + UpdateCallerAccount: instructionAcct.IsWritable, + UpdateCallerRegion: instructionAcct.IsWritable || mustUpdateCaller || syscallParameterAddressRestrictions, + }) + return nil } - if !found { - return nil, InstrErrMissingAccount + + return InstrErrMissingAccount + }() + if err != nil { + return nil, err + } + } + + if syscallParameterAddressRestrictions { + for accountIdx := range accounts { + account := &accounts[accountIdx] + if account.CallerAccount == nil { + continue + } + + calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, account.IndexOfAccount) + if err != nil { + return nil, err } + mustUpdateCaller, err := updateCalleeAccount(vm, execCtx, *account.CallerAccount, calleeAcct) + calleeAcct.Drop() + if err != nil { + return nil, err + } + account.UpdateCallerRegion = account.UpdateCallerAccount || mustUpdateCaller } } @@ -702,6 +981,7 @@ func translateAndUpdateAccountsC(vm sbpf.VM, instructionAccts []InstructionAccou func translateAndUpdateAccountsRust(vm sbpf.VM, instructionAccts []InstructionAccount, programIndices []uint64, accountInfoKeys []solana.PublicKey, accountInfos []SolAccountInfoRust, accountInfosAddr uint64, isLoaderDeprecated bool) (TranslatedAccounts, error) { execCtx := executionCtx(vm) txCtx := execCtx.TransactionContext + syscallParameterAddressRestrictions := execCtx.Features.IsActive(features.SyscallParameterAddressRestrictions) ixCtx, err := txCtx.CurrentInstructionCtx() if err != nil { @@ -722,52 +1002,84 @@ func translateAndUpdateAccountsRust(vm sbpf.VM, instructionAccts []InstructionAc continue } - calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, instructionAcct.IndexInCaller) - if err != nil { - return nil, err - } - defer calleeAcct.Drop() - - accountKey, err := txCtx.KeyOfAccountAtIndex(instructionAcct.IndexInTransaction) - if err != nil { - return nil, err - } + err := func() error { + calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, instructionAcct.IndexInCaller) + if err != nil { + return err + } + defer calleeAcct.Drop() - if calleeAcct.IsExecutable() { - cost := uint64(len(calleeAcct.Data()) / cu.CUCpiBytesPerUnit) - err = execCtx.ComputeMeter.Consume(cost) + accountKey, err := txCtx.KeyOfAccountAtIndex(instructionAcct.IndexInTransaction) if err != nil { - return nil, InstrErrComputationalBudgetExceeded + return err } - accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: nil}) - } else { - var found bool + + if calleeAcct.IsExecutable() { + cost := uint64(len(calleeAcct.Data()) / cu.CUCpiBytesPerUnit) + err = execCtx.ComputeMeter.Consume(cost) + if err != nil { + return InstrErrComputationalBudgetExceeded + } + accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: nil}) + return nil + } + for index, accountInfoKey := range accountInfoKeys { - if accountKey == accountInfoKey { - accountInfo := accountInfos[index] - callerAcct, err := callerAccountFromAccountInfoRust(vm, execCtx, accountInfo) - if err != nil { - return nil, err - } - err = updateCalleeAccount(execCtx, callerAcct, calleeAcct) + if accountKey != accountInfoKey { + continue + } + + accountInfo := accountInfos[index] + callerAcct, err := callerAccountFromAccountInfoRust(vm, execCtx, instructionAcct.IndexInCaller, accountInfo, isLoaderDeprecated) + if err != nil { + return err + } + + mustUpdateCaller := false + if !syscallParameterAddressRestrictions { + mustUpdateCaller, err = updateCalleeAccount(vm, execCtx, callerAcct, calleeAcct) if err != nil { - return nil, err + return err } + } - var c *CallerAccount - if instructionAcct.IsWritable { - c = &callerAcct - } else { - c = nil - } - accounts = append(accounts, TranslatedAccount{IndexOfAccount: instructionAcct.IndexInCaller, CallerAccount: c}) - found = true - break + var c *CallerAccount + if instructionAcct.IsWritable || syscallParameterAddressRestrictions { + c = &callerAcct } + accounts = append(accounts, TranslatedAccount{ + IndexOfAccount: instructionAcct.IndexInCaller, + CallerAccount: c, + UpdateCallerAccount: instructionAcct.IsWritable, + UpdateCallerRegion: instructionAcct.IsWritable || mustUpdateCaller || syscallParameterAddressRestrictions, + }) + return nil + } + + return InstrErrMissingAccount + }() + if err != nil { + return nil, err + } + } + + if syscallParameterAddressRestrictions { + for accountIdx := range accounts { + account := &accounts[accountIdx] + if account.CallerAccount == nil { + continue + } + + calleeAcct, err := ixCtx.BorrowInstructionAccount(txCtx, account.IndexOfAccount) + if err != nil { + return nil, err } - if !found { - return nil, InstrErrMissingAccount + mustUpdateCaller, err := updateCalleeAccount(vm, execCtx, *account.CallerAccount, calleeAcct) + calleeAcct.Drop() + if err != nil { + return nil, err } + account.UpdateCallerRegion = account.UpdateCallerAccount || mustUpdateCaller } } @@ -851,14 +1163,32 @@ func SyscallInvokeSignedCImpl(vm sbpf.VM, instructionAddr, accountInfosAddr, acc } for _, acct := range accounts { - if acct.CallerAccount != nil { + if acct.CallerAccount != nil && acct.UpdateCallerAccount { var calleeAcct *BorrowedAccount calleeAcct, err = instructionCtx.BorrowInstructionAccount(txCtx, acct.IndexOfAccount) if err != nil { return syscallErr(err) } - defer calleeAcct.Drop() - err = updateCallerAccount(vm, acct.CallerAccount, calleeAcct) + err = updateCallerAccount(vm, acct.CallerAccount, calleeAcct, isLoaderDeprecated) + calleeAcct.Drop() + if err != nil { + return syscallErr(err) + } + } + } + + if execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) { + for _, acct := range accounts { + if acct.CallerAccount == nil || !acct.UpdateCallerRegion { + continue + } + var calleeAcct *BorrowedAccount + calleeAcct, err = instructionCtx.BorrowInstructionAccount(txCtx, acct.IndexOfAccount) + if err != nil { + return syscallErr(err) + } + err = updateCallerAccountRegion(vm, execCtx, acct.CallerAccount, calleeAcct, isLoaderDeprecated) + calleeAcct.Drop() if err != nil { return syscallErr(err) } @@ -929,14 +1259,32 @@ func SyscallInvokeSignedRustImpl(vm sbpf.VM, instructionAddr, accountInfosAddr, } for _, acct := range accounts { - if acct.CallerAccount != nil { + if acct.CallerAccount != nil && acct.UpdateCallerAccount { var calleeAcct *BorrowedAccount calleeAcct, err = instructionCtx.BorrowInstructionAccount(txCtx, acct.IndexOfAccount) if err != nil { return syscallErr(err) } - defer calleeAcct.Drop() - err = updateCallerAccount(vm, acct.CallerAccount, calleeAcct) + err = updateCallerAccount(vm, acct.CallerAccount, calleeAcct, isLoaderDeprecated) + calleeAcct.Drop() + if err != nil { + return syscallErr(err) + } + } + } + + if execCtx.Features.IsActive(features.VirtualAddressSpaceAdjustments) { + for _, acct := range accounts { + if acct.CallerAccount == nil || !acct.UpdateCallerRegion { + continue + } + var calleeAcct *BorrowedAccount + calleeAcct, err = instructionCtx.BorrowInstructionAccount(txCtx, acct.IndexOfAccount) + if err != nil { + return syscallErr(err) + } + err = updateCallerAccountRegion(vm, execCtx, acct.CallerAccount, calleeAcct, isLoaderDeprecated) + calleeAcct.Drop() if err != nil { return syscallErr(err) } diff --git a/pkg/sealevel/syscalls_cpi_0459_test.go b/pkg/sealevel/syscalls_cpi_0459_test.go new file mode 100644 index 00000000..0b4ddc26 --- /dev/null +++ b/pkg/sealevel/syscalls_cpi_0459_test.go @@ -0,0 +1,58 @@ +package sealevel + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/cu" + feat "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/stretchr/testify/require" +) + +func TestSyscallParameterAddressRangeRestricted(t *testing.T) { + features := feat.NewFeaturesDefault() + execCtx := &ExecutionCtx{Features: *features} + + require.False(t, syscallParameterAddressRangeRestricted(execCtx, sbpf.VaddrInput-1, 1)) + + features.EnableFeature(feat.SyscallParameterAddressRestrictions, 0) + require.True(t, syscallParameterAddressRangeRestricted(execCtx, sbpf.VaddrInput, 0)) + require.True(t, syscallParameterAddressRangeRestricted(execCtx, sbpf.VaddrInput-1, 1)) + require.False(t, syscallParameterAddressRangeRestricted(execCtx, sbpf.VaddrInput-2, 1)) +} + +func TestCheckCpiDataLengthSyscallParameterAddressRestrictions(t *testing.T) { + features := feat.NewFeaturesDefault() + execCtx := &ExecutionCtx{Features: *features} + + require.NoError(t, checkCpiDataLength(execCtx, MaxPermittedDataIncrease+2, 1, true)) + + features.EnableFeature(feat.SyscallParameterAddressRestrictions, 0) + require.NoError(t, checkCpiDataLength(execCtx, 1, 1, true)) + require.ErrorIs(t, checkCpiDataLength(execCtx, 2, 1, true), InstrErrInvalidRealloc) + require.NoError(t, checkCpiDataLength(execCtx, MaxPermittedDataIncrease+1, 1, false)) + require.ErrorIs(t, checkCpiDataLength(execCtx, MaxPermittedDataIncrease+2, 1, false), InstrErrInvalidRealloc) +} + +func TestTranslateSerializedAccountDataDirectMappingSkipsSerializedBytes(t *testing.T) { + features := feat.NewFeaturesDefault() + features.EnableFeature(feat.VirtualAddressSpaceAdjustments, 0) + features.EnableFeature(feat.AccountDataDirectMapping, 0) + execCtx := &ExecutionCtx{Features: *features} + + meter := cu.NewComputeMeter(1) + vm := sbpf.NewInterpreter(&sbpf.Program{TextVA: sbpf.VaddrProgram, Funcs: map[uint32]int64{}}, &sbpf.VMOpts{ + Input: []byte{0x42}, + Context: execCtx, + ComputeMeter: &meter, + }) + defer vm.Finish() + + data, err := translateSerializedAccountData(vm, execCtx, sbpf.VaddrInput+42, 1) + require.NoError(t, err) + require.Nil(t, data) + + features.DisableFeature(feat.AccountDataDirectMapping) + _, err = translateSerializedAccountData(vm, execCtx, sbpf.VaddrInput+42, 1) + require.Error(t, err) +} diff --git a/pkg/sealevel/syscalls_curve.go b/pkg/sealevel/syscalls_curve.go index 667a1fc0..fe79930f 100644 --- a/pkg/sealevel/syscalls_curve.go +++ b/pkg/sealevel/syscalls_curve.go @@ -314,7 +314,7 @@ func SyscallCurveMultiscalarMultiplicationImpl(vm sbpf.VM, curveId, scalarsAddr, scalars, err := unmarshalEdwardsScalars(scalarsBytes) if err != nil { - return syscallErr(err) + return syscallSuccess(1) } pointsBytes, err := vm.Translate(pointsAddr, pointsLen*CurvePointBytesLen, false) @@ -324,7 +324,7 @@ func SyscallCurveMultiscalarMultiplicationImpl(vm sbpf.VM, curveId, scalarsAddr, points, err := unmarshalEdwardsPoints(pointsBytes) if err != nil { - return syscallErr(err) + return syscallSuccess(1) } resultPoint := edwards25519.NewIdentityPoint() @@ -355,7 +355,7 @@ func SyscallCurveMultiscalarMultiplicationImpl(vm sbpf.VM, curveId, scalarsAddr, scalars, err := unmarshalRistrettoScalars(scalarsBytes) if err != nil { - return syscallErr(err) + return syscallSuccess(1) } pointsBytes, err := vm.Translate(pointsAddr, pointsLen*CurvePointBytesLen, false) @@ -365,7 +365,7 @@ func SyscallCurveMultiscalarMultiplicationImpl(vm sbpf.VM, curveId, scalarsAddr, points, err := unmarshalRistrettoElements(pointsBytes) if err != nil { - return syscallErr(err) + return syscallSuccess(1) } resultPoint := ristretto255.NewElement().MultiScalarMult(scalars, points) diff --git a/pkg/sealevel/syscalls_gen/main.go b/pkg/sealevel/syscalls_gen/main.go index 5f8c488b..88a140a5 100644 --- a/pkg/sealevel/syscalls_gen/main.go +++ b/pkg/sealevel/syscalls_gen/main.go @@ -26,7 +26,7 @@ func main() { {"sol_log_data", "SyscallLogData", ""}, {"sol_sha256", "SyscallSha256", ""}, {"sol_keccak256", "SyscallKeccak256", ""}, - {"sol_blake3", "SyscallBlake3", ""}, + {"sol_blake3", "SyscallBlake3", "ft.IsActive(features.Blake3SyscallEnabled)"}, {"sol_secp256k1_recover", "SyscallSecp256k1Recover", ""}, {"sol_poseidon", "SyscallPoseidon", ""}, {"sol_curve_validate_point", "SyscallValidatePoint", "ft.IsActive(features.Curve25519SyscallEnabled)"}, diff --git a/pkg/sealevel/syscalls_hash_test.go b/pkg/sealevel/syscalls_hash_test.go new file mode 100644 index 00000000..96bb2b91 --- /dev/null +++ b/pkg/sealevel/syscalls_hash_test.go @@ -0,0 +1,21 @@ +package sealevel + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/features" + "github.com/Overclock-Validator/mithril/pkg/sbpf" + "github.com/stretchr/testify/assert" +) + +func TestSyscalls_Blake3RequiresFeature(t *testing.T) { + ft := features.NewFeaturesDefault() + + _, ok := Syscalls(ft, false, sbpf.SymbolHash("sol_blake3")) + assert.False(t, ok) + + ft.EnableFeature(features.Blake3SyscallEnabled, 0) + + _, ok = Syscalls(ft, false, sbpf.SymbolHash("sol_blake3")) + assert.True(t, ok) +} diff --git a/pkg/sealevel/syscalls_log.go b/pkg/sealevel/syscalls_log.go index d9b350b2..cf03112b 100644 --- a/pkg/sealevel/syscalls_log.go +++ b/pkg/sealevel/syscalls_log.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/base64" "fmt" + "unicode/utf8" //"github.com/Overclock-Validator/mithril/pkg/mlog" "github.com/Overclock-Validator/mithril/pkg/cu" @@ -27,6 +28,9 @@ func SyscallLogImpl(vm sbpf.VM, ptr, strlen uint64) (uint64, error) { if err = vm.Read(ptr, buf); err != nil { return syscallErr(err) } + if !utf8.Valid(buf) { + return syscallErr(SyscallErrInvalidString) + } execCtx.Log.Log("Program log: " + string(buf)) diff --git a/pkg/sealevel/syscalls_log_test.go b/pkg/sealevel/syscalls_log_test.go new file mode 100644 index 00000000..a05e504a --- /dev/null +++ b/pkg/sealevel/syscalls_log_test.go @@ -0,0 +1,18 @@ +package sealevel + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/cu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSyscallLogRejectsInvalidUTF8(t *testing.T) { + vm := newBls12_381SyscallTestVM(4) + vm.mem[0] = 0xff + + _, err := SyscallLogImpl(vm, 0, 1) + require.ErrorIs(t, err, SyscallErrInvalidString) + assert.Equal(t, uint64(1_000_000-cu.CUSyscallBaseCost), vm.ComputeMeter().Remaining()) +} diff --git a/pkg/sealevel/syscalls_pda.go b/pkg/sealevel/syscalls_pda.go index 4309f0c9..8f36cd05 100644 --- a/pkg/sealevel/syscalls_pda.go +++ b/pkg/sealevel/syscalls_pda.go @@ -2,7 +2,7 @@ package sealevel import ( "bytes" - + "errors" "math" "github.com/Overclock-Validator/mithril/pkg/cu" @@ -18,6 +18,10 @@ func translateAndValidateSeeds(vm sbpf.VM, seedsAddr, seedsLen uint64) ([][]byte return nil, SyscallErrMaxSeedLengthExceeded } + if syscallAddressRequiresAlignment(executionCtx(vm), seedsAddr, 8) { + return nil, errors.New("SyscallError::UnalignedPointer") + } + seedsData, err := vm.Translate(seedsAddr, seedsLen*16, false) if err != nil { return nil, err diff --git a/pkg/sealevel/syscalls_pda_test.go b/pkg/sealevel/syscalls_pda_test.go new file mode 100644 index 00000000..82073a92 --- /dev/null +++ b/pkg/sealevel/syscalls_pda_test.go @@ -0,0 +1,23 @@ +package sealevel + +import ( + "encoding/binary" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/cu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSyscallTryFindProgramAddressRejectsUnalignedSeedVector(t *testing.T) { + vm := newBls12_381SyscallTestVM(128) + + seedsAddr := uint64(1) + binary.LittleEndian.PutUint64(vm.mem[seedsAddr:seedsAddr+8], 24) + binary.LittleEndian.PutUint64(vm.mem[seedsAddr+8:seedsAddr+16], 1) + vm.mem[24] = 'x' + + _, err := SyscallTryFindProgramAddressImpl(vm, seedsAddr, 1, 40, 72, 104) + require.EqualError(t, err, "SyscallError::UnalignedPointer") + assert.Equal(t, uint64(1_000_000-cu.CUCreateProgramAddressUnits), vm.ComputeMeter().Remaining()) +} diff --git a/pkg/sealevel/syscalls_sysvar.go b/pkg/sealevel/syscalls_sysvar.go index 4dead1c9..f68b649e 100644 --- a/pkg/sealevel/syscalls_sysvar.go +++ b/pkg/sealevel/syscalls_sysvar.go @@ -25,6 +25,12 @@ func SyscallGetClockSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error) { if err != nil { return syscallCuErr() } + if !syscallAddressIsAligned(execCtx, addr, 8) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, addr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } var clockDst []byte clockDst, err = vm.Translate(addr, SysvarClockStructLen, true) @@ -58,6 +64,12 @@ func SyscallGetRentSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error) { if err != nil { return syscallCuErr() } + if !syscallAddressIsAligned(execCtx, addr, 8) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, addr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } rentDst, err := vm.Translate(addr, SysvarRentStructLen, true) if err != nil { @@ -87,6 +99,12 @@ func SyscallGetEpochScheduleSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error) if err != nil { return syscallCuErr() } + if !syscallAddressIsAligned(execCtx, addr, 8) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, addr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } epochScheduleDst, err := vm.Translate(addr, SysvarEpochScheduleStructLen, true) if err != nil { @@ -133,6 +151,12 @@ func SyscallGetEpochRewardsSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error) { if err != nil { return syscallCuErr() } + if !syscallAddressIsAligned(execCtx, addr, 16) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, addr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } epochRewardsDst, err := vm.Translate(addr, SysvarEpochRewardsStructLen, true) if err != nil { @@ -141,7 +165,7 @@ func SyscallGetEpochRewardsSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error) { epochRewards, err := ReadEpochRewardsSysvar(execCtx) if err != nil { - return syscallSuccess(1) + return syscallErr(err) } binary.LittleEndian.PutUint64(epochRewardsDst[:8], epochRewards.DistributionStartingBlockHeight) @@ -175,13 +199,22 @@ func SyscallGetLastRestartSlotSysvarImpl(vm sbpf.VM, addr uint64) (uint64, error if err != nil { return syscallCuErr() } + if !syscallAddressIsAligned(execCtx, addr, 8) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, addr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } lastRestartSlotDst, err := vm.Translate(addr, SysvarLastRestartSlotStructLen, true) if err != nil { return syscallErr(err) } - lrs := ReadLastRestartSlotSysvar(execCtx) + lrs, err := ReadLastRestartSlotSysvar(execCtx) + if err != nil { + return syscallErr(err) + } binary.LittleEndian.PutUint64(lastRestartSlotDst[:8], lrs.LastRestartSlot) return syscallSuccess(0) @@ -197,11 +230,20 @@ const ( var permittedSysvarAddrs = []solana.PublicKey{SysvarClockAddr, SysvarEpochScheduleAddr, SysvarEpochRewardsAddr, SysvarRentAddr, SysvarSlotHashesAddr, SysvarStakeHistoryAddr, SysvarLastRestartSlotAddr} -func fetchSysvarBytesForPubkey(pubkey solana.PublicKey) ([]byte, error) { +func fetchSysvarBytesForPubkey(execCtx *ExecutionCtx, pubkey solana.PublicKey) ([]byte, error) { if !slices.Contains(permittedSysvarAddrs, pubkey) { return nil, fmt.Errorf("unrecognised sysvar") } + accts := addrObjectForLookup(execCtx) + if accts != nil && *accts != nil { + key := [32]byte(pubkey) + acct, err := (*accts).GetAccount(&key) + if err == nil { + return acct.Data, nil + } + } + var sysvarAcct *accounts.Account if pubkey == SysvarClockAddr { sysvarAcct = SysvarCache.Clock.Acct @@ -218,6 +260,9 @@ func fetchSysvarBytesForPubkey(pubkey solana.PublicKey) ([]byte, error) { } else if pubkey == SysvarLastRestartSlotAddr { sysvarAcct = SysvarCache.LastRestartSlot.Acct } + if sysvarAcct == nil { + return nil, fmt.Errorf("sysvar account not found") + } return sysvarAcct.Data, nil } @@ -233,17 +278,23 @@ func SyscallGetSysvarImpl(vm sbpf.VM, sysvarIdAddr uint64, varAddr uint64, offse if err != nil { return syscallCuErr() } + if !syscallCheckAligned(execCtx) { + return syscallErrCustom("SyscallError::UnalignedPointer") + } + if syscallParameterAddressRestricted(execCtx, varAddr) { + return syscallErrCustom("SyscallError::InvalidPointer") + } - sysvarIdBytes, err := vm.Translate(sysvarIdAddr, 32, false) + varBuf, err := vm.Translate(varAddr, length, true) if err != nil { return syscallErr(err) } - sysvarId := solana.PublicKeyFromBytes(sysvarIdBytes) - varBuf, err := vm.Translate(varAddr, length, true) + sysvarIdBytes, err := vm.Translate(sysvarIdAddr, 32, false) if err != nil { return syscallErr(err) } + sysvarId := solana.PublicKeyFromBytes(sysvarIdBytes) offsetLen, err := safemath.CheckedAddU64(offset, length) if err != nil { @@ -255,7 +306,7 @@ func SyscallGetSysvarImpl(vm sbpf.VM, sysvarIdAddr uint64, varAddr uint64, offse return syscallErr(InstrErrArithmeticOverflow) } - sysvarBuf, err := fetchSysvarBytesForPubkey(sysvarId) + sysvarBuf, err := fetchSysvarBytesForPubkey(execCtx, sysvarId) if err != nil { return syscallSuccess(sysvarNotFound) } diff --git a/pkg/sealevel/sysvar_epoch_rewards.go b/pkg/sealevel/sysvar_epoch_rewards.go index da7093e3..80cf0e84 100644 --- a/pkg/sealevel/sysvar_epoch_rewards.go +++ b/pkg/sealevel/sysvar_epoch_rewards.go @@ -67,7 +67,7 @@ func (ser *SysvarEpochRewards) UnmarshalWithDecoder(decoder *bin.Decoder) error return fmt.Errorf("failed to read DistributedRewards when decoding SysvarEpochRewards: %w", err) } - ser.Active, err = decoder.ReadBool() + ser.Active, err = ReadBool(decoder) return err } @@ -154,7 +154,10 @@ func ReadEpochRewardsSysvar(execCtx *ExecutionCtx) (SysvarEpochRewards, error) { dec := bin.NewBinDecoder(epochRewardsSysvarAcct.Data) var epochRewards SysvarEpochRewards - epochRewards.MustUnmarshalWithDecoder(dec) + err = epochRewards.UnmarshalWithDecoder(dec) + if err != nil { + return SysvarEpochRewards{}, InstrErrUnsupportedSysvar + } return epochRewards, nil } diff --git a/pkg/sealevel/sysvar_last_restart_slot.go b/pkg/sealevel/sysvar_last_restart_slot.go index 6fbdc46f..95fd9fca 100644 --- a/pkg/sealevel/sysvar_last_restart_slot.go +++ b/pkg/sealevel/sysvar_last_restart_slot.go @@ -35,20 +35,27 @@ func (sr *SysvarLastRestartSlot) MustUnmarshalWithDecoder(decoder *bin.Decoder) } } -func ReadLastRestartSlotSysvar(execCtx *ExecutionCtx) SysvarLastRestartSlot { +func ReadLastRestartSlotSysvar(execCtx *ExecutionCtx) (SysvarLastRestartSlot, error) { + if SysvarCache.LastRestartSlot.Sysvar != nil { + return *SysvarCache.LastRestartSlot.Sysvar, nil + } + accts := addrObjectForLookup(execCtx) lrsAcct, err := (*accts).GetAccount(&SysvarLastRestartSlotAddr) if err != nil { - panic("failed to read LastRestartSlot sysvar account") + return SysvarLastRestartSlot{}, InstrErrUnsupportedSysvar } dec := bin.NewBinDecoder(lrsAcct.Data) var lrs SysvarLastRestartSlot - lrs.MustUnmarshalWithDecoder(dec) + err = lrs.UnmarshalWithDecoder(dec) + if err != nil { + return SysvarLastRestartSlot{}, InstrErrUnsupportedSysvar + } - return lrs + return lrs, nil } func WriteLastRestartSlotSysvar(accts *accounts.Accounts, lastRestartSlot SysvarLastRestartSlot) { diff --git a/pkg/sealevel/types.go b/pkg/sealevel/types.go index a6c840e1..55842937 100644 --- a/pkg/sealevel/types.go +++ b/pkg/sealevel/types.go @@ -128,8 +128,10 @@ type RefCellVecRust struct { type TranslatedAccounts []TranslatedAccount type TranslatedAccount struct { - IndexOfAccount uint64 - CallerAccount *CallerAccount + IndexOfAccount uint64 + CallerAccount *CallerAccount + UpdateCallerAccount bool + UpdateCallerRegion bool } type CallerAccount struct { @@ -157,6 +159,9 @@ func (accountMeta *AccountMeta) Unmarshal(buf io.Reader) error { } copy(accountMeta.Pubkey[:], accountMetaBytes[:32]) + if accountMetaBytes[32] > 1 || accountMetaBytes[33] > 1 { + return InstrErrInvalidArgument + } accountMeta.IsSigner = accountMetaBytes[32] != 0 accountMeta.IsWritable = accountMetaBytes[33] != 0 From 4b1df309d4a692689a04ff0aa7b8320978c39c6b Mon Sep 17 00:00:00 2001 From: smcio Date: Wed, 20 May 2026 17:29:02 +0200 Subject: [PATCH 4/4] forkchoice & block source fixes, quic tpu client for rpcserver, diagnostics/debugging --- cmd/mithril/node/node.go | 132 +++++--- go.mod | 1 + go.sum | 2 + pkg/blockstream/block_source.go | 229 +++++++++++-- pkg/blockstream/block_source_test.go | 242 ++++++++++++++ pkg/forkchoice/forkchoice.go | 71 ++++ pkg/forkchoice/forkchoice_test.go | 171 +++++++++- pkg/forkchoice/vote_parser.go | 73 ++++- pkg/replay/block.go | 233 +++++-------- pkg/replay/consensus.go | 111 +++++++ pkg/replay/consensus_fallback.go | 14 - pkg/replay/diagnostics.go | 361 +++++++++++++++++++++ pkg/replay/epoch.go | 5 + pkg/replay/epoch_authorized_voters_test.go | 36 ++ pkg/replay/epoch_schedule.go | 43 +++ pkg/replay/epoch_stakes_seed.go | 126 +++++++ pkg/replay/epoch_stakes_seed_test.go | 86 +++++ pkg/replay/sysvar.go | 17 +- pkg/replay/sysvar_clock_test.go | 52 +++ pkg/replay/transaction.go | 7 + pkg/rpcserver/rpcserver.go | 19 +- pkg/rpcserver/send_transaction.go | 112 +++++-- pkg/rpcserver/send_transaction_test.go | 50 ++- pkg/rpcserver/tpu_quic.go | 192 +++++++++++ pkg/rpcserver/tpu_quic_test.go | 60 ++++ pkg/snapshot/manifest_seed.go | 7 + pkg/snapshot/manifest_seed_test.go | 51 +++ pkg/state/state.go | 26 +- 28 files changed, 2227 insertions(+), 302 deletions(-) create mode 100644 pkg/replay/consensus.go delete mode 100644 pkg/replay/consensus_fallback.go create mode 100644 pkg/replay/diagnostics.go create mode 100644 pkg/replay/epoch_authorized_voters_test.go create mode 100644 pkg/replay/epoch_schedule.go create mode 100644 pkg/replay/epoch_stakes_seed.go create mode 100644 pkg/replay/epoch_stakes_seed_test.go create mode 100644 pkg/replay/sysvar_clock_test.go create mode 100644 pkg/rpcserver/tpu_quic.go create mode 100644 pkg/rpcserver/tpu_quic_test.go create mode 100644 pkg/snapshot/manifest_seed_test.go diff --git a/cmd/mithril/node/node.go b/cmd/mithril/node/node.go index e2b84eb8..cfa011f2 100644 --- a/cmd/mithril/node/node.go +++ b/cmd/mithril/node/node.go @@ -113,6 +113,87 @@ var ( lightbringerQuiet bool ) +func snapshotEpochForState(manifest *snapshot.SnapshotManifest) uint64 { + if manifest == nil || manifest.Bank == nil { + return 0 + } + if manifest.Bank.EpochSchedule.SlotsPerEpoch != 0 { + epoch := manifest.Bank.EpochSchedule.GetEpoch(manifest.Bank.Slot) + if manifest.Bank.Epoch != epoch { + mlog.Log.Warnf("manifest bank epoch %d differs from manifest epoch schedule epoch %d at slot %d; using schedule-derived epoch", + manifest.Bank.Epoch, epoch, manifest.Bank.Slot) + } + return epoch + } + + return manifest.Bank.Epoch +} + +func epochScheduleFromState(s *state.MithrilState) *sealevel.SysvarEpochSchedule { + if s != nil && s.ManifestEpochSchedule != nil && s.ManifestEpochSchedule.SlotsPerEpoch != 0 { + return &sealevel.SysvarEpochSchedule{ + SlotsPerEpoch: s.ManifestEpochSchedule.SlotsPerEpoch, + LeaderScheduleSlotOffset: s.ManifestEpochSchedule.LeaderScheduleSlotOffset, + Warmup: s.ManifestEpochSchedule.Warmup, + FirstNormalEpoch: s.ManifestEpochSchedule.FirstNormalEpoch, + FirstNormalSlot: s.ManifestEpochSchedule.FirstNormalSlot, + } + } + return sealevel.SysvarCache.EpochSchedule.Sysvar +} + +func epochForStateSlot(s *state.MithrilState, slot uint64) uint64 { + if epochSchedule := epochScheduleFromState(s); epochSchedule != nil { + return epochSchedule.GetEpoch(slot) + } + return 0 +} + +func manifestEpochScheduleSeedMatches(s *state.MithrilState, manifest *snapshot.SnapshotManifest) bool { + if s == nil || s.ManifestEpochSchedule == nil || manifest == nil || manifest.Bank == nil { + return false + } + m := manifest.Bank.EpochSchedule + return s.ManifestEpochSchedule.SlotsPerEpoch == m.SlotsPerEpoch && + s.ManifestEpochSchedule.LeaderScheduleSlotOffset == m.LeaderScheduleSlotOffset && + s.ManifestEpochSchedule.Warmup == m.Warmup && + s.ManifestEpochSchedule.FirstNormalEpoch == m.FirstNormalEpoch && + s.ManifestEpochSchedule.FirstNormalSlot == m.FirstNormalSlot +} + +func refreshManifestSeedFromManifest(accountsPath string, s *state.MithrilState, manifest *snapshot.SnapshotManifest) { + if s == nil || manifest == nil || manifest.Bank == nil { + return + } + + snapshotEpoch := snapshotEpochForState(manifest) + if manifestEpochScheduleSeedMatches(s, manifest) && s.SnapshotEpoch == snapshotEpoch { + return + } + + oldSnapshotEpoch := s.SnapshotEpoch + if oldSnapshotEpoch != 0 && oldSnapshotEpoch != snapshotEpoch && s.LastSlot > s.SnapshotSlot { + reason := fmt.Sprintf("snapshot epoch frame changed from %d to %d after replay had already persisted slot %d; rebuild AccountsDB from snapshot", + oldSnapshotEpoch, snapshotEpoch, s.LastSlot) + if err := s.MarkCorrupted(accountsPath, reason); err != nil { + mlog.Log.Errorf("failed to mark state as corrupted: %v", err) + } + klog.Fatalf(reason) + } + + s.SnapshotEpoch = snapshotEpoch + snapshot.PopulateManifestSeed(s, manifest) + if s.LastSlot > 0 { + s.LastEpoch = epochForStateSlot(s, s.LastSlot) + } + if err := s.Save(accountsPath); err != nil { + mlog.Log.Errorf("failed to refresh manifest seed data in state file: %v", err) + return + } + mlog.Log.Warnf("refreshed manifest-derived state seed data from snapshot manifest (snapshot_epoch %d -> %d)", + oldSnapshotEpoch, snapshotEpoch) +} + func init() { // [bootstrap] section flags Run.Flags().StringVar(&bootstrapMode, "bootstrap", "auto", "Bootstrap mode: 'auto' (use AccountsDB if exists, else snapshot), 'accountsdb' (require existing), 'snapshot' (rebuild from snapshot), 'new-snapshot' (always download fresh)") @@ -894,10 +975,7 @@ func runLive(c *cobra.Command, args []string) { } // Write state file - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) // Populate manifest seed data so replay doesn't need manifest at runtime snapshot.PopulateManifestSeed(mithrilState, manifest) @@ -926,6 +1004,7 @@ func runLive(c *cobra.Command, args []string) { if err != nil { klog.Fatalf("failed to load manifest: %v", err) } + refreshManifestSeedFromManifest(accountsPath, mithrilState, manifest) // Run integrity check if we have a state file (warn only, don't fail - user chose force mode) if hasValidState { if err := mithrilState.ValidateAgainstBankhashDB(accountsDb); err != nil { @@ -964,10 +1043,7 @@ func runLive(c *cobra.Command, args []string) { klog.Fatalf("failed to build AccountsDB from snapshot: %v", err) } // Write state file to mark build as complete - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) // Populate manifest seed data so replay doesn't need manifest at runtime snapshot.PopulateManifestSeed(mithrilState, manifest) @@ -1022,10 +1098,7 @@ func runLive(c *cobra.Command, args []string) { klog.Fatalf("failed to build AccountsDB from snapshot: %v", err) } // Write state file to mark build as complete - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) // Populate manifest seed data so replay doesn't need manifest at runtime snapshot.PopulateManifestSeed(mithrilState, manifest) @@ -1094,10 +1167,7 @@ func runLive(c *cobra.Command, args []string) { if err != nil { klog.Fatalf("failed to build AccountsDB from snapshot: %v", err) } - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) // Populate manifest seed data so replay doesn't need manifest at runtime snapshot.PopulateManifestSeed(mithrilState, manifest) @@ -1122,6 +1192,7 @@ func runLive(c *cobra.Command, args []string) { if err != nil { klog.Fatalf("failed to load manifest: %v", err) } + refreshManifestSeedFromManifest(accountsPath, mithrilState, manifest) // Validate state file matches AccountsDB (detect Ctrl+Z / kill -9 corruption) if err := mithrilState.ValidateAgainstBankhashDB(accountsDb); err != nil { @@ -1190,10 +1261,7 @@ func runLive(c *cobra.Command, args []string) { klog.Fatalf("failed to build AccountsDB from snapshot: %v", err) } // Write state file to mark build as complete - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyStateWithOpts(state.NewReadyStateOpts{ SnapshotSlot: manifest.Bank.Slot, SnapshotEpoch: snapshotEpoch, @@ -1312,10 +1380,7 @@ postBootstrap: if mithrilState == nil { // Initialize state for this session - var snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(manifest.Bank.Slot) - } + snapshotEpoch := snapshotEpochForState(manifest) mithrilState = state.NewReadyState(manifest.Bank.Slot, snapshotEpoch, "", "", 0, 0) // Populate manifest seed data so replay doesn't need manifest at runtime snapshot.PopulateManifestSeed(mithrilState, manifest) @@ -1363,7 +1428,7 @@ postBootstrap: if rpcPort < 0 || rpcPort > 65535 { klog.Fatalf("invalid port: %d", rpcPort) } else if rpcPort != 0 { - rpcServer = rpcserver.NewRpcServer(accountsDb, uint16(rpcPort)) + rpcServer = rpcserver.NewRpcServer(accountsDb, uint16(rpcPort), epochScheduleFromState(mithrilState)) rpcServer.Start() mlog.Log.Infof("Started RPC server on port %d", rpcPort) } @@ -1433,10 +1498,7 @@ postBootstrap: var shutdownCtx *state.ShutdownContext if result.LastAcctsLtHash != nil { // Calculate epoch for the last persisted slot - var lastEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - lastEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(result.LastPersistedSlot) - } + lastEpoch := epochForStateSlot(mithrilState, result.LastPersistedSlot) // Determine shutdown reason shutdownReason := state.ShutdownReasonCompleted if result.WasCancelled { @@ -1503,11 +1565,8 @@ postBootstrap: // Print shutdown summary if cancelled or error if (result.WasCancelled || result.Error != nil) && result.LastPersistedSlot > 0 { // Calculate epoch from slot using epoch schedule - var epoch, snapshotEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - epoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(result.LastPersistedSlot) - snapshotEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(snapshotBaseSlot) - } + epoch := epochForStateSlot(mithrilState, result.LastPersistedSlot) + snapshotEpoch := epochForStateSlot(mithrilState, snapshotBaseSlot) progress.PrintShutdownSummary(progress.ShutdownInfo{ LastSlot: result.LastPersistedSlot, LastBankhash: result.LastPersistedBankhash, @@ -2309,10 +2368,7 @@ func runReplayWithRecovery( } // Calculate epoch for the last persisted slot - var lastEpoch uint64 - if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { - lastEpoch = sealevel.SysvarCache.EpochSchedule.Sysvar.GetEpoch(r.LastPersistedSlot) - } + lastEpoch := epochForStateSlot(mithrilState, r.LastPersistedSlot) // Build shutdown context var shutdownCtx *state.ShutdownContext diff --git a/go.mod b/go.mod index ca31970e..2483faa8 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/nixberg/chacha-rng-go v0.1.0 github.com/novifinancial/serde-reflection/serde-generate/runtime/golang v0.0.0-20220519162058-e5cd3c3b3f3a github.com/prometheus/client_golang v1.20.4 + github.com/quic-go/quic-go v0.59.1 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.11.1 github.com/twmb/murmur3 v1.1.8 diff --git a/go.sum b/go.sum index 7b1cd4a5..7ea49725 100644 --- a/go.sum +++ b/go.sum @@ -302,6 +302,8 @@ github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+L github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/quic-go/quic-go v0.59.1 h1:0Gmua0HW1Tv7ANR7hUYwRyD0MG5OJfgvYSZasGZzBic= +github.com/quic-go/quic-go v0.59.1/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= diff --git a/pkg/blockstream/block_source.go b/pkg/blockstream/block_source.go index 85a141f5..256b708c 100644 --- a/pkg/blockstream/block_source.go +++ b/pkg/blockstream/block_source.go @@ -342,19 +342,20 @@ const ( defaultTipGateThreshold = 128 // Lightbringer stream settings - lightbringerDialTimeout = 10 * time.Second - lightbringerRetryBackoff = 2 * time.Second - lightbringerMaxRetryBackoff = 15 * time.Second - lightbringerBufferSlots = 256 - lightbringerFirstSlotWarn = 10 * time.Second - lightbringerIdleReconnect = 30 * time.Second - lightbringerNoEmitReconnect = 30 * time.Second - lightbringerGapReconnectAfter = 30 * time.Second - lightbringerDeepGapReconnect = 15 * time.Second - lightbringerMinHandoffRun = 8 - lightbringerGapFallbackWait = 8 * time.Second - lightbringerGapBufferDepth = 32 - lightbringerRecoverySlots = 0 + lightbringerDialTimeout = 10 * time.Second + lightbringerRetryBackoff = 2 * time.Second + lightbringerMaxRetryBackoff = 15 * time.Second + lightbringerBufferSlots = 256 + lightbringerFirstSlotWarn = 10 * time.Second + lightbringerIdleReconnect = 30 * time.Second + lightbringerNoEmitReconnect = 30 * time.Second + lightbringerGapReconnectAfter = 30 * time.Second + lightbringerDeepGapReconnect = 15 * time.Second + lightbringerMinHandoffRun = 8 + lightbringerLiveEdgeHandoffMaxLag = 4 + lightbringerGapFallbackWait = 8 * time.Second + lightbringerGapBufferDepth = 32 + lightbringerRecoverySlots = 0 // RPC getBlock can transiently report SlotSkipped or "block not available" // for slots that later turn out to have real blocks. Never emit a skipped @@ -516,6 +517,9 @@ func (bs *BlockSource) updateMode() { if wasNearTip { // Currently in near-tip mode - switch to catchup if gap exceeds threshold if gap >= bs.catchupThreshold { + if bs.shouldDeferCatchupForConsensusBufferedLightbringer(gap, lastExecuted, tip) { + return + } bs.isNearTip.Store(false) mlog.Log.Infof("MODE SWITCH: near-tip → CATCHUP | gap=%d (threshold=%d) | exec_slot=%d | tip=%d", gap, bs.catchupThreshold, lastExecuted, tip) @@ -539,6 +543,58 @@ func (bs *BlockSource) updateMode() { } } +func (bs *BlockSource) consensusBufferedLightbringerMaxReplayGap() uint64 { + if bs.catchupThreshold == 0 { + return 0 + } + if bs.catchupThreshold > math.MaxUint64/2 { + return math.MaxUint64 + } + return bs.catchupThreshold * 2 +} + +func (bs *BlockSource) consensusBufferedLightbringerMaxSourceGap() uint64 { + maxGap := bs.nearTipThreshold + if maxGap == 0 || (bs.catchupThreshold > 0 && maxGap > bs.catchupThreshold) { + maxGap = bs.catchupThreshold + } + return maxGap +} + +func (bs *BlockSource) shouldDeferCatchupForConsensusBufferedLightbringer(gap uint64, lastExecuted uint64, tip uint64) bool { + if bs.sourceType != BlockSourceLightbringer || bs.lightbringerEndpoint == "" { + return false + } + if !bs.consensusManagedLightbringer || !bs.lightbringerActive.Load() || !bs.lightbringerConnected.Load() { + return false + } + if maxReplayGap := bs.consensusBufferedLightbringerMaxReplayGap(); maxReplayGap != 0 && gap > maxReplayGap { + return false + } + + latestStreamed := bs.lightbringerLastStreamSlot.Load() + if latestStreamed <= lastExecuted { + return false + } + if tip > latestStreamed { + sourceGap := tip - latestStreamed + if maxSourceGap := bs.consensusBufferedLightbringerMaxSourceGap(); maxSourceGap != 0 && sourceGap > maxSourceGap { + return false + } + } + + lastRecvUnix := bs.lightbringerLastRecvUnix.Load() + if lastRecvUnix == 0 || time.Since(time.Unix(lastRecvUnix, 0)) >= lightbringerIdleReconnect { + return false + } + lastProgressUnix := bs.lastProgress.Load() + if lastProgressUnix != 0 && time.Since(time.Unix(lastProgressUnix, 0)) >= lightbringerNoEmitReconnect { + return false + } + + return true +} + // effectiveTipSafetyMargin returns the tip safety margin for the current mode. // In near-tip mode, we return 0 (no margin) - we rely on fast retries instead. func (bs *BlockSource) effectiveTipSafetyMargin() uint64 { @@ -1140,6 +1196,83 @@ func (bs *BlockSource) purgeRPCStateAtOrBeyondSlot(slot uint64) { bs.retryMu.Unlock() } +func (bs *BlockSource) lightbringerHandoffMaxReplayGap() uint64 { + // Arm Lightbringer only in the lower half of the near-tip window. Once + // forkchoice buffering starts, replay can wait for vote-confirmed path + // resolution; keeping this headroom prevents immediate lost-tip fallback. + maxGap := bs.nearTipThreshold / 2 + if maxGap == 0 && bs.nearTipThreshold > 0 { + maxGap = 1 + } + if bs.catchupThreshold > 0 && maxGap >= bs.catchupThreshold { + maxGap = bs.catchupThreshold - 1 + } + return maxGap +} + +func (bs *BlockSource) lightbringerHandoffTipEstimate() uint64 { + tip := bs.confirmedTip.Load() + if bs.lightbringerConnected.Load() { + if streamed := bs.lightbringerLastStreamSlot.Load(); streamed > tip { + tip = streamed + } + } + return tip +} + +func (bs *BlockSource) lightbringerHandoffReplayGapOK() (bool, uint64, uint64, uint64, uint64) { + maxGap := bs.lightbringerHandoffMaxReplayGap() + tip := bs.lightbringerHandoffTipEstimate() + lastExecuted := bs.lastExecutedSlot.Load() + if tip == 0 || lastExecuted == 0 { + return true, 0, maxGap, tip, lastExecuted + } + + var gap uint64 + if tip > lastExecuted { + gap = tip - lastExecuted + } + return gap <= maxGap, gap, maxGap, tip, lastExecuted +} + +func lightbringerDefaultHandoffLastSlot(waitingSlot uint64) uint64 { + requiredLastSlot := waitingSlot + uint64(lightbringerMinHandoffRun) - 1 + if requiredLastSlot < waitingSlot { + requiredLastSlot = math.MaxUint64 + } + return requiredLastSlot +} + +func (bs *BlockSource) lightbringerLiveEdgeHandoffMaxLag() uint64 { + maxLag := bs.nearTipLookahead + 2 + if maxLag < lightbringerLiveEdgeHandoffMaxLag { + maxLag = lightbringerLiveEdgeHandoffMaxLag + } + return maxLag +} + +func (bs *BlockSource) lightbringerHandoffRequiredLastSlot(waitingSlot uint64) uint64 { + requiredLastSlot := lightbringerDefaultHandoffLastSlot(waitingSlot) + if !bs.consensusManagedLightbringer || !bs.lightbringerConnected.Load() { + return requiredLastSlot + } + + latestStreamed := bs.lightbringerLastStreamSlot.Load() + if latestStreamed < waitingSlot { + return requiredLastSlot + } + + tip := bs.lightbringerHandoffTipEstimate() + if tip > latestStreamed && tip-latestStreamed > bs.lightbringerLiveEdgeHandoffMaxLag() { + return requiredLastSlot + } + + if latestStreamed < requiredLastSlot { + return latestStreamed + } + return requiredLastSlot +} + func (bs *BlockSource) prepareLightbringerHandoff(waitingSlot uint64, anchorSlot uint64) ([]*b.Block, uint64, bool) { if !bs.isNearTip.Load() { return nil, 0, false @@ -1147,6 +1280,9 @@ func (bs *BlockSource) prepareLightbringerHandoff(waitingSlot uint64, anchorSlot if handoffSlot := bs.lightbringerHandoffSlot.Load(); handoffSlot != 0 { return nil, handoffSlot, false } + if ok, _, _, _, _ := bs.lightbringerHandoffReplayGapOK(); !ok { + return nil, 0, false + } bs.lightbringerBufferMu.Lock() defer bs.lightbringerBufferMu.Unlock() @@ -1156,10 +1292,7 @@ func (bs *BlockSource) prepareLightbringerHandoff(waitingSlot uint64, anchorSlot return nil, 0, false } - requiredLastSlot := waitingSlot + uint64(lightbringerMinHandoffRun) - 1 - if requiredLastSlot < waitingSlot { - requiredLastSlot = math.MaxUint64 - } + requiredLastSlot := bs.lightbringerHandoffRequiredLastSlot(waitingSlot) if coveredUntil < requiredLastSlot { return nil, 0, false } @@ -1256,11 +1389,12 @@ func (bs *BlockSource) lightbringerHandoffWaitReason(waitingSlot uint64, anchorS return fmt.Sprintf("no buffered Lightbringer slot at or beyond waiting slot %d", waitingSlot) } - requiredLastSlot := waitingSlot + uint64(lightbringerMinHandoffRun) - 1 - if requiredLastSlot < waitingSlot { - requiredLastSlot = math.MaxUint64 - } + requiredLastSlot := bs.lightbringerHandoffRequiredLastSlot(waitingSlot) if coveredUntil >= requiredLastSlot { + if ok, gap, maxGap, tip, lastExecuted := bs.lightbringerHandoffReplayGapOK(); !ok { + return fmt.Sprintf("handoff-ready runway buffered through slot %d, but replay gap %d exceeds handoff arm threshold %d (last executed %d, live tip estimate %d)", + coveredUntil, gap, maxGap, lastExecuted, tip) + } return fmt.Sprintf("handoff-ready runway buffered through slot %d", coveredUntil) } @@ -1316,6 +1450,10 @@ func (bs *BlockSource) maybePrepareLightbringerHandoff() { anchorSlot := bs.lastEmittedBlockSlot bs.reorderMu.Unlock() + if ok, _, _, _, _ := bs.lightbringerHandoffReplayGapOK(); !ok { + return + } + blocks, handoffSlot, prepared := bs.prepareLightbringerHandoff(waitingSlot, anchorSlot) if !prepared { return @@ -1330,6 +1468,45 @@ func (bs *BlockSource) maybePrepareLightbringerHandoff() { bs.enqueueLightbringerBlocks(blocks) } +func (bs *BlockSource) lightbringerStagingMaxReplayGap() uint64 { + maxGap := bs.catchupThreshold + minGap := bs.nearTipThreshold + uint64(lightbringerMinHandoffRun) + if minGap < bs.nearTipThreshold { + minGap = math.MaxUint64 + } + if maxGap < minGap { + maxGap = minGap + } + return maxGap +} + +func (bs *BlockSource) shouldStageLightbringerSlot(slot uint64) bool { + if bs.lightbringerForceRPCUntil.Load() != 0 { + return false + } + if bs.lightbringerCooldownUntil.Load() != 0 { + return false + } + + lastExecuted := bs.lastExecutedSlot.Load() + if lastExecuted == 0 || slot <= lastExecuted { + return false + } + + bs.reorderMu.Lock() + nextSlot := bs.nextSlotToSend + bs.reorderMu.Unlock() + if nextSlot != 0 && slot < nextSlot { + return false + } + + tip := bs.lightbringerHandoffTipEstimate() + if tip <= lastExecuted { + return false + } + return tip-lastExecuted <= bs.lightbringerStagingMaxReplayGap() +} + func (bs *BlockSource) shouldDecodeLightbringerSlot(slot uint64) bool { if bs.lightbringerForceRPCUntil.Load() != 0 { return false @@ -1339,7 +1516,7 @@ func (bs *BlockSource) shouldDecodeLightbringerSlot(slot uint64) bool { } handoffSlot := bs.lightbringerHandoffSlot.Load() if handoffSlot == 0 { - return bs.isNearTip.Load() + return bs.isNearTip.Load() || bs.shouldStageLightbringerSlot(slot) } return bs.isNearTip.Load() && slot >= handoffSlot } @@ -1665,12 +1842,8 @@ func (bs *BlockSource) runLightbringerStream() { blk := block.FromLightbringerStreamMsg(resp) if bs.lightbringerHandoffSlot.Load() == 0 { - // Keep the stream warm during catchup, but do not buffer live blocks until - // we are actually in near-tip mode. Buffering catchup-time stream traffic - // can create a large backlog that blocks this recv loop at handoff time. - if !bs.isNearTip.Load() { - continue - } + // Stage a bounded runway before near-tip so handoff does not have + // to build its whole connected run while replay is already at tip. bs.bufferLightbringerBlock(blk) bs.maybePrepareLightbringerHandoff() continue diff --git a/pkg/blockstream/block_source_test.go b/pkg/blockstream/block_source_test.go index 40a2e874..9b218013 100644 --- a/pkg/blockstream/block_source_test.go +++ b/pkg/blockstream/block_source_test.go @@ -165,6 +165,66 @@ func TestPrepareLightbringerHandoffRequiresMinimumRunway(t *testing.T) { } } +func TestPrepareLightbringerHandoffAllowsLiveEdgeRunwayAtTip(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + ConsensusManagedLightbringer: true, + }) + + bs.isNearTip.Store(true) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(150) + bs.confirmedTip.Store(151) + bs.lightbringerLastStreamSlot.Store(151) + bs.lastEmittedBlockSlot = 150 + bs.lightbringerBuffer[151] = &b.Block{Slot: 151, FromLightbringer: true, SourceParentSlot: 150} + bs.lightbringerBufferOrder = append(bs.lightbringerBufferOrder, 151) + + reason := bs.lightbringerHandoffWaitReason(151, 150) + if !strings.Contains(reason, "handoff-ready runway buffered through slot 151") { + t.Fatalf("expected live-edge runway to be handoff-ready, got %q", reason) + } + + blocks, handoffSlot, prepared := bs.prepareLightbringerHandoff(151, 150) + if !prepared { + t.Fatalf("expected consensus-managed handoff to prepare at the live edge") + } + if handoffSlot != 151 { + t.Fatalf("expected handoff slot 151, got %d", handoffSlot) + } + if len(blocks) != 1 || blocks[0].Slot != 151 { + t.Fatalf("expected single live-edge Lightbringer block to be enqueued, got %+v", blocks) + } +} + +func TestPrepareLightbringerHandoffKeepsMinimumRunwayWhenLightbringerLagsTip(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + ConsensusManagedLightbringer: true, + }) + + bs.isNearTip.Store(true) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(150) + bs.confirmedTip.Store(157) + bs.lightbringerLastStreamSlot.Store(151) + bs.lastEmittedBlockSlot = 150 + bs.lightbringerBuffer[151] = &b.Block{Slot: 151, FromLightbringer: true, SourceParentSlot: 150} + bs.lightbringerBufferOrder = append(bs.lightbringerBufferOrder, 151) + + blocks, handoffSlot, prepared := bs.prepareLightbringerHandoff(151, 150) + if prepared || handoffSlot != 0 || len(blocks) != 0 { + t.Fatalf("expected stale Lightbringer stream to require the full handoff runway, got prepared=%v handoff=%d blocks=%+v", + prepared, handoffSlot, blocks) + } +} + func TestPrepareLightbringerHandoffRequiresRunwayThroughConfiguredBoundary(t *testing.T) { bs := NewBlockSource(&BlockSourceOpts{ SourceType: BlockSourceLightbringer, @@ -278,6 +338,188 @@ func TestPrepareLightbringerHandoffPurgesRPCOwnedStateAtBoundary(t *testing.T) { } } +func TestMaybePrepareLightbringerHandoffDefersWhenStreamTipShowsReplayGapTooLarge(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + }) + + bs.isNearTip.Store(true) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(101) + bs.confirmedTip.Store(117) + bs.lightbringerLastStreamSlot.Store(118) + bs.lastEmittedBlockSlot = 110 + bs.nextSlotToSend = 111 + for slot := uint64(111); slot <= 118; slot++ { + parentSlot := slot - 1 + if slot == 111 { + parentSlot = 110 + } + bs.lightbringerBuffer[slot] = &b.Block{Slot: slot, FromLightbringer: true, SourceParentSlot: parentSlot} + bs.lightbringerBufferOrder = append(bs.lightbringerBufferOrder, slot) + } + + bs.maybePrepareLightbringerHandoff() + + if got := bs.lightbringerHandoffSlot.Load(); got != 0 { + t.Fatalf("expected handoff to stay unarmed while replay gap exceeds handoff threshold, got %d", got) + } + if queued := len(bs.resultQueue); queued != 0 { + t.Fatalf("expected no Lightbringer blocks to be enqueued before handoff, got %d", queued) + } +} + +func TestMaybePrepareLightbringerHandoffArmsWhenReplayGapHasHeadroom(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 200, + }) + + bs.isNearTip.Store(true) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(102) + bs.confirmedTip.Store(117) + bs.lightbringerLastStreamSlot.Store(118) + bs.lastEmittedBlockSlot = 110 + bs.nextSlotToSend = 111 + for slot := uint64(111); slot <= 118; slot++ { + parentSlot := slot - 1 + if slot == 111 { + parentSlot = 110 + } + bs.lightbringerBuffer[slot] = &b.Block{Slot: slot, FromLightbringer: true, SourceParentSlot: parentSlot} + bs.lightbringerBufferOrder = append(bs.lightbringerBufferOrder, slot) + } + + bs.maybePrepareLightbringerHandoff() + + if got := bs.lightbringerHandoffSlot.Load(); got != 111 { + t.Fatalf("expected handoff to arm at slot 111 once replay gap has headroom, got %d", got) + } + if queued := len(bs.resultQueue); queued != 8 { + t.Fatalf("expected the 8-slot Lightbringer runway to be enqueued, got %d", queued) + } +} + +func TestShouldDecodeLightbringerSlotStagesBeforeNearTipWithinCatchupWindow(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 300, + }) + + bs.isNearTip.Store(false) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(100) + bs.confirmedTip.Store(164) + bs.lightbringerLastStreamSlot.Store(164) + bs.nextSlotToSend = 110 + + if !bs.shouldDecodeLightbringerSlot(120) { + t.Fatalf("expected Lightbringer slot within catchup staging window to be decoded") + } + if bs.shouldDecodeLightbringerSlot(109) { + t.Fatalf("expected slot behind the emission frontier to stay unstaged") + } +} + +func TestShouldDecodeLightbringerSlotDoesNotStageWhenReplayGapTooLarge(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 300, + }) + + bs.isNearTip.Store(false) + bs.lightbringerConnected.Store(true) + bs.lastExecutedSlot.Store(100) + bs.confirmedTip.Store(165) + bs.lightbringerLastStreamSlot.Store(165) + bs.nextSlotToSend = 101 + + if bs.shouldDecodeLightbringerSlot(120) { + t.Fatalf("expected Lightbringer staging to wait until replay is inside the catchup staging window") + } +} + +func TestUpdateModeDefersCatchupWhileConsensusManagedLightbringerIsLive(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 300, + ConsensusManagedLightbringer: true, + }) + + bs.lightbringerStarted.Store(true) + bs.isNearTip.Store(true) + bs.lightbringerActive.Store(true) + bs.lightbringerConnected.Store(true) + bs.lightbringerHandoffSlot.Store(101) + bs.lastExecutedSlot.Store(100) + bs.confirmedTip.Store(165) + bs.lightbringerLastStreamSlot.Store(164) + bs.lightbringerLastRecvUnix.Store(time.Now().Unix()) + bs.lastProgress.Store(time.Now().Unix()) + bs.nextSlotToSend = 101 + + bs.updateMode() + + if !bs.isNearTip.Load() { + t.Fatalf("expected near-tip mode to remain active while Lightbringer observations are fresh") + } + if !bs.lightbringerActive.Load() { + t.Fatalf("expected Lightbringer to stay active during consensus buffering") + } + if bs.lightbringerNeedRPCResume.Load() { + t.Fatalf("expected RPC resume flag to stay clear while deferring catchup") + } +} + +func TestUpdateModeFallsBackWhenConsensusManagedLightbringerReplayGapExceedsGrace(t *testing.T) { + bs := NewBlockSource(&BlockSourceOpts{ + SourceType: BlockSourceLightbringer, + LightbringerEndpoint: "127.0.0.1:50051", + StartSlot: 100, + EndSlot: 300, + ConsensusManagedLightbringer: true, + }) + + bs.lightbringerStarted.Store(true) + bs.isNearTip.Store(true) + bs.lightbringerActive.Store(true) + bs.lightbringerConnected.Store(true) + bs.lightbringerHandoffSlot.Store(101) + bs.lastExecutedSlot.Store(100) + bs.confirmedTip.Store(229) + bs.lightbringerLastStreamSlot.Store(229) + bs.lightbringerLastRecvUnix.Store(time.Now().Unix()) + bs.lastProgress.Store(time.Now().Unix()) + bs.nextSlotToSend = 150 + + bs.updateMode() + + if bs.isNearTip.Load() { + t.Fatalf("expected near-tip mode to fall back once replay gap exceeds consensus buffering grace") + } + if bs.lightbringerActive.Load() { + t.Fatalf("expected Lightbringer to be marked inactive after fallback") + } + if !bs.lightbringerNeedRPCResume.Load() { + t.Fatalf("expected RPC resume flag to be raised after fallback") + } + if got := bs.nextSlotToSend; got != 101 { + t.Fatalf("expected consensus-managed fallback to rewind emission frontier to replay next slot 101, got %d", got) + } +} + func TestSynthesizeLightbringerSkipsLockedDoesNotInferMissingSlotsFromReconnectingDescendant(t *testing.T) { bs := NewBlockSource(&BlockSourceOpts{ SourceType: BlockSourceLightbringer, diff --git a/pkg/forkchoice/forkchoice.go b/pkg/forkchoice/forkchoice.go index 0fa11489..9a47353c 100644 --- a/pkg/forkchoice/forkchoice.go +++ b/pkg/forkchoice/forkchoice.go @@ -2,6 +2,7 @@ package forkchoice import ( "fmt" + "sort" "sync" "github.com/Overclock-Validator/mithril/pkg/base58" @@ -42,6 +43,27 @@ type BankhashResult struct { ThresholdStake uint64 } +// VoteHashDiagnostic is a JSON-friendly snapshot of votes accumulated for one +// bankhash within a target slot. +type VoteHashDiagnostic struct { + Bankhash string `json:"bankhash"` + Stake uint64 `json:"stake"` + VoterCount int `json:"voter_count"` + Confirmed bool `json:"confirmed"` +} + +// SlotVoteDiagnostic is a JSON-friendly snapshot of forkchoice's accumulated +// vote state for a target slot. +type SlotVoteDiagnostic struct { + Slot uint64 `json:"slot"` + Status string `json:"status"` + WinningHash string `json:"winning_hash,omitempty"` + LatestObservedSlot uint64 `json:"latest_observed_slot"` + TotalEpochStake uint64 `json:"total_epoch_stake"` + ThresholdStake uint64 `json:"threshold_stake"` + Hashes []VoteHashDiagnostic `json:"hashes,omitempty"` +} + // ConfirmedLeaf is a vote-confirmed bankhash winner paired with the observed // block slot it belongs to. type ConfirmedLeaf struct { @@ -582,3 +604,52 @@ func (s *ForkChoiceService) GetSupermajorityHash(slot uint64) (solana.Hash, Bank return solana.Hash{}, BankhashNoSupermajority } + +// SlotVoteDiagnostics returns a compact snapshot of forkchoice vote totals for +// a target slot. It is intended for rare consensus mismatch artifacts rather +// than hot-path logging. +func (s *ForkChoiceService) SlotVoteDiagnostics(slot uint64) SlotVoteDiagnostic { + s.state.mu.Lock() + defer s.state.mu.Unlock() + + out := SlotVoteDiagnostic{ + Slot: slot, + Status: BankhashNoSupermajority.String(), + LatestObservedSlot: s.state.latestObservedSlot, + TotalEpochStake: s.state.totalEpochStake, + } + + accumulator, exists := s.state.voteStakeTotals[slot] + if !exists { + if s.state.latestObservedSlot < slot+VoteConfirmationTimeoutSlots { + out.Status = BankhashNeedWait.String() + } + return out + } + + out.TotalEpochStake = accumulator.totalEpochStake + out.ThresholdStake = accumulator.thresholdStake + if winningHash, ok := accumulator.winningHash(); ok { + out.Status = BankhashHasSupermajority.String() + out.WinningHash = base58.Encode(winningHash[:]) + } else if s.state.latestObservedSlot < slot+VoteConfirmationTimeoutSlots { + out.Status = BankhashNeedWait.String() + } + + out.Hashes = make([]VoteHashDiagnostic, 0, len(accumulator.trackers)) + for bankhash, tracker := range accumulator.trackers { + out.Hashes = append(out.Hashes, VoteHashDiagnostic{ + Bankhash: base58.Encode(bankhash[:]), + Stake: tracker.stake, + VoterCount: len(tracker.voted), + Confirmed: accumulator.hashHasSupermajority(bankhash), + }) + } + sort.Slice(out.Hashes, func(i, j int) bool { + if out.Hashes[i].Stake == out.Hashes[j].Stake { + return out.Hashes[i].Bankhash < out.Hashes[j].Bankhash + } + return out.Hashes[i].Stake > out.Hashes[j].Stake + }) + return out +} diff --git a/pkg/forkchoice/forkchoice_test.go b/pkg/forkchoice/forkchoice_test.go index d1c79c63..ff8f0fe9 100644 --- a/pkg/forkchoice/forkchoice_test.go +++ b/pkg/forkchoice/forkchoice_test.go @@ -2,6 +2,7 @@ package forkchoice import ( "encoding/binary" + "encoding/json" "testing" "github.com/Overclock-Validator/mithril/pkg/epochstakes" @@ -420,6 +421,157 @@ func TestSequentialProcessingDeterminesWinner(t *testing.T) { assert.Equal(t, hashX, resultY.WinningHash, "winning hash should still be hashX") } +func TestParseAndValidateVoteTxFindsVoteInstructionAfterComputeBudget(t *testing.T) { + voterKey := solana.PublicKey{1} + voteAcct := solana.PublicKey{2} + votedSlot := uint64(50) + votedHash := solana.Hash{0xBB} + + epochAuth := epochstakes.NewEpochAuthorizedVotersCache() + epochAuth.PutEntry(voteAcct, voterKey) + + tx := buildTestVoteTx(voteAcct, voterKey, votedSlot, votedHash) + voteInstr := tx.Message.Instructions[0] + tx.Message.AccountKeys = []solana.PublicKey{ + voterKey, + voteAcct, + solana.SysVarSlotHashesPubkey, + solana.SysVarClockPubkey, + solana.ComputeBudget, + solana.VoteProgramID, + } + tx.Message.Instructions = []solana.CompiledInstruction{ + { + ProgramIDIndex: 4, + Data: solana.Base58([]byte{2, 0, 0, 0, 200, 0, 0, 0, 0}), + }, + { + ProgramIDIndex: 5, + Accounts: voteInstr.Accounts, + Data: voteInstr.Data, + }, + } + + require.True(t, tx.IsVote(), "fixture should still be recognized as a vote transaction") + + info, ok := parseAndValidateVoteTx(tx, epochAuth) + require.True(t, ok) + assert.Equal(t, votedSlot, info.slot) + assert.Equal(t, votedHash, solana.Hash(info.bankHash)) + assert.Equal(t, voteAcct, info.votePubkey) +} + +func TestParseAndValidateVoteTxUsesSignerForVoteAuthority(t *testing.T) { + voterKey := solana.PublicKey{1} + voteAcct := solana.PublicKey{2} + votedSlot := uint64(50) + votedHash := solana.Hash{0xBB} + + epochAuth := epochstakes.NewEpochAuthorizedVotersCache() + epochAuth.PutEntry(voteAcct, voterKey) + + tx := buildTestVoteTx(voteAcct, voterKey, votedSlot, votedHash) + require.True(t, tx.IsVote(), "fixture should be recognized as a vote transaction") + require.Equal(t, solana.SysVarSlotHashesPubkey, tx.Message.AccountKeys[tx.Message.Instructions[0].Accounts[1]], + "legacy vote account 1 is the slot-hashes sysvar, not the vote authority") + + info, ok := parseAndValidateVoteTx(tx, epochAuth) + require.True(t, ok) + assert.Equal(t, votedSlot, info.slot) + assert.Equal(t, votedHash, solana.Hash(info.bankHash)) + assert.Equal(t, voteAcct, info.votePubkey) +} + +func TestParseAndValidateVoteTxRejectsAuthorityOutsideInstruction(t *testing.T) { + voterKey := solana.PublicKey{1} + voteAcct := solana.PublicKey{2} + + epochAuth := epochstakes.NewEpochAuthorizedVotersCache() + epochAuth.PutEntry(voteAcct, voterKey) + + tx := buildTestVoteTx(voteAcct, voterKey, 50, solana.Hash{0xBB}) + tx.Message.Instructions[0].Accounts = []uint16{1, 2, 3} + + _, ok := parseAndValidateVoteTx(tx, epochAuth) + require.False(t, ok) +} + +func TestParseAndValidateVoteTxAcceptsLiveTowerSyncShape(t *testing.T) { + voteAuthority := solana.MustPublicKeyFromBase58("DRpbCBMxVnDK7maPM5tGv6MvB3v1sRMC86PZ8okm21hy") + voteAcct := solana.MustPublicKeyFromBase58("3N7s9zXMZ4QqvHQR15t5GNHyqc89KduzMP7423eWiD5g") + votedHash := solana.MustHashFromBase58("F4GcS4MtttPknSkbGW3KCXWJd6mWvzaXDnHHyM87Gd2A") + + var data solana.Base58 + err := json.Unmarshal([]byte(`"67MGmzm8yEnRh15X2h4HuP1ZCWg1Ld1zPNgZqhBGEySYPXuCReZ8tSvvhrKA1j7q6ky81hjNVPUp6WvdLbnfVYTmFvK2C2QCSBbGAoibsseTrrczvs6Xk47BPdpcN6PB9bYaFnu8wtykuo4WLhELbCuYYwwUyA6zNqZfDHLePABFKUDJLbyE9DsqoiATDtoznG7Bevvfra"`), &data) + require.NoError(t, err) + + tx := &solana.Transaction{ + Message: solana.Message{ + Header: solana.MessageHeader{ + NumRequiredSignatures: 1, + }, + AccountKeys: []solana.PublicKey{ + voteAuthority, + voteAcct, + solana.VoteProgramID, + }, + Instructions: []solana.CompiledInstruction{ + { + ProgramIDIndex: 2, + Accounts: []uint16{1, 0}, + Data: data, + }, + }, + }, + Signatures: []solana.Signature{{}}, + } + + epochAuth := epochstakes.NewEpochAuthorizedVotersCache() + epochAuth.PutEntry(voteAcct, voteAuthority) + + info, ok := parseAndValidateVoteTx(tx, epochAuth) + require.True(t, ok) + assert.Equal(t, uint64(420404777), info.slot) + assert.Equal(t, votedHash, solana.Hash(info.bankHash)) + assert.Equal(t, voteAcct, info.votePubkey) +} + +func TestParseAndValidateVoteTxFallsBackToSignatureCountWhenHeaderSignerCountMissing(t *testing.T) { + voteAuthority := solana.MustPublicKeyFromBase58("GmCxjmjKZoaKN1DKunbYq8RCYib94Nm3sHyncFfofaF5") + voteAcct := solana.MustPublicKeyFromBase58("7S9dHgoeMYvtShTjEC3x5D3THRDQz123WVGPseZsm3hm") + + var data solana.Base58 + err := json.Unmarshal([]byte(`"67MGn8HzmNzWfjLAq5WGPoC4LktJMeSH2UUTqWmWd2VXRbURnRQM4hTJvxGRcSbKb6CYLd3x42wvAjAyYsY19ajzUtqxDcE4XZP4eHV47zTUEkudvy7R2a7sJAaJtS9nk9D2NtMP3du8S8BFSUhjLPmVW9pmh4CgnBS5Jh7B8XNkQmGLS8sCGSWY9UbZrYipFy7rEVjirv"`), &data) + require.NoError(t, err) + + tx := &solana.Transaction{ + Message: solana.Message{ + Header: solana.MessageHeader{}, + AccountKeys: []solana.PublicKey{ + voteAuthority, + voteAcct, + solana.VoteProgramID, + }, + Instructions: []solana.CompiledInstruction{ + { + ProgramIDIndex: 2, + Accounts: []uint16{1, 0}, + Data: data, + }, + }, + }, + Signatures: []solana.Signature{{}}, + } + + epochAuth := epochstakes.NewEpochAuthorizedVotersCache() + epochAuth.PutEntry(voteAcct, voteAuthority) + + info, ok := parseAndValidateVoteTx(tx, epochAuth) + require.True(t, ok) + assert.Equal(t, uint64(420407984), info.slot) + assert.Equal(t, voteAcct, info.votePubkey) +} + func TestObserveBlockResolvesParentSlotFromParentBlockhash(t *testing.T) { epochAuth := epochstakes.NewEpochAuthorizedVotersCache() service := NewForkChoiceService(0, map[solana.PublicKey]uint64{}, 100, epochAuth) @@ -541,8 +693,7 @@ func TestFindConfirmedLeafReturnsHighestObservedWinner(t *testing.T) { assert.Equal(t, hash(0xA7), leaf.Bankhash) } -// buildTestVoteTx constructs a minimal valid vote transaction for testing. -// The tx passes IsVote(), IsSigner(authority), and parseAndValidateVoteTx(). +// buildTestVoteTx constructs a minimal legacy Vote instruction for testing. func buildTestVoteTx(voteAcct, voteAuthority solana.PublicKey, slot uint64, hash solana.Hash) *solana.Transaction { // Encode VoteProgramInstrTypeVote (type=2): // [type:4][num_slots:8][slot:8][hash:32][timestamp_opt:1] = 53 bytes @@ -556,17 +707,23 @@ func buildTestVoteTx(voteAcct, voteAuthority solana.PublicKey, slot uint64, hash return &solana.Transaction{ Message: solana.Message{ Header: solana.MessageHeader{ - NumRequiredSignatures: 2, + NumRequiredSignatures: 1, + }, + AccountKeys: []solana.PublicKey{ + voteAuthority, + voteAcct, + solana.SysVarSlotHashesPubkey, + solana.SysVarClockPubkey, + solana.VoteProgramID, }, - AccountKeys: []solana.PublicKey{voteAcct, voteAuthority, solana.VoteProgramID}, Instructions: []solana.CompiledInstruction{ { - ProgramIDIndex: 2, - Accounts: []uint16{0, 1}, + ProgramIDIndex: 4, + Accounts: []uint16{1, 2, 3, 0}, Data: solana.Base58(data), }, }, }, - Signatures: []solana.Signature{{}, {}}, + Signatures: []solana.Signature{{}}, } } diff --git a/pkg/forkchoice/vote_parser.go b/pkg/forkchoice/vote_parser.go index de908758..7f6a0ed0 100644 --- a/pkg/forkchoice/vote_parser.go +++ b/pkg/forkchoice/vote_parser.go @@ -17,22 +17,31 @@ type voteInfo struct { // parseAndValidateVoteTx validates a vote transaction against the given authorized // voters cache. Accepts the cache as a parameter to avoid racing with epoch updates. func parseAndValidateVoteTx(tx *solana.Transaction, authorizedVoters *epochstakes.EpochAuthorizedVotersCache) (*voteInfo, bool) { - if len(tx.Message.Instructions) < 1 { + if len(tx.Message.Instructions) == 0 { return nil, false } - instr := tx.Message.Instructions[0] + for _, instr := range tx.Message.Instructions { + programID, err := tx.ResolveProgramIDIndex(instr.ProgramIDIndex) + if err != nil || !programID.Equals(solana.VoteProgramID) { + continue + } + return parseAndValidateVoteInstruction(tx, instr, authorizedVoters) + } + + return nil, false +} - if len(instr.Accounts) < 2 { +func parseAndValidateVoteInstruction(tx *solana.Transaction, instr solana.CompiledInstruction, authorizedVoters *epochstakes.EpochAuthorizedVotersCache) (*voteInfo, bool) { + if len(instr.Accounts) < 1 { return nil, false } - votePubkey := tx.Message.AccountKeys[instr.Accounts[0]] - voteAuthority := tx.Message.AccountKeys[instr.Accounts[1]] - - if authorizedVoters == nil { + votePubkey, err := tx.Message.Account(instr.Accounts[0]) + if err != nil { return nil, false } - if !(tx.IsSigner(voteAuthority) && authorizedVoters.IsAuthorizedVoter(votePubkey, voteAuthority)) { + + if !hasAuthorizedVoteSigner(tx, instr, votePubkey, authorizedVoters) { return nil, false } @@ -79,6 +88,9 @@ func parseAndValidateVoteTx(tx *solana.Transaction, authorizedVoters *epochstake case sealevel.VoteProgramInstrTypeVote: { + if !hasLegacyVoteSysvarAccounts(tx, instr) { + return nil, false + } var vote sealevel.VoteInstrVote err = vote.UnmarshalWithDecoder(decoder) if err != nil { @@ -97,6 +109,9 @@ func parseAndValidateVoteTx(tx *solana.Transaction, authorizedVoters *epochstake case sealevel.VoteProgramInstrTypeVoteSwitch: { + if !hasLegacyVoteSysvarAccounts(tx, instr) { + return nil, false + } var vote sealevel.VoteInstrVoteSwitch err = vote.UnmarshalWithDecoder(decoder) if err != nil { @@ -188,6 +203,48 @@ func parseAndValidateVoteTx(tx *solana.Transaction, authorizedVoters *epochstake } } +func hasAuthorizedVoteSigner(tx *solana.Transaction, instr solana.CompiledInstruction, votePubkey solana.PublicKey, authorizedVoters *epochstakes.EpochAuthorizedVotersCache) bool { + if authorizedVoters == nil { + return false + } + + // The Vote program validates authority from the instruction's signer set; + // for common vote instructions, account 1 is a sysvar rather than the voter. + numSigners := voteTransactionSignerCount(tx) + if numSigners > len(tx.Message.AccountKeys) { + numSigners = len(tx.Message.AccountKeys) + } + for _, accountIndex := range instr.Accounts { + if int(accountIndex) >= numSigners { + continue + } + if authorizedVoters.IsAuthorizedVoter(votePubkey, tx.Message.AccountKeys[accountIndex]) { + return true + } + } + return false +} + +func voteTransactionSignerCount(tx *solana.Transaction) int { + numSigners := int(tx.Message.Header.NumRequiredSignatures) + if numSigners == 0 && len(tx.Signatures) > 0 { + numSigners = len(tx.Signatures) + } + return numSigners +} + +func hasLegacyVoteSysvarAccounts(tx *solana.Transaction, instr solana.CompiledInstruction) bool { + if len(instr.Accounts) < 3 { + return false + } + slotHashes, err := tx.Message.Account(instr.Accounts[1]) + if err != nil || slotHashes != solana.SysVarSlotHashesPubkey { + return false + } + clock, err := tx.Message.Account(instr.Accounts[2]) + return err == nil && clock == solana.SysVarClockPubkey +} + func getLastLockout(lockouts *deque.Deque[sealevel.VoteLockout]) (*sealevel.VoteLockout, bool) { lockoutsLen := lockouts.Len() if lockoutsLen == 0 { diff --git a/pkg/replay/block.go b/pkg/replay/block.go index 09b5c2fd..c5b4072a 100644 --- a/pkg/replay/block.go +++ b/pkg/replay/block.go @@ -69,14 +69,6 @@ type BlockFetchOpts struct { NearTipLookahead int // Slots ahead to schedule in near-tip, 0 = use default } -// ConsensusOpts contains vote-anchored consensus configuration. -// Nil means use defaults (max_depth=64, policy="halt"). -type ConsensusOpts struct { - SkipPathMaxDepth int // Max slots for skip-path solver (default: 64) - UnresolvedPolicy string // "halt" or "warn" (default: "halt") - EnforceOnSource string // "lightbringer" or "all" (default: "lightbringer") -} - var SerializedParameterArena *arena.Arena[byte] // Commit state tracking for panic recovery @@ -409,7 +401,7 @@ func cacheConstantSysvars(acctsDb *accountsdb.AccountsDb) { } } -func loadBlockAccountsAndUpdateSysvars(accountsDb *accountsdb.AccountsDb, block *b.Block) (accounts.Accounts, accounts.Accounts, error) { +func loadBlockAccountsAndUpdateSysvars(accountsDb *accountsdb.AccountsDb, block *b.Block, epochSchedule *sealevel.SysvarEpochSchedule) (accounts.Accounts, accounts.Accounts, error) { err := resolveAddrTableLookups(accountsDb, block) if err != nil { return nil, nil, err @@ -465,7 +457,7 @@ func loadBlockAccountsAndUpdateSysvars(accountsDb *accountsdb.AccountsDb, block panic("unable to unmarshal clock sysvar") } - err = updateClockSysvar(&clock, block) + err = updateClockSysvar(&clock, block, epochSchedule) if err != nil { panic(fmt.Sprintf("failed to update clock sysvar: %s", err)) } @@ -793,6 +785,7 @@ func setupInitialVoteAcctsAndStakeAccts(acctsDb *accountsdb.AccountsDb, block *b if err := RebuildVoteCacheFromAccountsDB(acctsDb, block.Slot, voteAcctStakes, 0); err != nil { mlog.Log.Warnf("vote cache rebuild had errors: %v", err) } + rebuildAuthorizedVotersFromVoteCache(block.Epoch) // Seed EpochStakesPerVoteAcct and TotalEpochStake from the epoch stakes cache, // loaded by buildInitialEpochStakesCache() from the manifest. These are @@ -1115,19 +1108,29 @@ func initializeBlockHeight(rpcc *rpcclient.RpcClient, mithrilState *state.Mithri return nil } -// buildInitialEpochStakesCache seeds the epoch stakes cache from state file or manifest. -// Priority: 1) State file ManifestEpochStakes, 2) Direct manifest (backwards compat) -func buildInitialEpochStakesCache(mithrilState *state.MithrilState) error { - // Require state file ManifestEpochStakes (PersistedEpochStakes JSON format) - if mithrilState == nil || len(mithrilState.ManifestEpochStakes) == 0 { - return fmt.Errorf("state file missing manifest_epoch_stakes - delete AccountsDB and rebuild from snapshot") +// buildInitialEpochStakesCache seeds the epoch stakes cache from state file manifest data. +func buildInitialEpochStakesCache(mithrilState *state.MithrilState, currentEpoch uint64, snapshotEpoch uint64) error { + seeds, rebased, err := prepareManifestEpochStakesForRuntime(mithrilState, currentEpoch, snapshotEpoch) + if err != nil { + return err + } + if rebased { + sourceEpochs := make([]uint64, 0, len(seeds)) + runtimeEpochs := make([]uint64, 0, len(seeds)) + for _, seed := range seeds { + sourceEpochs = append(sourceEpochs, seed.sourceEpoch) + runtimeEpochs = append(runtimeEpochs, seed.runtimeEpoch) + } + mlog.Log.Warnf("rebasing manifest epoch stakes from snapshot epochs %v to runtime epochs %v (snapshot epoch %d)", + sourceEpochs, runtimeEpochs, snapshotEpoch) } - for epoch, data := range mithrilState.ManifestEpochStakes { - if loadedEpoch, err := global.DeserializeAndLoadEpochStakes([]byte(data)); err != nil { - return fmt.Errorf("failed to load manifest epoch %d stakes from state file: %w", epoch, err) + for _, seed := range seeds { + if loadedEpoch, err := global.DeserializeAndLoadEpochStakes(seed.data); err != nil { + return fmt.Errorf("failed to load manifest epoch %d stakes from state file: %w", seed.sourceEpoch, err) } else { - mlog.Log.Debugf("loaded epoch %d stakes from state file manifest_epoch_stakes", loadedEpoch) + mlog.Log.Debugf("loaded manifest epoch %d stakes as runtime epoch %d from state file manifest_epoch_stakes", + seed.sourceEpoch, loadedEpoch) } } @@ -1226,12 +1229,21 @@ func ReplayBlocks( } cacheConstantSysvars(acctsDb) - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar + epochSchedule, usingManifestEpochSchedule, err := bankEpochScheduleForReplay(mithrilState) + if err != nil { + result.Error = err + return result + } + if usingManifestEpochSchedule && !epochSchedulesEqual(epochSchedule, sealevel.SysvarCache.EpochSchedule.Sysvar) { + sysvarSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar + mlog.Log.Warnf("bank epoch schedule differs from SysvarEpochSchedule account; using manifest bank schedule for replay | bank_slots_per_epoch=%d bank_leader_offset=%d bank_first_normal_epoch=%d bank_first_normal_slot=%d | sysvar_slots_per_epoch=%d sysvar_leader_offset=%d sysvar_first_normal_epoch=%d sysvar_first_normal_slot=%d", + epochSchedule.SlotsPerEpoch, epochSchedule.LeaderScheduleSlotOffset, epochSchedule.FirstNormalEpoch, epochSchedule.FirstNormalSlot, + sysvarSchedule.SlotsPerEpoch, sysvarSchedule.LeaderScheduleSlotOffset, sysvarSchedule.FirstNormalEpoch, sysvarSchedule.FirstNormalSlot) + } global.SetCalcUnixTimeForClockSysvar(true) global.SetManageLeaderSchedule(true) - var err error var currentSlot uint64 currentEpoch := epochSchedule.GetEpoch(startSlot) var lastSlotCtx *sealevel.SlotCtx @@ -1305,14 +1317,14 @@ func ReplayBlocks( } } else { // Resume in same epoch as snapshot, no boundaries crossed - state file epoch stakes still valid - if err := buildInitialEpochStakesCache(mithrilState); err != nil { + if err := buildInitialEpochStakesCache(mithrilState, currentEpoch, snapshotEpoch); err != nil { result.Error = err return result } } } else { // Fresh start: load all epochs from state file - if err := buildInitialEpochStakesCache(mithrilState); err != nil { + if err := buildInitialEpochStakesCache(mithrilState, currentEpoch, snapshotEpoch); err != nil { result.Error = err return result } @@ -1323,34 +1335,13 @@ func ReplayBlocks( } // Resolve consensus config defaults before forkchoice init so we can // check whether enforcement requires authorized voters. - consensusMaxDepth := 64 - consensusPolicy := "halt" - consensusEnforceSource := "lightbringer" - if consensusOpts != nil { - if consensusOpts.SkipPathMaxDepth > 0 { - consensusMaxDepth = consensusOpts.SkipPathMaxDepth - } - if consensusOpts.UnresolvedPolicy != "" { - consensusPolicy = consensusOpts.UnresolvedPolicy - } - if consensusOpts.EnforceOnSource != "" { - consensusEnforceSource = consensusOpts.EnforceOnSource - } - } - switch consensusEnforceSource { - case "lightbringer", "all": - default: - mlog.Log.Warnf("forkchoice: invalid EnforceOnSource=%q, defaulting to \"lightbringer\"", consensusEnforceSource) - consensusEnforceSource = "lightbringer" - } - - consensusEnforceActive := consensusEnforceSource == "all" || useLightbringer + consensusCfg := resolveConsensusConfig(consensusOpts, useLightbringer, isLive) epochAuthVoters := global.EpochAuthorizedVoters() if epochAuthVoters == nil { // Without authorized voters, forkchoice can't parse votes → no supermajority → enforcement is blind. // If consensus enforcement is active, this is a fatal misconfiguration. - if consensusEnforceActive && consensusPolicy == "halt" { + if consensusCfg.enforceActive && consensusCfg.policy == "halt" { result.Error = fmt.Errorf("forkchoice: EpochAuthorizedVoters is nil — cannot enforce consensus without vote parsing (check snapshot/state file)") return result } @@ -1364,8 +1355,8 @@ func ReplayBlocks( // Instantiate the consensus coordinator for skip-path resolution and policy. // In Lightbringer mode this now resolves a pre-execution block/skip path from // the current anchor to a vote-confirmed leaf. - consensusCoordinator := forkchoice.NewConsensusCoordinator(forkChoice, consensusMaxDepth, consensusPolicy) - consensusBufferedExecutionActive := !isLive || consensusEnforceSource == "all" + consensusCoordinator := forkchoice.NewConsensusCoordinator(forkChoice, consensusCfg.maxDepth, consensusCfg.policy) + consensusBufferedExecutionActive := consensusCfg.bufferedExecutionActive var statsCounter int var execTimes []float64 // seconds per block @@ -1383,12 +1374,6 @@ func ReplayBlocks( voteTxCounts = make([]uint64, 0, summaryInterval) nonVoteTxCounts = make([]uint64, 0, summaryInterval) - type pendingConsensusPath struct { - leafSlot uint64 - leafBankhash solana.Hash - decisions []forkchoice.SlotDecision - } - var readyConsensusPath *pendingConsensusPath observedConsensusBlocks := make(map[uint64]*b.Block) @@ -1402,9 +1387,9 @@ func ReplayBlocks( StartSlot: startSlot, EndSlot: endSlot, BlockDir: blockDir, - ConsensusManagedLightbringer: consensusEnforceActive && + ConsensusManagedLightbringer: consensusCfg.enforceActive && isLive && - consensusEnforceSource == "lightbringer", + consensusCfg.enforceSource == "lightbringer", } } else { opts = &blockstream.BlockSourceOpts{ @@ -1443,32 +1428,6 @@ func ReplayBlocks( var skippedSlotsCount int // Track skipped slots for 100-slot summary replayStartLogged := false - // writeConsensusArtifact writes a best-effort JSON diagnostic artifact to the - // per-run consensus subdirectory. If the log dir is empty or any step fails, - // it logs a warning and continues — artifact failure must not crash replay. - writeConsensusArtifact := func(filename string, data map[string]interface{}) { - logDir := mlog.GetLogDir() - if logDir == "" { - return - } - dir := filepath.Join(logDir, "consensus") - if err := os.MkdirAll(dir, 0755); err != nil { - mlog.Log.Warnf("consensus artifact: failed to create directory %s: %v", dir, err) - return - } - artifactPath := filepath.Join(dir, filename) - artifactJSON, jsonErr := json.MarshalIndent(data, "", " ") - if jsonErr != nil { - mlog.Log.Warnf("consensus artifact: failed to marshal JSON for %s: %v", filename, jsonErr) - return - } - if writeErr := os.WriteFile(artifactPath, artifactJSON, 0644); writeErr != nil { - mlog.Log.Warnf("consensus artifact: failed to write %s: %v", artifactPath, writeErr) - return - } - mlog.Log.FileOnlyf("consensus artifact written: %s", artifactPath) - } - currentConsensusAnchorSlot := func() uint64 { if lastSlotCtx != nil { return lastSlotCtx.Slot @@ -1498,25 +1457,8 @@ func ReplayBlocks( } } - pruneObservedConsensusBlocks := func(anchorSlot uint64) { - if observedConsensusBlocks == nil || anchorSlot == 0 { - return - } - for slot := range observedConsensusBlocks { - if slot <= anchorSlot { - delete(observedConsensusBlocks, slot) - } - } - } - - clearObservedConsensusBlocks := func() { - for slot := range observedConsensusBlocks { - delete(observedConsensusBlocks, slot) - } - } - syncConsensusBufferedExecutionMode := func(triggerSlot uint64) { - if !consensusEnforceActive || !isLive || consensusEnforceSource != "lightbringer" { + if !consensusCfg.enforceActive || !isLive || consensusCfg.enforceSource != "lightbringer" { return } @@ -1531,7 +1473,7 @@ func ReplayBlocks( consensusBufferedExecutionActive = false readyConsensusPath = nil - clearObservedConsensusBlocks() + clearObservedConsensusBlocks(observedConsensusBlocks) observeConsensusAnchor() mlog.Log.Warnf("forkchoice: suspending buffered execution at slot %d because block source left near-tip mode; RPC catchup will continue from anchor %d (discarded_observed_blocks=%d discarded_ready_decisions=%d next_emitted_slot=%d)", triggerSlot, anchorSlot, discardedObservedBlocks, readyDecisionCount, stats.NextSlot) @@ -1539,18 +1481,18 @@ func ReplayBlocks( } observeBlockForConsensus := func(block *b.Block) error { - if !consensusEnforceActive { + if !consensusCfg.enforceActive { return nil } - if !consensusBufferedExecutionActive && isLive && consensusEnforceSource == "lightbringer" { + if !consensusBufferedExecutionActive && isLive && consensusCfg.enforceSource == "lightbringer" { if block == nil || !block.FromLightbringer { return nil } consensusBufferedExecutionActive = true readyConsensusPath = nil observeConsensusAnchor() - pruneObservedConsensusBlocks(currentConsensusAnchorSlot()) + pruneObservedConsensusBlocks(observedConsensusBlocks, currentConsensusAnchorSlot()) mlog.Log.Warnf("forkchoice: enabling buffered execution at slot %d after block source switched to Lightbringer", block.Slot) } @@ -1704,7 +1646,7 @@ func ReplayBlocks( case errors.Is(err, forkchoice.ErrDepthExceeded): if consensusCoordinator.Policy() == "halt" { result.Error = fmt.Errorf("forkchoice: unable to resolve a confirmed path within %d slots from anchor %d", - consensusMaxDepth, currentConsensusAnchorSlot()) + consensusCfg.maxDepth, currentConsensusAnchorSlot()) break } mlog.Log.Warnf("forkchoice: path resolution exceeded max depth from anchor %d", currentConsensusAnchorSlot()) @@ -1722,11 +1664,7 @@ func ReplayBlocks( continue } - readyConsensusPath = &pendingConsensusPath{ - leafSlot: resolvedPath.LeafSlot, - leafBankhash: resolvedPath.LeafBankhash, - decisions: resolvedPath.SlotDecisions, - } + readyConsensusPath = newPendingConsensusPath(currentConsensusAnchorSlot(), resolvedPath) continue } } @@ -1767,10 +1705,11 @@ func ReplayBlocks( currentSlot = block.Slot block.Epoch = epochSchedule.GetEpoch(currentSlot) var configErr error + initialBlockConfigured := lastSlotCtx == nil // Use lastSlotCtx == nil to detect first block, not currentSlot == startSlot. // This handles the case where startSlot (or slots after it) are skipped - // the first emitted block might have slot > startSlot. - if lastSlotCtx == nil { + if initialBlockConfigured { if resumeState != nil { // RESUME: Use resume state + state file (for static fields) configErr = configureInitialBlockFromResume(acctsDb, block, resumeState, mithrilState, epochSchedule) @@ -1787,6 +1726,18 @@ func ReplayBlocks( result.Error = configErr break } + if initialBlockConfigured { + // Initial block configuration rebuilds VoteCache and EpochAuthorizedVoters + // from AccountsDB. Forkchoice is created before that happens, so refresh + // its epoch view here to avoid using stale manifest voters after resume or + // an epoch boundary. + forkChoice.UpdateEpoch( + block.Epoch, + global.EpochStakes(block.Epoch), + global.EpochTotalStake(block.Epoch), + global.EpochAuthorizedVoters(), + ) + } // Log replay start message once, after initial configuration completes if !replayStartLogged { @@ -1868,36 +1819,9 @@ func ReplayBlocks( block.ParentEpochUpdatedAccts = append(block.ParentEpochUpdatedAccts, parentDistributedAccts...) } - // EAH (Epoch Accounts Hash) Workaround - DISABLED - // Background: Before the AccountsLtHash feature was activated, Solana required an Epoch - // Accounts Hash at specific slots during partitioned epoch rewards. This hash covers all - // accounts and is included in the bankhash calculation at the EahStopOffsetSlot. - // Problem: Mithril does not implement EAH computation (it's expensive and was being phased out). - // Workaround: For historical slots that require EAH, we fetch the expected bankhash from RPC - // instead of computing it locally. This allows replaying old slots without EAH implementation. - // Note: This workaround is only needed for slots before AccountsLtHash activation (~Nov 2024). - // If replaying historical slots and hitting EAH requirements, the bankhash is fetched from - // a trusted RPC endpoint rather than computed. The bankhash is NOT stored to bankhash_db - // in this case (see ProcessBlock's early return when HasEahWorkaround is true). - // Uncomment if you need to replay pre-AccountsLtHash historical slots. - /* - if !block.Features.IsActive(features.AccountsLtHash) { - if partitionedEpochRewardsEnabled && block.Slot == partitionedRewardsInfo.EahStopOffsetSlot { - if replayCtx.HasEpochAcctsHash { - block.EpochAcctsHash = replayCtx.EpochAcctsHash - } else { - block.EahWorkaroundBankhash, err = fetchBankhashForSlot(rpcc, block.Slot) - if err != nil { - panic(fmt.Sprintf("unable to fetch bankhash for EAH workaround for slot %d", block.Slot)) - } - block.HasEahWorkaround = true - } - } - } - */ metrics.GlobalBlockReplay.PreprocessBlock.AddTimingSince(start) - lastSlotCtx, err = ProcessBlock(acctsDb, block, txParallelism, dbgOpts, pt) + lastSlotCtx, err = ProcessBlock(acctsDb, block, epochSchedule, txParallelism, dbgOpts, pt) if err != nil { mlog.Log.Errorf("error encountered during block replay: %s\n", err) result.Error = err @@ -1912,8 +1836,6 @@ func ReplayBlocks( } if consensusBufferedExecutionActive { - observeConsensusAnchor() - pruneObservedConsensusBlocks(currentConsensusAnchorSlot()) if readyConsensusPath != nil && block.Slot == readyConsensusPath.leafSlot { actualBankhash := solana.HashFromBytes(lastSlotCtx.FinalBankhash) if actualBankhash != readyConsensusPath.leafBankhash { @@ -1924,14 +1846,17 @@ func ReplayBlocks( ) writeConsensusArtifact( fmt.Sprintf("bankhash_mismatch_slot_%d.json", block.Slot), - map[string]interface{}{ - "type": "bankhash_mismatch", - "checked_slot": block.Slot, - "our_bankhash": base58.Encode(actualBankhash[:]), - "winning_bankhash": base58.Encode(readyConsensusPath.leafBankhash[:]), - "policy": consensusCoordinator.Policy(), - "run_id": CurrentRunID, - }, + buildConsensusMismatchArtifact( + block, + lastSlotCtx, + readyConsensusPath, + actualBankhash, + blockStream.GetFetchStats(), + forkChoice, + consensusCoordinator.Policy(), + observedConsensusBlocks, + currentConsensusAnchorSlot(), + ), ) if consensusCoordinator.Policy() == "halt" { result.Error = fmt.Errorf("consensus halt: slot %d bankhash mismatch (our=%s winning=%s)", @@ -1941,6 +1866,8 @@ func ReplayBlocks( } readyConsensusPath = nil } + observeConsensusAnchor() + pruneObservedConsensusBlocks(observedConsensusBlocks, currentConsensusAnchorSlot()) } replayCtx.Capitalization -= lastSlotCtx.LamportsBurnt @@ -2665,6 +2592,7 @@ func parallelTxLoop(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, bloc func ProcessBlock( acctsDb *accountsdb.AccountsDb, block *b.Block, + epochSchedule *sealevel.SysvarEpochSchedule, txParallelism int, dbgOpts *DebugOptions, // pt is updated after StoreAccounts completes through a callback. @@ -2731,6 +2659,8 @@ func ProcessBlock( unresolvedBlock := &b.Block{ Transactions: make([]*solana.Transaction, len(block.Transactions)), TxMetas: make([]*rpc.TransactionMeta, len(block.TxMetas)), + Slot: block.Slot, + ParentSlot: block.ParentSlot, } for i := range block.Transactions { clonedTx, cloneErr := cloneTransaction(block.Transactions[i]) @@ -2747,7 +2677,7 @@ func ProcessBlock( start = time.Now() setReplayStage("load_accounts") loadAcctsRegion := trace.StartRegion(ctx, "LoadBlockAccounts") - accts, parentAccts, err := loadBlockAccountsAndUpdateSysvars(acctsDb, block) + accts, parentAccts, err := loadBlockAccountsAndUpdateSysvars(acctsDb, block, epochSchedule) loadAcctsRegion.End() if err != nil { panic(fmt.Sprintf("unable to load slot accounts and update sysvars: %s", err)) @@ -2787,7 +2717,6 @@ func ProcessBlock( start = time.Now() setReplayStage("collect_rent") - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar rentSysvar := sealevel.SysvarCache.Rent.Sysvar rentAccts := rent.CollectRentEagerly(slotCtx, rentSysvar, epochSchedule) metrics.GlobalBlockReplay.Rent.AddTimingSince(start) diff --git a/pkg/replay/consensus.go b/pkg/replay/consensus.go new file mode 100644 index 00000000..24ea48b6 --- /dev/null +++ b/pkg/replay/consensus.go @@ -0,0 +1,111 @@ +package replay + +import ( + b "github.com/Overclock-Validator/mithril/pkg/block" + "github.com/Overclock-Validator/mithril/pkg/blockstream" + "github.com/Overclock-Validator/mithril/pkg/forkchoice" + "github.com/Overclock-Validator/mithril/pkg/mlog" + "github.com/gagliardetto/solana-go" +) + +const ( + defaultConsensusMaxDepth = 64 + defaultConsensusPolicy = "halt" + defaultConsensusEnforceSource = "lightbringer" +) + +// ConsensusOpts contains vote-anchored consensus configuration. +// Nil means use defaults (max_depth=64, policy="halt"). +type ConsensusOpts struct { + SkipPathMaxDepth int // Max slots for skip-path solver (default: 64) + UnresolvedPolicy string // "halt" or "warn" (default: "halt") + EnforceOnSource string // "lightbringer" or "all" (default: "lightbringer") +} + +type consensusConfig struct { + maxDepth int + policy string + enforceSource string + enforceActive bool + bufferedExecutionActive bool +} + +// pendingConsensusPath tracks a vote-resolved path that replay has observed but +// has not yet executed through to the confirmed leaf. +type pendingConsensusPath struct { + anchorSlot uint64 + leafSlot uint64 + leafBankhash solana.Hash + decisions []forkchoice.SlotDecision + originalDecisions []forkchoice.SlotDecision +} + +func resolveConsensusConfig(opts *ConsensusOpts, useLightbringer, isLive bool) consensusConfig { + cfg := consensusConfig{ + maxDepth: defaultConsensusMaxDepth, + policy: defaultConsensusPolicy, + enforceSource: defaultConsensusEnforceSource, + } + + if opts != nil { + if opts.SkipPathMaxDepth > 0 { + cfg.maxDepth = opts.SkipPathMaxDepth + } + if opts.UnresolvedPolicy != "" { + cfg.policy = opts.UnresolvedPolicy + } + if opts.EnforceOnSource != "" { + cfg.enforceSource = opts.EnforceOnSource + } + } + + switch cfg.enforceSource { + case "lightbringer", "all": + default: + mlog.Log.Warnf("forkchoice: invalid EnforceOnSource=%q, defaulting to %q", cfg.enforceSource, defaultConsensusEnforceSource) + cfg.enforceSource = defaultConsensusEnforceSource + } + + cfg.enforceActive = cfg.enforceSource == "all" || useLightbringer + cfg.bufferedExecutionActive = !isLive || cfg.enforceSource == "all" + return cfg +} + +func newPendingConsensusPath(anchorSlot uint64, resolvedPath *forkchoice.ResolvedPath) *pendingConsensusPath { + if resolvedPath == nil { + return nil + } + decisions := append([]forkchoice.SlotDecision(nil), resolvedPath.SlotDecisions...) + return &pendingConsensusPath{ + anchorSlot: anchorSlot, + leafSlot: resolvedPath.LeafSlot, + leafBankhash: resolvedPath.LeafBankhash, + decisions: append([]forkchoice.SlotDecision(nil), decisions...), + originalDecisions: decisions, + } +} + +func pruneObservedConsensusBlocks(blocks map[uint64]*b.Block, anchorSlot uint64) { + if blocks == nil || anchorSlot == 0 { + return + } + for slot := range blocks { + if slot <= anchorSlot { + delete(blocks, slot) + } + } +} + +func clearObservedConsensusBlocks(blocks map[uint64]*b.Block) { + for slot := range blocks { + delete(blocks, slot) + } +} + +func shouldDiscardLightbringerObservationAfterFallback(isLive, useLightbringer bool, block *b.Block, stats blockstream.FetchStatsSnapshot) bool { + return isLive && + useLightbringer && + block != nil && + block.FromLightbringer && + (!stats.IsNearTip || stats.CurrentSource != "lightbringer") +} diff --git a/pkg/replay/consensus_fallback.go b/pkg/replay/consensus_fallback.go deleted file mode 100644 index d34c9eb4..00000000 --- a/pkg/replay/consensus_fallback.go +++ /dev/null @@ -1,14 +0,0 @@ -package replay - -import ( - b "github.com/Overclock-Validator/mithril/pkg/block" - "github.com/Overclock-Validator/mithril/pkg/blockstream" -) - -func shouldDiscardLightbringerObservationAfterFallback(isLive, useLightbringer bool, block *b.Block, stats blockstream.FetchStatsSnapshot) bool { - return isLive && - useLightbringer && - block != nil && - block.FromLightbringer && - (!stats.IsNearTip || stats.CurrentSource != "lightbringer") -} diff --git a/pkg/replay/diagnostics.go b/pkg/replay/diagnostics.go new file mode 100644 index 00000000..35dacfcb --- /dev/null +++ b/pkg/replay/diagnostics.go @@ -0,0 +1,361 @@ +package replay + +import ( + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "time" + + "github.com/Overclock-Validator/mithril/pkg/base58" + b "github.com/Overclock-Validator/mithril/pkg/block" + "github.com/Overclock-Validator/mithril/pkg/blockstream" + "github.com/Overclock-Validator/mithril/pkg/forkchoice" + "github.com/Overclock-Validator/mithril/pkg/lthash" + "github.com/Overclock-Validator/mithril/pkg/mlog" + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/gagliardetto/solana-go" + "github.com/gagliardetto/solana-go/rpc" +) + +type consensusTxDiagnostic struct { + Index int `json:"index"` + Signatures []string `json:"signatures"` + IsVote bool `json:"is_vote"` + Version string `json:"version"` + RecentBlockhash string `json:"recent_blockhash"` + AccountKeyCount int `json:"account_key_count"` + WritableAccountCount int `json:"writable_account_count,omitempty"` + ReadonlyAccountCount int `json:"readonly_account_count,omitempty"` + AddressTableCount int `json:"address_table_count,omitempty"` + AddressLookupCount int `json:"address_lookup_count,omitempty"` + AddressTableLookupKeys []string `json:"address_table_lookup_keys,omitempty"` + InstructionCount int `json:"instruction_count"` + ProgramIDs []string `json:"program_ids,omitempty"` + Meta any `json:"meta"` +} + +type consensusEntryDiagnostic struct { + Index int `json:"index"` + NumHashes uint64 `json:"num_hashes"` + Hash string `json:"hash,omitempty"` + TxCount int `json:"tx_count"` + FirstTxIndex uint64 `json:"first_tx_index,omitempty"` + LastTxIndex uint64 `json:"last_tx_index,omitempty"` + TxIndices []uint64 `json:"tx_indices,omitempty"` +} + +func consensusHashString(hash [32]byte) string { + if hash == ([32]byte{}) { + return "" + } + return base58.Encode(hash[:]) +} + +func consensusByteHashString(hash []byte) string { + if len(hash) == 0 { + return "" + } + if len(hash) == 32 { + return base58.Encode(hash) + } + return hex.EncodeToString(hash) +} + +func consensusLtHashChecksum(ltHash *lthash.LtHash) string { + if ltHash == nil { + return "" + } + return consensusByteHashString(ltHash.Checksum()) +} + +func consensusSignatureStrings(tx *solana.Transaction) []string { + sigs := make([]string, 0, len(tx.Signatures)) + for _, sig := range tx.Signatures { + sigs = append(sigs, sig.String()) + } + return sigs +} + +func consensusTxVersion(tx *solana.Transaction) string { + if tx.Message.IsVersioned() { + return fmt.Sprintf("v%d", tx.Message.GetVersion()) + } + return "legacy" +} + +func consensusProgramIDs(tx *solana.Transaction) []string { + out := make([]string, 0, len(tx.Message.Instructions)) + seen := make(map[solana.PublicKey]struct{}, len(tx.Message.Instructions)) + for _, instr := range tx.Message.Instructions { + idx := int(instr.ProgramIDIndex) + if idx < 0 || idx >= len(tx.Message.AccountKeys) { + out = append(out, fmt.Sprintf("invalid_program_index:%d", idx)) + continue + } + programID := tx.Message.AccountKeys[idx] + if _, ok := seen[programID]; ok { + continue + } + seen[programID] = struct{}{} + out = append(out, programID.String()) + } + return out +} + +func consensusLookupTableKeys(tx *solana.Transaction) []string { + if !tx.Message.IsVersioned() || tx.Message.AddressTableLookups.NumLookups() == 0 { + return nil + } + tableIDs := tx.Message.GetAddressTableLookups().GetTableIDs() + out := make([]string, 0, len(tableIDs)) + for _, tableID := range tableIDs { + out = append(out, tableID.String()) + } + return out +} + +func consensusTxMetaDiagnostic(txMeta *rpc.TransactionMeta) any { + if txMeta == nil { + return map[string]any{"present": false} + } + out := map[string]any{ + "present": true, + "fee": txMeta.Fee, + "pre_balance_count": len(txMeta.PreBalances), + "post_balance_count": len(txMeta.PostBalances), + "loaded_writable_address_count": len(txMeta.LoadedAddresses.Writable), + "loaded_readonly_address_count": len(txMeta.LoadedAddresses.ReadOnly), + } + if txMeta.Err != nil { + out["err"] = fmt.Sprintf("%v", txMeta.Err) + } + if txMeta.ComputeUnitsConsumed != nil { + out["compute_units_consumed"] = *txMeta.ComputeUnitsConsumed + } + return out +} + +func consensusTxDiagnostics(block *b.Block) []consensusTxDiagnostic { + out := make([]consensusTxDiagnostic, 0, len(block.Transactions)) + for idx, tx := range block.Transactions { + var txMeta *rpc.TransactionMeta + if idx < len(block.TxMetas) { + txMeta = block.TxMetas[idx] + } + + txDiag := consensusTxDiagnostic{ + Index: idx, + Signatures: consensusSignatureStrings(tx), + IsVote: tx.IsVote(), + Version: consensusTxVersion(tx), + RecentBlockhash: tx.Message.RecentBlockhash.String(), + AccountKeyCount: len(tx.Message.AccountKeys), + AddressTableCount: len(tx.Message.AddressTableLookups), + AddressLookupCount: tx.Message.AddressTableLookups.NumLookups(), + InstructionCount: len(tx.Message.Instructions), + ProgramIDs: consensusProgramIDs(tx), + Meta: consensusTxMetaDiagnostic(txMeta), + } + if canDeriveAccountsFromMessage(tx) { + txDiag.WritableAccountCount = len(messageWritableAccounts(&tx.Message)) + txDiag.ReadonlyAccountCount = len(messageReadonlyAccounts(&tx.Message)) + } + txDiag.AddressTableLookupKeys = consensusLookupTableKeys(tx) + out = append(out, txDiag) + } + return out +} + +func consensusEntryDiagnostics(block *b.Block) []consensusEntryDiagnostic { + out := make([]consensusEntryDiagnostic, 0, len(block.Entries)) + for idx, entry := range block.Entries { + entryDiag := consensusEntryDiagnostic{ + Index: idx, + NumHashes: entry.NumHashes, + Hash: consensusByteHashString(entry.Hash), + TxCount: len(entry.Indices), + TxIndices: append([]uint64(nil), entry.Indices...), + } + if len(entry.Indices) > 0 { + entryDiag.FirstTxIndex = entry.Indices[0] + entryDiag.LastTxIndex = entry.Indices[len(entry.Indices)-1] + } + out = append(out, entryDiag) + } + return out +} + +func consensusBlockDiagnostic(block *b.Block) map[string]any { + voteTxCount := 0 + for _, tx := range block.Transactions { + if tx.IsVote() { + voteTxCount++ + } + } + + return map[string]any{ + "slot": block.Slot, + "parent_slot": block.ParentSlot, + "source_parent_slot": block.SourceParentSlot, + "block_height": block.BlockHeight, + "epoch": block.Epoch, + "from_lightbringer": block.FromLightbringer, + "is_skipped": block.IsSkipped, + "leader": block.Leader.String(), + "blockhash": consensusHashString(block.Blockhash), + "last_blockhash": consensusHashString(block.LastBlockhash), + "parent_bankhash": consensusHashString(block.ParentBankhash), + "expected_bankhash": consensusHashString(block.ExpectedBankhash), + "accts_lthash_checksum": consensusLtHashChecksum(block.AcctsLtHash), + "num_signatures": block.NumSignatures, + "prev_num_signatures": block.PrevNumSignatures, + "initial_lamports_per_sig": block.InitialPreviousLamportsPerSignature, + "tx_count": len(block.Transactions), + "vote_tx_count": voteTxCount, + "non_vote_tx_count": len(block.Transactions) - voteTxCount, + "tx_meta_count": len(block.TxMetas), + "entry_count": len(block.Entries), + "entries": consensusEntryDiagnostics(block), + "transactions": consensusTxDiagnostics(block), + "latest_evicted_blockhash": consensusHashString(block.LatestEvictedBlockhash), + "has_eah_workaround": block.HasEahWorkaround, + "eah_workaround_bankhash": consensusByteHashString(block.EahWorkaroundBankhash), + "num_reward_partitions": block.NumRewardPartitions, + "reward_count": len(block.Rewards), + "updated_account_count": len(block.UpdatedAccts), + "epoch_updated_account_count": len(block.EpochUpdatedAccts), + } +} + +func consensusObservedBlocksDiagnostic(blocks map[uint64]*b.Block) []map[string]any { + slots := make([]uint64, 0, len(blocks)) + for slot := range blocks { + slots = append(slots, slot) + } + sort.Slice(slots, func(i, j int) bool { return slots[i] < slots[j] }) + + out := make([]map[string]any, 0, len(slots)) + for _, slot := range slots { + block := blocks[slot] + out = append(out, map[string]any{ + "slot": block.Slot, + "source_parent_slot": block.SourceParentSlot, + "from_lightbringer": block.FromLightbringer, + "is_skipped": block.IsSkipped, + "blockhash": consensusHashString(block.Blockhash), + "last_blockhash": consensusHashString(block.LastBlockhash), + "tx_count": len(block.Transactions), + "entry_count": len(block.Entries), + }) + } + return out +} + +func consensusDecisionDiagnostics(decisions []forkchoice.SlotDecision) []map[string]any { + out := make([]map[string]any, 0, len(decisions)) + for _, decision := range decisions { + out = append(out, map[string]any{ + "slot": decision.Slot, + "use_block": decision.UseBlock, + }) + } + return out +} + +// writeConsensusArtifact writes a best-effort JSON diagnostic artifact to the +// per-run consensus subdirectory. If the log dir is empty or any step fails, +// it logs a warning and continues; artifact failure must not crash replay. +func writeConsensusArtifact(filename string, data map[string]interface{}) { + logDir := mlog.GetLogDir() + if logDir == "" { + return + } + dir := filepath.Join(logDir, "consensus") + if err := os.MkdirAll(dir, 0755); err != nil { + mlog.Log.Warnf("consensus artifact: failed to create directory %s: %v", dir, err) + return + } + artifactPath := filepath.Join(dir, filename) + artifactJSON, jsonErr := json.MarshalIndent(data, "", " ") + if jsonErr != nil { + mlog.Log.Warnf("consensus artifact: failed to marshal JSON for %s: %v", filename, jsonErr) + return + } + if writeErr := os.WriteFile(artifactPath, artifactJSON, 0644); writeErr != nil { + mlog.Log.Warnf("consensus artifact: failed to write %s: %v", artifactPath, writeErr) + return + } + mlog.Log.FileOnlyf("consensus artifact written: %s", artifactPath) +} + +func buildConsensusMismatchArtifact( + block *b.Block, + slotCtx *sealevel.SlotCtx, + path *pendingConsensusPath, + actualBankhash solana.Hash, + fetchStats blockstream.FetchStatsSnapshot, + forkChoice *forkchoice.ForkChoiceService, + consensusPolicy string, + observedConsensusBlocks map[uint64]*b.Block, + executionAnchorAfterReplay uint64, +) map[string]interface{} { + artifact := map[string]interface{}{ + "type": "bankhash_mismatch", + "checked_slot": block.Slot, + "our_bankhash": base58.Encode(actualBankhash[:]), + "winning_bankhash": base58.Encode(path.leafBankhash[:]), + "policy": consensusPolicy, + "run_id": CurrentRunID, + "created_at": time.Now().UTC().Format(time.RFC3339Nano), + "path_anchor_slot": path.anchorSlot, + "execution_anchor_after_replay": executionAnchorAfterReplay, + "source": map[string]interface{}{ + "current_source": fetchStats.CurrentSource, + "source_status": fetchStats.SourceStatus, + "is_near_tip": fetchStats.IsNearTip, + "next_slot": fetchStats.NextSlot, + "confirmed_tip": fetchStats.ConfirmedTip, + "processed_tip": fetchStats.ProcessedTip, + "handoff_slot": fetchStats.HandoffSlot, + "waiting_slot_state": fetchStats.WaitingSlotState, + "waiting_slot_retries": fetchStats.WaitingSlotRetries, + "inflight": fetchStats.InflightCount, + "retry_queue_len": fetchStats.RetryQueueLen, + "stream_buffer_depth": fetchStats.BufferDepth, + "reorder_buffer_len": fetchStats.ReorderBufLen, + }, + "block": consensusBlockDiagnostic(block), + "forkchoice_vote_summary": forkChoice.SlotVoteDiagnostics(path.leafSlot), + "parent_vote_summary": forkChoice.SlotVoteDiagnostics(block.ParentSlot), + "observed_consensus_blocks": consensusObservedBlocksDiagnostic(observedConsensusBlocks), + "resolved_path": map[string]interface{}{ + "anchor_slot": path.anchorSlot, + "leaf_slot": path.leafSlot, + "leaf_bankhash": base58.Encode(path.leafBankhash[:]), + "remaining_decisions": consensusDecisionDiagnostics(path.decisions), + "original_decisions": consensusDecisionDiagnostics(path.originalDecisions), + }, + } + if slotCtx != nil { + artifact["slot_context"] = map[string]interface{}{ + "slot": slotCtx.Slot, + "parent_slot": slotCtx.ParentSlot, + "epoch": slotCtx.Epoch, + "blockhash": consensusHashString(slotCtx.Blockhash), + "last_blockhash": consensusHashString(slotCtx.LastBlockhash), + "latest_evicted_blockhash": consensusHashString(slotCtx.LatestEvictedBlockhash), + "final_bankhash": consensusByteHashString(slotCtx.FinalBankhash), + "accts_lthash_checksum": consensusLtHashChecksum(slotCtx.AcctsLtHash), + "num_signatures": slotCtx.NumSignatures, + "lamports_burnt": slotCtx.LamportsBurnt, + "total_compute_units_consumed": slotCtx.TotalComputeUnitsConsumed, + "modified_account_count": len(slotCtx.ModifiedAccts), + "writable_account_count": len(slotCtx.WritableAccts), + "total_epoch_stake": slotCtx.TotalEpochStake, + } + } + return artifact +} diff --git a/pkg/replay/epoch.go b/pkg/replay/epoch.go index 6c805c2e..b2e5e93a 100644 --- a/pkg/replay/epoch.go +++ b/pkg/replay/epoch.go @@ -324,6 +324,11 @@ func rebuildAuthorizedVotersFromVoteCache(epoch uint64) { if err == nil { newCache.PutEntry(voteAcct, voter) } + case sealevel.VoteStateVersionV4: + voter, _, err := voteState.V4.AuthorizedVoters.GetOrCalculateAuthorizedVoterForEpoch(epoch) + if err == nil { + newCache.PutEntry(voteAcct, voter) + } } } diff --git a/pkg/replay/epoch_authorized_voters_test.go b/pkg/replay/epoch_authorized_voters_test.go new file mode 100644 index 00000000..e03d7069 --- /dev/null +++ b/pkg/replay/epoch_authorized_voters_test.go @@ -0,0 +1,36 @@ +package replay + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/global" + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/gagliardetto/solana-go" + "github.com/stretchr/testify/require" +) + +func TestRebuildAuthorizedVotersFromVoteCacheIncludesV4(t *testing.T) { + epoch := uint64(973) + voteAcct := solana.PublicKey{0x44} + authorizedVoter := solana.PublicKey{0x55} + + oldCache := global.EpochAuthorizedVoters() + defer global.SetEpochAuthorizedVoters(oldCache) + defer global.DeleteVoteCacheItem(voteAcct) + + var authorizedVoters sealevel.AuthorizedVoters + authorizedVoters.AuthorizedVoters.Set(epoch, authorizedVoter) + + global.PutVoteCacheItem(voteAcct, &sealevel.VoteStateVersions{ + Type: sealevel.VoteStateVersionV4, + V4: sealevel.VoteState4{ + AuthorizedVoters: authorizedVoters, + }, + }) + + rebuildAuthorizedVotersFromVoteCache(epoch) + + cache := global.EpochAuthorizedVoters() + require.NotNil(t, cache) + require.True(t, cache.IsAuthorizedVoter(voteAcct, authorizedVoter)) +} diff --git a/pkg/replay/epoch_schedule.go b/pkg/replay/epoch_schedule.go new file mode 100644 index 00000000..79aa8e75 --- /dev/null +++ b/pkg/replay/epoch_schedule.go @@ -0,0 +1,43 @@ +package replay + +import ( + "fmt" + + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/Overclock-Validator/mithril/pkg/state" +) + +func bankEpochScheduleFromState(s *state.MithrilState) (*sealevel.SysvarEpochSchedule, bool) { + if s == nil || s.ManifestEpochSchedule == nil || s.ManifestEpochSchedule.SlotsPerEpoch == 0 { + return nil, false + } + + return &sealevel.SysvarEpochSchedule{ + SlotsPerEpoch: s.ManifestEpochSchedule.SlotsPerEpoch, + LeaderScheduleSlotOffset: s.ManifestEpochSchedule.LeaderScheduleSlotOffset, + Warmup: s.ManifestEpochSchedule.Warmup, + FirstNormalEpoch: s.ManifestEpochSchedule.FirstNormalEpoch, + FirstNormalSlot: s.ManifestEpochSchedule.FirstNormalSlot, + }, true +} + +func bankEpochScheduleForReplay(s *state.MithrilState) (*sealevel.SysvarEpochSchedule, bool, error) { + if epochSchedule, ok := bankEpochScheduleFromState(s); ok { + return epochSchedule, true, nil + } + if sealevel.SysvarCache.EpochSchedule.Sysvar != nil { + return sealevel.SysvarCache.EpochSchedule.Sysvar, false, nil + } + return nil, false, fmt.Errorf("missing epoch schedule") +} + +func epochSchedulesEqual(a, b *sealevel.SysvarEpochSchedule) bool { + if a == nil || b == nil { + return a == b + } + return a.SlotsPerEpoch == b.SlotsPerEpoch && + a.LeaderScheduleSlotOffset == b.LeaderScheduleSlotOffset && + a.Warmup == b.Warmup && + a.FirstNormalEpoch == b.FirstNormalEpoch && + a.FirstNormalSlot == b.FirstNormalSlot +} diff --git a/pkg/replay/epoch_stakes_seed.go b/pkg/replay/epoch_stakes_seed.go new file mode 100644 index 00000000..f14a6c1f --- /dev/null +++ b/pkg/replay/epoch_stakes_seed.go @@ -0,0 +1,126 @@ +package replay + +import ( + "encoding/json" + "fmt" + "sort" + + "github.com/Overclock-Validator/mithril/pkg/epochstakes" + "github.com/Overclock-Validator/mithril/pkg/state" +) + +type manifestEpochStakeSeed struct { + sourceEpoch uint64 + runtimeEpoch uint64 + data []byte +} + +func prepareManifestEpochStakesForRuntime(mithrilState *state.MithrilState, currentEpoch uint64, snapshotEpoch uint64) ([]manifestEpochStakeSeed, bool, error) { + if mithrilState == nil || len(mithrilState.ManifestEpochStakes) == 0 { + return nil, false, fmt.Errorf("state file missing manifest_epoch_stakes - delete AccountsDB and rebuild from snapshot") + } + if snapshotEpoch == 0 { + snapshotEpoch = currentEpoch + } + + keys := sortedManifestEpochStakeKeys(mithrilState.ManifestEpochStakes) + needsRebase := !manifestEpochStakeKeyExists(mithrilState.ManifestEpochStakes, currentEpoch) && + !manifestEpochStakeKeyExists(mithrilState.ManifestEpochStakes, snapshotEpoch) + sourceSnapshotEpoch := snapshotEpoch + if needsRebase { + var ok bool + sourceSnapshotEpoch, ok = inferManifestSourceSnapshotEpoch(keys, mithrilState.ManifestParentSlot) + if !ok { + return nil, false, fmt.Errorf("state file manifest_epoch_stakes has no epoch keys") + } + } + + seeds := make([]manifestEpochStakeSeed, 0, len(keys)) + for _, sourceEpoch := range keys { + runtimeEpoch := sourceEpoch + if needsRebase { + var ok bool + runtimeEpoch, ok = rebaseManifestEpoch(sourceEpoch, sourceSnapshotEpoch, snapshotEpoch) + if !ok { + return nil, false, fmt.Errorf("cannot rebase manifest epoch %d from source snapshot epoch %d to runtime snapshot epoch %d", + sourceEpoch, sourceSnapshotEpoch, snapshotEpoch) + } + } + + data := []byte(mithrilState.ManifestEpochStakes[sourceEpoch]) + if runtimeEpoch != sourceEpoch { + var persisted epochstakes.PersistedEpochStakes + if err := json.Unmarshal(data, &persisted); err != nil { + return nil, false, fmt.Errorf("failed to decode manifest epoch %d stakes for rebase: %w", sourceEpoch, err) + } + persisted.Epoch = runtimeEpoch + rebased, err := json.Marshal(persisted) + if err != nil { + return nil, false, fmt.Errorf("failed to encode manifest epoch %d stakes rebased to %d: %w", sourceEpoch, runtimeEpoch, err) + } + data = rebased + } + + seeds = append(seeds, manifestEpochStakeSeed{ + sourceEpoch: sourceEpoch, + runtimeEpoch: runtimeEpoch, + data: data, + }) + } + + return seeds, needsRebase, nil +} + +func sortedManifestEpochStakeKeys(stakes map[uint64]string) []uint64 { + keys := make([]uint64, 0, len(stakes)) + for epoch := range stakes { + keys = append(keys, epoch) + } + sort.Slice(keys, func(i, j int) bool { return keys[i] < keys[j] }) + return keys +} + +func manifestEpochStakeKeyExists(stakes map[uint64]string, epoch uint64) bool { + _, exists := stakes[epoch] + return exists +} + +func inferManifestSourceSnapshotEpoch(keys []uint64, parentSlot uint64) (uint64, bool) { + if len(keys) == 0 { + return 0, false + } + + // The affected devnet snapshots have manifest Bank epoch data serialized in + // the 432k-slot frame even though the EpochSchedule sysvar account uses 8192. + const agaveDefaultSlotsPerEpoch = 432000 + if parentSlot > 0 { + legacyEpoch := parentSlot / agaveDefaultSlotsPerEpoch + if uint64SliceContains(keys, legacyEpoch) { + return legacyEpoch, true + } + } + + // VersionedEpochStakes normally carries the snapshot epoch plus the next + // leader-schedule epoch, so the penultimate key is the best fallback. + if len(keys) >= 2 { + return keys[len(keys)-2], true + } + return keys[0], true +} + +func uint64SliceContains(values []uint64, needle uint64) bool { + idx := sort.Search(len(values), func(i int) bool { return values[i] >= needle }) + return idx < len(values) && values[idx] == needle +} + +func rebaseManifestEpoch(sourceEpoch uint64, sourceSnapshotEpoch uint64, runtimeSnapshotEpoch uint64) (uint64, bool) { + if runtimeSnapshotEpoch >= sourceSnapshotEpoch { + return sourceEpoch + (runtimeSnapshotEpoch - sourceSnapshotEpoch), true + } + + delta := sourceSnapshotEpoch - runtimeSnapshotEpoch + if sourceEpoch < delta { + return 0, false + } + return sourceEpoch - delta, true +} diff --git a/pkg/replay/epoch_stakes_seed_test.go b/pkg/replay/epoch_stakes_seed_test.go new file mode 100644 index 00000000..6f1cbda2 --- /dev/null +++ b/pkg/replay/epoch_stakes_seed_test.go @@ -0,0 +1,86 @@ +package replay + +import ( + "encoding/json" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/epochstakes" + "github.com/Overclock-Validator/mithril/pkg/state" +) + +func TestPrepareManifestEpochStakesForRuntimeRepairsPreviouslyRebasedDevnetFrame(t *testing.T) { + const ( + parentSlot = 463538376 + snapshotEpoch = 1073 + currentEpoch = 1073 + ) + + mithrilState := &state.MithrilState{ + ManifestParentSlot: parentSlot, + ManifestEpochStakes: make(map[uint64]string), + } + for _, epoch := range []uint64{56581, 56582, 56583, 56584, 56585} { + mithrilState.ManifestEpochStakes[epoch] = persistedEpochStakeJSON(t, epoch) + } + + seeds, rebased, err := prepareManifestEpochStakesForRuntime(mithrilState, currentEpoch, snapshotEpoch) + if err != nil { + t.Fatalf("prepareManifestEpochStakesForRuntime returned error: %v", err) + } + if !rebased { + t.Fatalf("expected manifest epoch stakes to be rebased") + } + + wantEpochs := []uint64{1070, 1071, 1072, 1073, 1074} + if len(seeds) != len(wantEpochs) { + t.Fatalf("expected %d seeds, got %d", len(wantEpochs), len(seeds)) + } + for i, wantEpoch := range wantEpochs { + if seeds[i].runtimeEpoch != wantEpoch { + t.Fatalf("seed %d runtime epoch = %d, want %d", i, seeds[i].runtimeEpoch, wantEpoch) + } + var persisted epochstakes.PersistedEpochStakes + if err := json.Unmarshal(seeds[i].data, &persisted); err != nil { + t.Fatalf("failed to decode seed %d: %v", i, err) + } + if persisted.Epoch != wantEpoch { + t.Fatalf("seed %d payload epoch = %d, want %d", i, persisted.Epoch, wantEpoch) + } + } +} + +func TestPrepareManifestEpochStakesForRuntimeKeepsRuntimeFrame(t *testing.T) { + mithrilState := &state.MithrilState{ + ManifestParentSlot: 463538376, + ManifestEpochStakes: map[uint64]string{ + 1073: persistedEpochStakeJSON(t, 1073), + 1074: persistedEpochStakeJSON(t, 1074), + }, + } + + seeds, rebased, err := prepareManifestEpochStakesForRuntime(mithrilState, 1073, 1073) + if err != nil { + t.Fatalf("prepareManifestEpochStakesForRuntime returned error: %v", err) + } + if rebased { + t.Fatalf("did not expect manifest epoch stakes to be rebased") + } + if len(seeds) != 2 || seeds[0].runtimeEpoch != 1073 || seeds[1].runtimeEpoch != 1074 { + t.Fatalf("unexpected seeds: %#v", seeds) + } +} + +func persistedEpochStakeJSON(t *testing.T, epoch uint64) string { + t.Helper() + + data, err := json.Marshal(epochstakes.PersistedEpochStakes{ + Epoch: epoch, + TotalStake: 42, + Stakes: map[string]uint64{}, + VoteAccts: map[string]*epochstakes.VoteAccountJSON{}, + }) + if err != nil { + t.Fatalf("failed to marshal epoch stakes: %v", err) + } + return string(data) +} diff --git a/pkg/replay/sysvar.go b/pkg/replay/sysvar.go index 7e8a8b24..05ccf84b 100644 --- a/pkg/replay/sysvar.go +++ b/pkg/replay/sysvar.go @@ -18,13 +18,18 @@ const nsPerSlot = 400000000 const maxAllowableDriftFast = 25 const maxAllowableDriftSlow = 150 -func updateClockSysvar(clock *sealevel.SysvarClock, block *block.Block) error { - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar +func updateClockSysvar(clock *sealevel.SysvarClock, block *block.Block, epochSchedule *sealevel.SysvarEpochSchedule) error { + epochOld := clock.Epoch + epochNew := block.Epoch + + if epochOld != epochNew && epochOld+1 != epochNew { + return fmt.Errorf("unexpected epoch transition in Clock sysvar: clock epoch %d, block epoch %d at slot %d", epochOld, epochNew, block.Slot) + } if global.CalcUnixTimeForClockSysvar() { firstSlotInEpoch := epochSchedule.FirstSlotInEpoch(clock.Epoch) epochStartTimestamp := clock.EpochStartTimestamp - timestampEstimate := getTimestampEstimate(block.Slot, firstSlotInEpoch, epochStartTimestamp) + timestampEstimate := getTimestampEstimate(block.Slot, firstSlotInEpoch, epochStartTimestamp, epochSchedule) if timestampEstimate > clock.UnixTimestamp { clock.UnixTimestamp = timestampEstimate } @@ -33,8 +38,6 @@ func updateClockSysvar(clock *sealevel.SysvarClock, block *block.Block) error { } clock.Slot = block.Slot - epochOld := clock.Epoch - epochNew := block.Epoch clock.Epoch = epochNew if epochOld != epochNew { @@ -51,9 +54,7 @@ type tsEntry struct { timestamp int64 } -func getTimestampEstimate(slot uint64, epochStartTimestampSlot uint64, epochStartTimestamp int64) int64 { - epochSchedule := sealevel.SysvarCache.EpochSchedule.Sysvar - +func getTimestampEstimate(slot uint64, epochStartTimestampSlot uint64, epochStartTimestamp int64, epochSchedule *sealevel.SysvarEpochSchedule) int64 { slotsPerEpoch := epochSchedule.SlotsPerEpoch voteAccts := global.VoteCache() diff --git a/pkg/replay/sysvar_clock_test.go b/pkg/replay/sysvar_clock_test.go new file mode 100644 index 00000000..353741a3 --- /dev/null +++ b/pkg/replay/sysvar_clock_test.go @@ -0,0 +1,52 @@ +package replay + +import ( + "testing" + + "github.com/Overclock-Validator/mithril/pkg/block" + "github.com/Overclock-Validator/mithril/pkg/global" + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/stretchr/testify/require" +) + +func TestUpdateClockSysvarUsesBankEpochSchedule(t *testing.T) { + prevCalcUnixTime := global.CalcUnixTimeForClockSysvar() + t.Cleanup(func() { + global.SetCalcUnixTimeForClockSysvar(prevCalcUnixTime) + }) + + bankEpochSchedule := &sealevel.SysvarEpochSchedule{ + SlotsPerEpoch: 432000, + LeaderScheduleSlotOffset: 432000, + } + global.SetCalcUnixTimeForClockSysvar(false) + + clock := &sealevel.SysvarClock{ + Slot: 463624424, + Epoch: 1073, + LeaderScheduleEpoch: 1074, + EpochStartTimestamp: 1779227894, + UnixTimestamp: 1779261346, + } + blk := &block.Block{ + Slot: 463624425, + Epoch: 1073, + UnixTimestamp: 1779261347, + } + + err := updateClockSysvar(clock, blk, bankEpochSchedule) + require.NoError(t, err) + + require.Equal(t, blk.Slot, clock.Slot) + require.Equal(t, blk.Epoch, clock.Epoch) + require.Equal(t, blk.UnixTimestamp, clock.UnixTimestamp) + require.Equal(t, uint64(1074), clock.LeaderScheduleEpoch) + require.Equal(t, int64(1779227894), clock.EpochStartTimestamp) +} + +func TestUpdateClockSysvarRejectsMismatchedEpochFrame(t *testing.T) { + clock := &sealevel.SysvarClock{Slot: 463624424, Epoch: 1073} + blk := &block.Block{Slot: 463624425, Epoch: 56592} + err := updateClockSysvar(clock, blk, &sealevel.SysvarEpochSchedule{SlotsPerEpoch: 8192, LeaderScheduleSlotOffset: 8192}) + require.Error(t, err) +} diff --git a/pkg/replay/transaction.go b/pkg/replay/transaction.go index 8f58582e..775fb744 100644 --- a/pkg/replay/transaction.go +++ b/pkg/replay/transaction.go @@ -514,6 +514,13 @@ func ProcessTransaction(slotCtx *sealevel.SlotCtx, sigverifyWg *sync.WaitGroup, // Handle transaction errors from the pure function if output.ProcessingResult.TransactionError != nil { txErr := output.ProcessingResult.TransactionError + if dbgOpts.IsDebugTx(tx.Signatures[0]) && execCtx != nil { + if logRecorder, ok := execCtx.Log.(*sealevel.LogRecorder); ok { + for _, l := range logRecorder.Logs { + mlog.Log.Debugf("%s", l) + } + } + } switch txErr.ErrorType { case TransactionErrorSanitizeFailure: diff --git a/pkg/rpcserver/rpcserver.go b/pkg/rpcserver/rpcserver.go index 5994d54a..1035bc49 100644 --- a/pkg/rpcserver/rpcserver.go +++ b/pkg/rpcserver/rpcserver.go @@ -6,7 +6,6 @@ import ( "net" "net/http" "net/http/httptest" - "net/netip" "strings" "sync" "time" @@ -30,19 +29,19 @@ type RpcServer struct { slotCtxMu sync.RWMutex leaderTPUCacheMu sync.RWMutex - leaderTPUByIdentity map[solana.PublicKey]netip.AddrPort + leaderTPUByIdentity map[solana.PublicKey]tpuEndpoint leaderTPUCacheUpdatedAt time.Time clusterNodesRefreshEvery time.Duration clusterNodesRefreshOnce sync.Once clusterRPCEndpoints []string clusterNodesFetcher clusterNodesFetcher - // packetSender is injectable for tests; production defaults to UDP. - packetSender packetSender + // transactionSender is injectable for tests; production supports QUIC with UDP fallback. + transactionSender transactionSender sendTransactionLeaderForwardCount uint64 } -func NewRpcServer(acctsDb *accountsdb.AccountsDb, port uint16) *RpcServer { +func NewRpcServer(acctsDb *accountsdb.AccountsDb, port uint16, epochSchedule *sealevel.SysvarEpochSchedule) *RpcServer { var err error rpcServer := &RpcServer{} @@ -63,11 +62,15 @@ func NewRpcServer(acctsDb *accountsdb.AccountsDb, port uint16) *RpcServer { rpcServer.rpcService.Register("MithrilRpc", rpcServer) rpcServer.acctsDb = acctsDb - rpcServer.epochSchedule = fetchAndUnmarshalEpochScheduleSysvar(acctsDb) - rpcServer.leaderTPUByIdentity = make(map[solana.PublicKey]netip.AddrPort) + if epochSchedule != nil { + rpcServer.epochSchedule = epochSchedule + } else { + rpcServer.epochSchedule = fetchAndUnmarshalEpochScheduleSysvar(acctsDb) + } + rpcServer.leaderTPUByIdentity = make(map[solana.PublicKey]tpuEndpoint) rpcServer.clusterNodesRefreshEvery = sendTransactionClusterNodesRefreshEvery rpcServer.clusterRPCEndpoints = configuredSendTransactionRPCEndpoints() - rpcServer.packetSender = defaultPacketSender + rpcServer.transactionSender = defaultTransactionSender rpcServer.sendTransactionLeaderForwardCount = sendTransactionLeaderForwardCount return rpcServer diff --git a/pkg/rpcserver/send_transaction.go b/pkg/rpcserver/send_transaction.go index 22885a1b..6b29b91e 100644 --- a/pkg/rpcserver/send_transaction.go +++ b/pkg/rpcserver/send_transaction.go @@ -8,6 +8,7 @@ import ( "net" "net/netip" "strings" + "sync" "time" "github.com/Overclock-Validator/mithril/pkg/config" @@ -24,7 +25,23 @@ import ( ) type clusterNodesFetcher func(context.Context) ([]*solanarpc.GetClusterNodesResult, error) -type packetSender func([]byte, netip.AddrPort) error +type transactionSender func(context.Context, []byte, tpuEndpoint) error + +type tpuTransport string + +const ( + tpuTransportUDP tpuTransport = "udp" + tpuTransportQUIC tpuTransport = "quic" +) + +type tpuEndpoint struct { + Addr netip.AddrPort + Transport tpuTransport +} + +func (endpoint tpuEndpoint) String() string { + return fmt.Sprintf("%s/%s", endpoint.Addr.String(), endpoint.Transport) +} type sendTransactionConfig struct { encoding string @@ -42,6 +59,7 @@ const ( sendTransactionTargetCount = sendTransactionLeaderForwardCount + 1 sendTransactionLeaderLookahead = 64 sendTransactionClusterNodesRefreshEvery = 10 * time.Minute + sendTransactionTPUSendTimeout = 3 * time.Second maxSanitizedInstructionCount = 64 ) @@ -400,25 +418,39 @@ func (rpcServer *RpcServer) forwardTransactionToUpcomingLeaders(ctx context.Cont targetCount = sendTransactionTargetCount } - targets, err := rpcServer.resolveUpcomingLeaderTPUAddresses(ctx, targetCount) + targets, err := rpcServer.resolveUpcomingLeaderTPUEndpoints(ctx, targetCount) if err != nil { return err } - send := rpcServer.packetSender + send := rpcServer.transactionSender if send == nil { - send = defaultPacketSender + send = defaultTransactionSender } var sendErrs []error sentCount := 0 + var sendMu sync.Mutex + var sendWg sync.WaitGroup for _, target := range targets { - if err := send(wire, target); err != nil { - sendErrs = append(sendErrs, fmt.Errorf("%s: %w", target.String(), err)) - continue - } - sentCount++ + target := target + sendWg.Add(1) + go func() { + defer sendWg.Done() + sendCtx, cancel := context.WithTimeout(ctx, sendTransactionTPUSendTimeout) + defer cancel() + + err := send(sendCtx, wire, target) + sendMu.Lock() + defer sendMu.Unlock() + if err != nil { + sendErrs = append(sendErrs, fmt.Errorf("%s: %w", target.String(), err)) + return + } + sentCount++ + }() } + sendWg.Wait() if sentCount == 0 { return fmt.Errorf("failed to forward transaction to any leader TPU: %w", errors.Join(sendErrs...)) @@ -429,12 +461,12 @@ func (rpcServer *RpcServer) forwardTransactionToUpcomingLeaders(ctx context.Cont return nil } -func (rpcServer *RpcServer) resolveUpcomingLeaderTPUAddresses(ctx context.Context, want int) ([]netip.AddrPort, error) { +func (rpcServer *RpcServer) resolveUpcomingLeaderTPUEndpoints(ctx context.Context, want int) ([]tpuEndpoint, error) { if want <= 0 { want = 1 } - targets, updatedAt := rpcServer.collectUpcomingLeaderTPUAddressesFromCache(want) + targets, updatedAt := rpcServer.collectUpcomingLeaderTPUEndpointsFromCache(want) cacheStale := updatedAt.IsZero() || time.Since(updatedAt) >= rpcServer.clusterNodesRefreshInterval() if cacheStale || len(targets) < want { if err := rpcServer.refreshLeaderTPUCache(ctx); err != nil { @@ -444,7 +476,7 @@ func (rpcServer *RpcServer) resolveUpcomingLeaderTPUAddresses(ctx context.Contex mlog.Log.Warnf("sendTransaction: using partial cached TPU target set after refresh failure: %v", err) return targets, nil } - targets, _ = rpcServer.collectUpcomingLeaderTPUAddressesFromCache(want) + targets, _ = rpcServer.collectUpcomingLeaderTPUEndpointsFromCache(want) } if len(targets) == 0 { @@ -488,20 +520,21 @@ func (rpcServer *RpcServer) refreshLeaderTPUCache(ctx context.Context) error { return err } - leaderTPUs := make(map[solana.PublicKey]netip.AddrPort, len(nodes)) + leaderTPUs := make(map[solana.PublicKey]tpuEndpoint, len(nodes)) for _, node := range nodes { - if node == nil || node.TPU == nil || *node.TPU == "" { + if node == nil { continue } - addr, err := netip.ParseAddrPort(*node.TPU) - if err != nil { + + endpoint, ok := leaderTPUEndpointFromClusterNode(node) + if !ok { continue } - leaderTPUs[node.Pubkey] = addr + leaderTPUs[node.Pubkey] = endpoint } if len(leaderTPUs) == 0 { - return fmt.Errorf("cluster did not advertise any TPU UDP endpoints") + return fmt.Errorf("cluster did not advertise any TPU endpoints") } rpcServer.leaderTPUCacheMu.Lock() @@ -511,19 +544,37 @@ func (rpcServer *RpcServer) refreshLeaderTPUCache(ctx context.Context) error { return nil } -func (rpcServer *RpcServer) collectUpcomingLeaderTPUAddressesFromCache(want int) ([]netip.AddrPort, time.Time) { +func leaderTPUEndpointFromClusterNode(node *solanarpc.GetClusterNodesResult) (tpuEndpoint, bool) { + if endpoint, ok := parseLeaderTPUEndpoint(node.TPUQUIC, tpuTransportQUIC); ok { + return endpoint, true + } + return parseLeaderTPUEndpoint(node.TPU, tpuTransportUDP) +} + +func parseLeaderTPUEndpoint(raw *string, transport tpuTransport) (tpuEndpoint, bool) { + if raw == nil || *raw == "" { + return tpuEndpoint{}, false + } + addr, err := netip.ParseAddrPort(*raw) + if err != nil { + return tpuEndpoint{}, false + } + return tpuEndpoint{Addr: addr, Transport: transport}, true +} + +func (rpcServer *RpcServer) collectUpcomingLeaderTPUEndpointsFromCache(want int) ([]tpuEndpoint, time.Time) { rpcServer.leaderTPUCacheMu.RLock() - nodeTPUs := make(map[solana.PublicKey]netip.AddrPort, len(rpcServer.leaderTPUByIdentity)) - for leader, addr := range rpcServer.leaderTPUByIdentity { - nodeTPUs[leader] = addr + nodeTPUs := make(map[solana.PublicKey]tpuEndpoint, len(rpcServer.leaderTPUByIdentity)) + for leader, endpoint := range rpcServer.leaderTPUByIdentity { + nodeTPUs[leader] = endpoint } updatedAt := rpcServer.leaderTPUCacheUpdatedAt rpcServer.leaderTPUCacheMu.RUnlock() currentSlot := global.Slot() - targets := make([]netip.AddrPort, 0, want) + targets := make([]tpuEndpoint, 0, want) seenLeaders := make(map[solana.PublicKey]struct{}, want) - seenTargets := make(map[netip.AddrPort]struct{}, want) + seenTargets := make(map[tpuEndpoint]struct{}, want) for offset := uint64(0); offset < sendTransactionLeaderLookahead && len(targets) < want; offset++ { leader, ok := global.LeaderForSlot(currentSlot + offset) @@ -557,7 +608,18 @@ func configuredSendTransactionRPCEndpoints() []string { return endpoints } -func defaultPacketSender(payload []byte, target netip.AddrPort) error { +func defaultTransactionSender(ctx context.Context, payload []byte, target tpuEndpoint) error { + switch target.Transport { + case tpuTransportQUIC: + return defaultTPUQUICSender.Send(ctx, payload, target.Addr) + case tpuTransportUDP: + return defaultUDPPacketSender(payload, target.Addr) + default: + return fmt.Errorf("unsupported TPU transport %q", target.Transport) + } +} + +func defaultUDPPacketSender(payload []byte, target netip.AddrPort) error { conn, err := net.ListenUDP("udp", nil) if err != nil { return err diff --git a/pkg/rpcserver/send_transaction_test.go b/pkg/rpcserver/send_transaction_test.go index 95441546..3ef4eae2 100644 --- a/pkg/rpcserver/send_transaction_test.go +++ b/pkg/rpcserver/send_transaction_test.go @@ -5,6 +5,7 @@ import ( "encoding/base64" "encoding/json" "net" + "net/netip" "testing" "time" @@ -126,7 +127,7 @@ func TestSendTransaction_SkipPreflight_FansOutToUpcomingLeaders(t *testing.T) { tx, wire := testLegacyTransaction(t) fetchCount := 0 rpcServer := &RpcServer{ - packetSender: defaultPacketSender, + transactionSender: defaultTransactionSender, clusterNodesRefreshEvery: sendTransactionClusterNodesRefreshEvery, sendTransactionLeaderForwardCount: sendTransactionLeaderForwardCount, clusterNodesFetcher: func(context.Context) ([]*solanarpc.GetClusterNodesResult, error) { @@ -161,7 +162,7 @@ func TestSendTransaction_SkipPreflight_FansOutToUpcomingLeaders(t *testing.T) { assert.Equal(t, wire, mustReadUDP(t, listenerF)) } -func TestResolveUpcomingLeaderTPUAddresses_UsesFreshCacheWithoutRefetch(t *testing.T) { +func TestResolveUpcomingLeaderTPUEndpoints_UsesFreshCacheWithoutRefetch(t *testing.T) { leaderA := solana.PublicKey{0x01} leaderB := solana.PublicKey{0x02} leaderC := solana.PublicKey{0x03} @@ -202,13 +203,13 @@ func TestResolveUpcomingLeaderTPUAddresses_UsesFreshCacheWithoutRefetch(t *testi } require.NoError(t, rpcServer.refreshLeaderTPUCache(context.Background())) - targets, err := rpcServer.resolveUpcomingLeaderTPUAddresses(context.Background(), sendTransactionTargetCount) + targets, err := rpcServer.resolveUpcomingLeaderTPUEndpoints(context.Background(), 6) require.NoError(t, err) - require.Len(t, targets, sendTransactionTargetCount) + require.Len(t, targets, 6) assert.Equal(t, 1, fetchCount, "fresh cache should satisfy resolution without another RPC poll") } -func TestResolveUpcomingLeaderTPUAddresses_RefreshesStaleCache(t *testing.T) { +func TestResolveUpcomingLeaderTPUEndpoints_RefreshesStaleCache(t *testing.T) { leaderA := solana.PublicKey{0x11} leaderB := solana.PublicKey{0x12} leaderC := solana.PublicKey{0x13} @@ -253,12 +254,47 @@ func TestResolveUpcomingLeaderTPUAddresses_RefreshesStaleCache(t *testing.T) { rpcServer.leaderTPUCacheUpdatedAt = time.Now().Add(-11 * time.Minute) rpcServer.leaderTPUCacheMu.Unlock() - targets, err := rpcServer.resolveUpcomingLeaderTPUAddresses(context.Background(), sendTransactionTargetCount) + targets, err := rpcServer.resolveUpcomingLeaderTPUEndpoints(context.Background(), 6) require.NoError(t, err) - require.Len(t, targets, sendTransactionTargetCount) + require.Len(t, targets, 6) assert.Equal(t, 2, fetchCount, "stale cache should trigger a refresh before resolving targets") } +func TestRefreshLeaderTPUCache_PrefersQUICEndpoints(t *testing.T) { + leaderA := solana.PublicKey{0x21} + leaderB := solana.PublicKey{0x22} + leaderC := solana.PublicKey{0x23} + + rpcServer := &RpcServer{ + clusterNodesFetcher: func(context.Context) ([]*solanarpc.GetClusterNodesResult, error) { + return []*solanarpc.GetClusterNodesResult{ + { + Pubkey: leaderA, + TPU: stringPtr("127.0.0.1:9001"), + TPUQUIC: stringPtr("127.0.0.1:10001"), + }, + { + Pubkey: leaderB, + TPUQUIC: stringPtr("127.0.0.1:10002"), + }, + { + Pubkey: leaderC, + TPU: stringPtr("127.0.0.1:9003"), + }, + }, nil + }, + } + + require.NoError(t, rpcServer.refreshLeaderTPUCache(context.Background())) + + rpcServer.leaderTPUCacheMu.RLock() + defer rpcServer.leaderTPUCacheMu.RUnlock() + + assert.Equal(t, tpuEndpoint{Addr: netip.MustParseAddrPort("127.0.0.1:10001"), Transport: tpuTransportQUIC}, rpcServer.leaderTPUByIdentity[leaderA]) + assert.Equal(t, tpuEndpoint{Addr: netip.MustParseAddrPort("127.0.0.1:10002"), Transport: tpuTransportQUIC}, rpcServer.leaderTPUByIdentity[leaderB]) + assert.Equal(t, tpuEndpoint{Addr: netip.MustParseAddrPort("127.0.0.1:9003"), Transport: tpuTransportUDP}, rpcServer.leaderTPUByIdentity[leaderC]) +} + func mustRawParams(t *testing.T, params []interface{}) jsonrpc.RawParams { t.Helper() raw, err := json.Marshal(params) diff --git a/pkg/rpcserver/tpu_quic.go b/pkg/rpcserver/tpu_quic.go new file mode 100644 index 00000000..b780682a --- /dev/null +++ b/pkg/rpcserver/tpu_quic.go @@ -0,0 +1,192 @@ +package rpcserver + +import ( + "context" + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net/netip" + "strconv" + "strings" + "sync" + "time" + + "github.com/quic-go/quic-go" +) + +const ( + tpuQUICALPN = "solana-tpu" + tpuQUICMaxIdleTimeout = 30 * time.Second + tpuQUICKeepAlive = time.Second + tpuQUICHandshakeTimeout = 3 * time.Second +) + +var defaultTPUQUICSender = newTPUQUICSender() + +type tpuQUICSender struct { + mu sync.Mutex + conns map[netip.AddrPort]*quic.Conn + tlsConf *tls.Config + quicConf *quic.Config +} + +func newTPUQUICSender() *tpuQUICSender { + cert, err := newTPUQUICClientCertificate() + if err != nil { + panic(fmt.Sprintf("create TPU QUIC client certificate: %v", err)) + } + + return &tpuQUICSender{ + conns: make(map[netip.AddrPort]*quic.Conn), + tlsConf: &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS13, + NextProtos: []string{tpuQUICALPN}, + ClientSessionCache: tls.NewLRUClientSessionCache(128), + }, + quicConf: &quic.Config{ + HandshakeIdleTimeout: tpuQUICHandshakeTimeout, + MaxIdleTimeout: tpuQUICMaxIdleTimeout, + KeepAlivePeriod: tpuQUICKeepAlive, + TokenStore: quic.NewLRUTokenStore(128, 4), + }, + } +} + +func (sender *tpuQUICSender) Send(ctx context.Context, payload []byte, addr netip.AddrPort) error { + var lastErr error + for attempt := 0; attempt < 2; attempt++ { + conn, err := sender.connection(ctx, addr) + if err != nil { + lastErr = err + continue + } + + if err := sender.sendOnConnection(ctx, conn, payload); err != nil { + lastErr = err + sender.dropConnection(addr, conn) + continue + } + return nil + } + return lastErr +} + +func (sender *tpuQUICSender) connection(ctx context.Context, addr netip.AddrPort) (*quic.Conn, error) { + sender.mu.Lock() + if conn := sender.conns[addr]; conn != nil { + if conn.Context().Err() == nil { + sender.mu.Unlock() + return conn, nil + } + delete(sender.conns, addr) + } + sender.mu.Unlock() + + conn, err := quic.DialAddr(ctx, addr.String(), sender.tlsConfigFor(addr), sender.quicConf) + if err != nil { + return nil, err + } + + sender.mu.Lock() + defer sender.mu.Unlock() + if existing := sender.conns[addr]; existing != nil && existing.Context().Err() == nil { + _ = conn.CloseWithError(0, "superseded") + return existing, nil + } + sender.conns[addr] = conn + return conn, nil +} + +func (sender *tpuQUICSender) tlsConfigFor(addr netip.AddrPort) *tls.Config { + conf := sender.tlsConf.Clone() + conf.ServerName = tpuQUICServerName(addr) + return conf +} + +func (sender *tpuQUICSender) sendOnConnection(ctx context.Context, conn *quic.Conn, payload []byte) error { + stream, err := conn.OpenUniStreamSync(ctx) + if err != nil { + return err + } + + if deadline, ok := ctx.Deadline(); ok { + _ = stream.SetWriteDeadline(deadline) + } + + remaining := payload + for len(remaining) > 0 { + written, err := stream.Write(remaining) + if err != nil { + stream.CancelWrite(0) + return err + } + if written == 0 { + stream.CancelWrite(0) + return fmt.Errorf("short QUIC stream write: wrote 0 of %d bytes", len(remaining)) + } + remaining = remaining[written:] + } + + if err := stream.Close(); err != nil { + return err + } + return nil +} + +func (sender *tpuQUICSender) dropConnection(addr netip.AddrPort, conn *quic.Conn) { + sender.mu.Lock() + if sender.conns[addr] == conn { + delete(sender.conns, addr) + } + sender.mu.Unlock() + _ = conn.CloseWithError(0, "dropped") +} + +func tpuQUICServerName(addr netip.AddrPort) string { + host := addr.Addr().String() + if addr.Addr().Is6() { + host = strings.ReplaceAll(host, ":", "-") + } + return host + "." + strconv.Itoa(int(addr.Port())) + ".sol" +} + +func newTPUQUICClientCertificate() (tls.Certificate, error) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + if err != nil { + return tls.Certificate{}, err + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + CommonName: "Solana node", + }, + NotBefore: time.Unix(0, 0), + NotAfter: time.Date(4096, 1, 1, 0, 0, 0, 0, time.UTC), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + DNSNames: []string{"localhost"}, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + if err != nil { + return tls.Certificate{}, err + } + cert, err := x509.ParseCertificate(certDER) + if err != nil { + return tls.Certificate{}, err + } + + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: priv, + Leaf: cert, + }, nil +} diff --git a/pkg/rpcserver/tpu_quic_test.go b/pkg/rpcserver/tpu_quic_test.go new file mode 100644 index 00000000..9bfc70ba --- /dev/null +++ b/pkg/rpcserver/tpu_quic_test.go @@ -0,0 +1,60 @@ +package rpcserver + +import ( + "context" + "crypto/tls" + "io" + "net/netip" + "testing" + "time" + + "github.com/quic-go/quic-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTPUQUICSenderSendsOneTransactionPerUniStream(t *testing.T) { + serverCert, err := newTPUQUICClientCertificate() + require.NoError(t, err) + + listener, err := quic.ListenAddr( + "127.0.0.1:0", + &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientAuth: tls.RequireAnyClientCert, + NextProtos: []string{tpuQUICALPN}, + MinVersion: tls.VersionTLS13, + }, + nil, + ) + require.NoError(t, err) + defer listener.Close() + + received := make(chan []byte, 1) + go func() { + conn, err := listener.Accept(context.Background()) + if err != nil { + received <- nil + return + } + stream, err := conn.AcceptUniStream(context.Background()) + if err != nil { + received <- nil + return + } + payload, err := io.ReadAll(stream) + if err != nil { + received <- nil + return + } + received <- payload + }() + + addr := netip.MustParseAddrPort(listener.Addr().String()) + payload := []byte("wire transaction") + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + require.NoError(t, newTPUQUICSender().Send(ctx, payload, addr)) + assert.Equal(t, payload, <-received) +} diff --git a/pkg/snapshot/manifest_seed.go b/pkg/snapshot/manifest_seed.go index 0af6859c..9d144a97 100644 --- a/pkg/snapshot/manifest_seed.go +++ b/pkg/snapshot/manifest_seed.go @@ -67,6 +67,13 @@ func PopulateManifestSeed(s *state.MithrilState, m *SnapshotManifest) { s.ManifestInflationTaper = m.Bank.Inflation.Taper s.ManifestInflationFoundation = m.Bank.Inflation.FoundationVal s.ManifestInflationFoundationTerm = m.Bank.Inflation.FoundationTerm + s.ManifestEpochSchedule = &state.ManifestEpochScheduleSeed{ + SlotsPerEpoch: m.Bank.EpochSchedule.SlotsPerEpoch, + LeaderScheduleSlotOffset: m.Bank.EpochSchedule.LeaderScheduleSlotOffset, + Warmup: m.Bank.EpochSchedule.Warmup, + FirstNormalEpoch: m.Bank.EpochSchedule.FirstNormalEpoch, + FirstNormalSlot: m.Bank.EpochSchedule.FirstNormalSlot, + } // Epoch account hash (base64 for consistency with LtHash) if m.EpochAccountHash != [32]byte{} { diff --git a/pkg/snapshot/manifest_seed_test.go b/pkg/snapshot/manifest_seed_test.go new file mode 100644 index 00000000..e87c0081 --- /dev/null +++ b/pkg/snapshot/manifest_seed_test.go @@ -0,0 +1,51 @@ +package snapshot + +import ( + "encoding/json" + "testing" + + "github.com/Overclock-Validator/mithril/pkg/epochstakes" + "github.com/Overclock-Validator/mithril/pkg/sealevel" + "github.com/Overclock-Validator/mithril/pkg/state" + "github.com/stretchr/testify/require" +) + +func TestPopulateManifestSeedKeepsManifestEpochFrame(t *testing.T) { + manifest := &SnapshotManifest{ + Bank: &DeserializableVersionedBank{ + Slot: 463538376, + Epoch: 1073, + EpochSchedule: sealevel.SysvarEpochSchedule{ + SlotsPerEpoch: 432000, + LeaderScheduleSlotOffset: 432000, + }, + }, + VersionedEpochStakes: []VersionedEpochStakesPair{ + { + Epoch: 1073, + Val: VersionedEpochStakes{ + TotalStake: 42, + Stakes: Stake{}, + }, + }, + }, + } + mithrilState := state.NewReadyState(manifest.Bank.Slot, 1073, "", "", 0, 0) + + PopulateManifestSeed(mithrilState, manifest) + + require.Equal(t, uint64(432000), mithrilState.ManifestEpochSchedule.SlotsPerEpoch) + require.Equal(t, uint64(432000), mithrilState.ManifestEpochSchedule.LeaderScheduleSlotOffset) + + data, exists := mithrilState.ManifestEpochStakes[1073] + if !exists { + t.Fatalf("expected manifest epoch stakes for epoch 1073, got keys %#v", mithrilState.ManifestEpochStakes) + } + var persisted epochstakes.PersistedEpochStakes + if err := json.Unmarshal([]byte(data), &persisted); err != nil { + t.Fatalf("failed to decode persisted epoch stakes: %v", err) + } + if persisted.Epoch != 1073 { + t.Fatalf("persisted epoch = %d, want 1073", persisted.Epoch) + } +} diff --git a/pkg/state/state.go b/pkg/state/state.go index b9bb706c..4ef42b9c 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -81,13 +81,14 @@ type MithrilState struct { ManifestEvictedBlockhash string `json:"manifest_evicted_blockhash,omitempty"` // base58 // ReplayCtx seed (inflation/capitalization at snapshot) - ManifestCapitalization uint64 `json:"manifest_capitalization,omitempty"` - ManifestSlotsPerYear float64 `json:"manifest_slots_per_year,omitempty"` - ManifestInflationInitial float64 `json:"manifest_inflation_initial,omitempty"` - ManifestInflationTerminal float64 `json:"manifest_inflation_terminal,omitempty"` - ManifestInflationTaper float64 `json:"manifest_inflation_taper,omitempty"` - ManifestInflationFoundation float64 `json:"manifest_inflation_foundation,omitempty"` - ManifestInflationFoundationTerm float64 `json:"manifest_inflation_foundation_term,omitempty"` + ManifestCapitalization uint64 `json:"manifest_capitalization,omitempty"` + ManifestSlotsPerYear float64 `json:"manifest_slots_per_year,omitempty"` + ManifestInflationInitial float64 `json:"manifest_inflation_initial,omitempty"` + ManifestInflationTerminal float64 `json:"manifest_inflation_terminal,omitempty"` + ManifestInflationTaper float64 `json:"manifest_inflation_taper,omitempty"` + ManifestInflationFoundation float64 `json:"manifest_inflation_foundation,omitempty"` + ManifestInflationFoundationTerm float64 `json:"manifest_inflation_foundation_term,omitempty"` + ManifestEpochSchedule *ManifestEpochScheduleSeed `json:"manifest_epoch_schedule,omitempty"` // Epoch account hash (base64 for consistency with LtHash) ManifestEpochAcctsHash string `json:"manifest_epoch_accts_hash,omitempty"` // base64 @@ -182,6 +183,17 @@ type ManifestFeeRateGovernorSeed struct { BurnPercent byte `json:"burn_percent"` } +// ManifestEpochScheduleSeed contains the bank epoch schedule serialized in the +// snapshot manifest. Some clusters can expose a divergent EpochSchedule sysvar +// account, so replay uses this bank schedule for epoch/leader/rewards logic. +type ManifestEpochScheduleSeed struct { + SlotsPerEpoch uint64 `json:"slots_per_epoch"` + LeaderScheduleSlotOffset uint64 `json:"leader_schedule_slot_offset"` + Warmup bool `json:"warmup"` + FirstNormalEpoch uint64 `json:"first_normal_epoch"` + FirstNormalSlot uint64 `json:"first_normal_slot"` +} + // SlotHashEntry represents a single entry in the SlotHashes sysvar type SlotHashEntry struct { Slot uint64 `json:"slot"`