diff --git a/eth/handler.go b/eth/handler.go index bd8e2c2bc375..b7a0e297f099 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -208,7 +208,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne } // Construct the different synchronisation mechanisms - manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain, nil, manager.removePeer, handleProposedBlock) + manager.downloader = downloader.New(chaindb, manager.eventMux, blockchain, nil, manager.removePeerByID, handleProposedBlock) validator := func(header *types.Header) error { return engine.VerifyHeader(blockchain, header, true) @@ -237,7 +237,7 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne atomic.StoreUint32(&manager.acceptTxs, 1) // Mark initial sync done on any fetcher import return manager.blockchain.PrepareBlock(block) } - manager.fetcher = fetcher.New(blockchain.GetBlockByHash, validator, handleProposedBlock, manager.BroadcastBlock, heighter, inserter, prepare, manager.removePeer) + manager.fetcher = fetcher.New(blockchain.GetBlockByHash, validator, handleProposedBlock, manager.BroadcastBlock, heighter, inserter, prepare, manager.removePeerByID) //Define bft function broadcasts := bft.BroadcastFns{ Vote: manager.BroadcastVote, @@ -261,23 +261,43 @@ func (pm *ProtocolManager) addLendingPoolProtocol(lendingpool lendingPool) { pm.lendingpool = lendingpool } -func (pm *ProtocolManager) removePeer(id string) { - // Short circuit if the peer was already removed - peer := pm.peers.Peer(id) +// removePeer disconnects a peer instance, unregistering it only when it is the +// current primary connection because the downloader invariant is primary-only. +func (pm *ProtocolManager) removePeer(peer *peer) { if peer == nil { return } - log.Debug("Removing Ethereum peer", "peer", id) + removedPrimary, err := pm.peers.UnregisterPeer(peer) + if err != nil { + if errors.Is(err, errPairNotRegistered) { + log.Debug("Stale paired peer removal", "peer", peer.id, "err", err) + } else if errors.Is(err, errNotRegistered) { + log.Debug("Peer already removed", "peer", peer.id) + } else { + log.Warn("Peer removal failed", "peer", peer.id, "err", err) + } + // Intentionally disconnect even on not-registered errors. For an + // already tearing-down peer this is redundant, and for a stale paired + // peer it is a harmless idempotent fallback that keeps cleanup robust. + peer.Peer.Disconnect(p2p.DiscUselessPeer) + return + } + log.Debug("Removing Ethereum peer", "peer", peer.id) - // Unregister the peer from the downloader and Ethereum peer set - pm.downloader.UnregisterPeer(id) - if err := pm.peers.Unregister(id); err != nil { - log.Debug("Peer removal failed", "peer", id, "err", err) + // Only the currently registered primary connection is tracked by the + // downloader. Paired connections skip downloader.RegisterPeer in handle. + if removedPrimary { + pm.downloader.UnregisterPeer(peer.id) } // Hard disconnect at the networking layer peer.Peer.Disconnect(p2p.DiscUselessPeer) } +// removePeerByID adapts downloader and fetcher callbacks that only expose a peer id. +func (pm *ProtocolManager) removePeerByID(id string) { + pm.removePeer(pm.peers.Peer(id)) +} + func (pm *ProtocolManager) Start(maxPeers int) { pm.maxPeers = maxPeers @@ -369,7 +389,7 @@ func (pm *ProtocolManager) handle(p *peer) error { p.Log().Error("Ethereum peer registration failed", "err", err) return err } - defer pm.removePeer(p.id) + defer pm.removePeer(p) if err != p2p.ErrAddPairPeer { // Register the peer in the downloader. If the downloader considers it banned, we disconnect if err := pm.downloader.RegisterPeer(p.id, p.version, p); err != nil { @@ -389,7 +409,7 @@ func (pm *ProtocolManager) handle(p *peer) error { // Start a timer to disconnect if the peer doesn't reply in time p.forkDrop = time.AfterFunc(daoChallengeTimeout, func() { p.Log().Debug("Timed out DAO fork-check, dropping") - pm.removePeer(p.id) + pm.removePeer(p) }) // Make sure it's cleaned up if the peer dies off defer func() { @@ -958,7 +978,7 @@ func (pm *ProtocolManager) BroadcastVote(vote *types.Vote) { err := peer.SendVote(vote) if err != nil { log.Debug("[BroadcastVote] Fail to broadcast vote message", "peerId", peer.id, "version", peer.version, "blockNum", vote.ProposedBlockInfo.Number, "err", err) - pm.removePeer(peer.id) + pm.removePeer(peer) } } log.Trace("Propagated Vote", "vote hash", vote.Hash(), "voted block hash", vote.ProposedBlockInfo.Hash.Hex(), "number", vote.ProposedBlockInfo.Number, "round", vote.ProposedBlockInfo.Round, "recipients", len(peers)) @@ -975,7 +995,7 @@ func (pm *ProtocolManager) BroadcastTimeout(timeout *types.Timeout) { err := peer.SendTimeout(timeout) if err != nil { log.Debug("[BroadcastTimeout] Fail to broadcast timeout message, remove peer", "peerId", peer.id, "version", peer.version, "timeout", timeout, "err", err) - pm.removePeer(peer.id) + pm.removePeer(peer) } } log.Trace("Propagated Timeout", "hash", hash, "recipients", len(peers)) @@ -992,7 +1012,7 @@ func (pm *ProtocolManager) BroadcastSyncInfo(syncInfo *types.SyncInfo) { err := peer.SendSyncInfo(syncInfo) if err != nil { log.Debug("[BroadcastSyncInfo] Fail to broadcast syncInfo message, remove peer", "peerId", peer.id, "version", peer.version, "syncInfo", syncInfo, "err", err) - pm.removePeer(peer.id) + pm.removePeer(peer) } } log.Trace("Propagated SyncInfo", "hash", hash, "recipients", len(peers)) diff --git a/eth/handler_test.go b/eth/handler_test.go index 3c1254c5af48..1e02a9c88f98 100644 --- a/eth/handler_test.go +++ b/eth/handler_test.go @@ -35,6 +35,7 @@ import ( "github.com/XinFinOrg/XDPoSChain/eth/ethconfig" "github.com/XinFinOrg/XDPoSChain/event" "github.com/XinFinOrg/XDPoSChain/p2p" + "github.com/XinFinOrg/XDPoSChain/p2p/discover" "github.com/XinFinOrg/XDPoSChain/params" ) @@ -73,6 +74,145 @@ func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) } // TestGetBlockHeaders63 tests get block headers 63. func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) } +func TestGetBlockHeadersAfterPairDrop63(t *testing.T) { + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) + defer pm.Stop() + + var id discover.NodeID + id[0] = 1 + + newPeerWithID := func(name string) (*testPeer, <-chan error) { + app, net := p2p.MsgPipe() + peer := pm.newPeer(eth63, p2p.NewPeer(id, name, nil), net) + + errc := make(chan error, 1) + go func() { + select { + case pm.newPeerCh <- peer: + errc <- pm.handle(peer) + case <-pm.quitSync: + errc <- p2p.DiscQuitting + } + }() + + return &testPeer{app: app, net: net, peer: peer}, errc + } + + primary, primaryErrc := newPeerWithID("primary") + defer primary.close() + pair, pairErrc := newPeerWithID("pair") + + var ( + genesis = pm.blockchain.Genesis() + head = pm.blockchain.CurrentHeader() + hash = head.Hash() + td = pm.blockchain.GetTd(hash, head.Number.Uint64()) + ) + primary.handshake(t, td, hash, genesis.Hash()) + pair.handshake(t, td, hash, genesis.Hash()) + + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + timeout := time.After(3 * time.Second) + for primary.pairWriter() == nil { + select { + case <-ticker.C: + case <-timeout: + t.Fatalf("pairing state not established: peers=%d pairRWSet=%t", pm.peers.Len(), primary.pairWriter() != nil) + } + } + + pair.close() + select { + case <-pairErrc: + case <-time.After(time.Second): + t.Fatal("timed out waiting for pair peer to disconnect") + } + + if got := pm.peers.Peer(primary.id); got != primary.peer { + t.Fatal("primary peer was removed after pair disconnect") + } + if primary.pairWriter() != nil { + t.Fatal("primary peer still has stale pair writer after pair disconnect") + } + if primary.PairPeer() != nil { + t.Fatal("primary peer still references pair peer after pair disconnect") + } + + query := &getBlockHeadersData{Origin: hashOrNumber{Number: 1}, Amount: 1} + expected := []*types.Header{pm.blockchain.GetBlockByNumber(1).Header()} + + if err := p2p.Send(primary.app, GetBlockHeadersMsg, query); err != nil { + t.Fatalf("failed to send header request from primary peer: %v", err) + } + if err := p2p.ExpectMsg(primary.app, BlockHeadersMsg, expected); err != nil { + t.Fatalf("failed to receive header response after pair disconnect: %v", err) + } + + select { + case err := <-primaryErrc: + t.Fatalf("primary peer disconnected unexpectedly: %v", err) + default: + } +} + +func TestRemovePeerKeepsDownloaderRegistrationForPair(t *testing.T) { + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) + defer pm.Stop() + + var id discover.NodeID + id[0] = 11 + + primary := pm.newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := pm.newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + + if err := pm.peers.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := pm.downloader.RegisterPeer(primary.id, primary.version, primary); err != nil { + t.Fatalf("register downloader primary: %v", err) + } + if err := pm.peers.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + pm.removePeer(pair) + + if err := pm.downloader.RegisterPeer(primary.id, primary.version, primary); err == nil { + t.Fatal("pair removal should keep downloader primary registration") + } + if got := pm.peers.Peer(primary.id); got != primary { + t.Fatal("pair removal should keep primary registered") + } + if primary.pairWriter() != nil { + t.Fatal("pair removal should clear the primary pair writer") + } +} + +func TestRemovePeerUnregistersDownloaderForPrimary(t *testing.T) { + pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil) + defer pm.Stop() + + var id discover.NodeID + id[0] = 12 + + primary := pm.newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + + if err := pm.peers.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := pm.downloader.RegisterPeer(primary.id, primary.version, primary); err != nil { + t.Fatalf("register downloader primary: %v", err) + } + pm.removePeer(primary) + + if err := pm.downloader.RegisterPeer(primary.id, primary.version, primary); err != nil { + t.Fatalf("primary removal should unregister downloader peer: %v", err) + } + if got := pm.peers.Peer(primary.id); got != nil { + t.Fatal("primary removal should unregister primary peer") + } +} + func testGetBlockHeaders(t *testing.T, protocol int) { pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil) peer, _ := newTestPeer("peer", protocol, pm, true) diff --git a/eth/peer.go b/eth/peer.go index b864f8cec244..e7c569392d21 100644 --- a/eth/peer.go +++ b/eth/peer.go @@ -21,10 +21,12 @@ import ( "fmt" "math/big" "sync" + "sync/atomic" "time" "github.com/XinFinOrg/XDPoSChain/common" "github.com/XinFinOrg/XDPoSChain/core/types" + "github.com/XinFinOrg/XDPoSChain/log" "github.com/XinFinOrg/XDPoSChain/p2p" "github.com/XinFinOrg/XDPoSChain/rlp" mapset "github.com/deckarep/golang-set/v2" @@ -34,6 +36,7 @@ var ( errClosed = errors.New("peer set is closed") errAlreadyRegistered = errors.New("peer is already registered") errNotRegistered = errors.New("peer is not registered") + errPairNotRegistered = errors.New("paired peer is not registered") ) const ( @@ -56,11 +59,12 @@ type PeerInfo struct { } type peer struct { - id string + id string + isPair bool *p2p.Peer rw p2p.MsgReadWriter - pairRw p2p.MsgReadWriter + pairRW atomic.Pointer[p2p.MsgReadWriter] version int // Protocol version negotiated forkDrop *time.Timer // Timed connection dropper if forks aren't validated in time @@ -99,6 +103,29 @@ func newPeer(version int, p *p2p.Peer, rw p2p.MsgReadWriter) *peer { } } +func (p *peer) pairWriter() p2p.MsgReadWriter { + pairRW := p.pairRW.Load() + if pairRW == nil { + return nil + } + return *pairRW +} + +func (p *peer) setPairWriter(rw *p2p.MsgReadWriter) { + p.pairRW.Store(rw) +} + +func (p *peer) clearPairWriter() { + p.pairRW.Store(nil) +} + +func (p *peer) msgWriter() p2p.MsgReadWriter { + if pairRW := p.pairWriter(); pairRW != nil { + return pairRW + } + return p.rw +} + // Info gathers and returns a collection of metadata known about a peer. func (p *peer) Info() *PeerInfo { hash, td := p.Head() @@ -262,59 +289,35 @@ func (p *peer) SendNewBlock(block *types.Block, td *big.Int) error { } p.knownBlocks.Add(block.Hash()) - if p.pairRw != nil { - return p2p.Send(p.pairRw, NewBlockMsg, []interface{}{block, td}) - } else { - return p2p.Send(p.rw, NewBlockMsg, []interface{}{block, td}) - } + return p2p.Send(p.msgWriter(), NewBlockMsg, []interface{}{block, td}) } // SendBlockHeaders sends a batch of block headers to the remote peer. func (p *peer) SendBlockHeaders(headers []*types.Header) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockHeadersMsg, headers) - } else { - return p2p.Send(p.rw, BlockHeadersMsg, headers) - } + return p2p.Send(p.msgWriter(), BlockHeadersMsg, headers) } // SendBlockBodies sends a batch of block contents to the remote peer. func (p *peer) SendBlockBodies(bodies []*blockBody) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockBodiesMsg, blockBodiesData(bodies)) - } else { - return p2p.Send(p.rw, BlockBodiesMsg, blockBodiesData(bodies)) - } + return p2p.Send(p.msgWriter(), BlockBodiesMsg, blockBodiesData(bodies)) } // SendBlockBodiesRLP sends a batch of block contents to the remote peer from // an already RLP encoded format. func (p *peer) SendBlockBodiesRLP(bodies []rlp.RawValue) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, BlockBodiesMsg, bodies) - } else { - return p2p.Send(p.rw, BlockBodiesMsg, bodies) - } + return p2p.Send(p.msgWriter(), BlockBodiesMsg, bodies) } // SendNodeDataRLP sends a batch of arbitrary internal data, corresponding to the // hashes requested. func (p *peer) SendNodeData(data [][]byte) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, NodeDataMsg, data) - } else { - return p2p.Send(p.rw, NodeDataMsg, data) - } + return p2p.Send(p.msgWriter(), NodeDataMsg, data) } // SendReceiptsRLP sends a batch of transaction receipts, corresponding to the // ones requested from an already RLP encoded format. func (p *peer) SendReceiptsRLP(receipts []rlp.RawValue) error { - if p.pairRw != nil { - return p2p.Send(p.pairRw, ReceiptsMsg, receipts) - } else { - return p2p.Send(p.rw, ReceiptsMsg, receipts) - } + return p2p.Send(p.msgWriter(), ReceiptsMsg, receipts) } func (p *peer) SendVote(vote *types.Vote) error { @@ -323,11 +326,7 @@ func (p *peer) SendVote(vote *types.Vote) error { } p.knownVote.Add(vote.Hash()) - if p.pairRw != nil { - return p2p.Send(p.pairRw, VoteMsg, vote) - } else { - return p2p.Send(p.rw, VoteMsg, vote) - } + return p2p.Send(p.msgWriter(), VoteMsg, vote) } /* @@ -341,11 +340,7 @@ func (p *peer) SendTimeout(timeout *types.Timeout) error { } p.knownTimeout.Add(timeout.Hash()) - if p.pairRw != nil { - return p2p.Send(p.pairRw, TimeoutMsg, timeout) - } else { - return p2p.Send(p.rw, TimeoutMsg, timeout) - } + return p2p.Send(p.msgWriter(), TimeoutMsg, timeout) } /* @@ -359,11 +354,7 @@ func (p *peer) SendSyncInfo(syncInfo *types.SyncInfo) error { } p.knownSyncInfo.Add(syncInfo.Hash()) - if p.pairRw != nil { - return p2p.Send(p.pairRw, SyncInfoMsg, syncInfo) - } else { - return p2p.Send(p.rw, SyncInfoMsg, syncInfo) - } + return p2p.Send(p.msgWriter(), SyncInfoMsg, syncInfo) } /* @@ -376,65 +367,41 @@ func (p *peer) AsyncSendSyncInfo() { // single header. It is used solely by the fetcher. func (p *peer) RequestOneHeader(hash common.Hash) error { p.Log().Debug("Fetching single header", "hash", hash) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) - } + return p2p.Send(p.msgWriter(), GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: hash}, Amount: uint64(1), Skip: uint64(0), Reverse: false}) } // RequestHeadersByHash fetches a batch of blocks' headers corresponding to the // specified header query, based on the hash of an origin block. func (p *peer) RequestHeadersByHash(origin common.Hash, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromhash", origin, "skip", skip, "reverse", reverse) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } + return p2p.Send(p.msgWriter(), GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Hash: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) } // RequestHeadersByNumber fetches a batch of blocks' headers corresponding to the // specified header query, based on the number of an origin block. func (p *peer) RequestHeadersByNumber(origin uint64, amount int, skip int, reverse bool) error { p.Log().Debug("Fetching batch of headers", "count", amount, "fromnum", origin, "skip", skip, "reverse", reverse) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } else { - return p2p.Send(p.rw, GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) - } + return p2p.Send(p.msgWriter(), GetBlockHeadersMsg, &getBlockHeadersData{Origin: hashOrNumber{Number: origin}, Amount: uint64(amount), Skip: uint64(skip), Reverse: reverse}) } // RequestBodies fetches a batch of blocks' bodies corresponding to the hashes // specified. func (p *peer) RequestBodies(hashes []common.Hash) error { p.Log().Debug("Fetching batch of block bodies", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetBlockBodiesMsg, hashes) - } else { - return p2p.Send(p.rw, GetBlockBodiesMsg, hashes) - } + return p2p.Send(p.msgWriter(), GetBlockBodiesMsg, hashes) } // RequestNodeData fetches a batch of arbitrary data from a node's known state // data, corresponding to the specified hashes. func (p *peer) RequestNodeData(hashes []common.Hash) error { p.Log().Debug("Fetching batch of state data", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetNodeDataMsg, hashes) - } else { - return p2p.Send(p.rw, GetNodeDataMsg, hashes) - } + return p2p.Send(p.msgWriter(), GetNodeDataMsg, hashes) } // RequestReceipts fetches a batch of transaction receipts from a remote node. func (p *peer) RequestReceipts(hashes []common.Hash) error { p.Log().Debug("Fetching batch of receipts", "count", len(hashes)) - if p.pairRw != nil { - return p2p.Send(p.pairRw, GetReceiptsMsg, hashes) - } else { - return p2p.Send(p.rw, GetReceiptsMsg, hashes) - } + return p2p.Send(p.msgWriter(), GetReceiptsMsg, hashes) } // Handshake executes the eth protocol handshake, negotiating version number, @@ -531,29 +498,72 @@ func (ps *peerSet) Register(p *peer) error { return errClosed } if existPeer, ok := ps.peers[p.id]; ok { - if existPeer.pairRw != nil { + // Mark duplicate connections as pair-role peers before any early return + // so rejected duplicate pairs still take the pair cleanup path in + // UnregisterPeer and report errPairNotRegistered instead of looking + // like an unregistered primary. + p.isPair = true + if existPeer.pairWriter() != nil { return errAlreadyRegistered } existPeer.SetPairPeer(p.Peer) - existPeer.pairRw = p.rw p.SetPairPeer(existPeer.Peer) + rw := p.rw + existPeer.setPairWriter(&rw) return p2p.ErrAddPairPeer } + p.isPair = false ps.peers[p.id] = p return nil } -// Unregister removes a remote peer from the active set, disabling any further -// actions to/from that particular entity. -func (ps *peerSet) Unregister(id string) error { +// UnregisterPeer removes a specific peer instance from the active set. The +// returned flag reports whether this exact instance was the currently +// registered primary peer, allowing callers to keep downloader bookkeeping in +// sync without performing a separate lookup. +// +// When the instance is a paired connection, the primary peer remains +// registered and only its pair state is cleared. +func (ps *peerSet) UnregisterPeer(p *peer) (bool, error) { ps.lock.Lock() defer ps.lock.Unlock() - if _, ok := ps.peers[id]; !ok { - return errNotRegistered - } - delete(ps.peers, id) - return nil + current, ok := ps.peers[p.id] + if !ok { + if p.isPair { + return false, errPairNotRegistered + } + return false, errNotRegistered + } + if current == p { + // Keep the p2p-level primary->pair peer link intact here. The paired + // connection is disconnected by the p2p.Peer.run exit path, and its own + // unregister path will clean up any remaining stale pair state. + if pairPeer := current.PairPeer(); pairPeer != nil { + pairPeer.ClearPairPeer(current.Peer) + } + // Clear the eth-level paired writer now because only the registered + // primary stores pairRW. + current.clearPairWriter() + delete(ps.peers, p.id) + return true, nil + } + if pairPeer := current.PairPeer(); p.isPair && pairPeer != p.Peer { + // A paired connection is known for this id, but it is not this specific + // peer instance anymore (for example, a stale pair after primary removal). + return false, errPairNotRegistered + } + if current.ClearPairPeer(p.Peer) { + p.ClearPairPeer(current.Peer) + current.clearPairWriter() + // The current primary cleared its active paired connection; the primary + // stays registered, so callers should not treat this as a primary removal. + return false, nil + } + // Reaching here means p shares the id with the current primary but is neither + // the registered primary nor the currently tracked pair instance. + log.Trace("Ignoring unregister for unexpected same-id peer", "peer", p.id, "isPair", p.isPair, "hasTrackedPair", current.PairPeer() != nil) + return false, errNotRegistered } // Peer retrieves the registered peer with the given id. diff --git a/eth/peer_test.go b/eth/peer_test.go new file mode 100644 index 000000000000..142984d8658e --- /dev/null +++ b/eth/peer_test.go @@ -0,0 +1,421 @@ +// Copyright 2023 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package eth + +import ( + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + "testing" + + "github.com/XinFinOrg/XDPoSChain/p2p" + "github.com/XinFinOrg/XDPoSChain/p2p/discover" +) + +type stubMsgReadWriter struct{} + +func (*stubMsgReadWriter) ReadMsg() (p2p.Msg, error) { + return p2p.Msg{}, io.EOF +} + +func (*stubMsgReadWriter) WriteMsg(p2p.Msg) error { + return nil +} + +func TestPeerMsgWriterUsesAtomicPairWriter(t *testing.T) { + var p peer + + primaryRW := &stubMsgReadWriter{} + pairedRW := p2p.MsgReadWriter(&stubMsgReadWriter{}) + p.rw = primaryRW + + if got := p.msgWriter(); got != primaryRW { + t.Fatal("msgWriter should return primary writer when pair writer is unset") + } + + p.setPairWriter(&pairedRW) + if got := p.msgWriter(); got != pairedRW { + t.Fatal("msgWriter should return pair writer when one is registered") + } + + p.clearPairWriter() + if got := p.msgWriter(); got != primaryRW { + t.Fatal("msgWriter should fall back to primary writer after clearing pair writer") + } +} + +func TestPeerSetUnregisterPairKeepsPrimaryAndClearsPairWriter(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 1 + + primaryRW := &stubMsgReadWriter{} + pairRW := &stubMsgReadWriter{} + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), primaryRW) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), pairRW) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if primary.isPair { + t.Fatal("primary peer should not be marked as pair after primary registration") + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + if !pair.isPair { + t.Fatal("paired peer should be marked as pair after duplicate registration") + } + if primary.pairWriter() != pair.rw { + t.Fatal("primary did not record pair writer") + } + if pair.PairPeer() != primary.Peer { + t.Fatal("pair peer did not record primary peer") + } + + removedPrimary, err := ps.UnregisterPeer(pair) + if err != nil { + t.Fatalf("unregister pair: %v", err) + } + if removedPrimary { + t.Fatal("pair unregister should not report primary removal") + } + if got := ps.Peer(primary.id); got != primary { + t.Fatal("primary peer was removed while unregistering pair") + } + if primary.pairWriter() != nil { + t.Fatal("primary peer still references pair writer after unregister") + } + if primary.PairPeer() != nil { + t.Fatal("primary peer still references pair peer after unregister") + } + if pair.PairPeer() != nil { + t.Fatal("pair peer still references primary peer after unregister") + } +} + +func TestPeerSetUnregisterPrimaryKeepsDisconnectLinkAndClearsPairBacklink(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 2 + + primaryRW := &stubMsgReadWriter{} + pairRW := &stubMsgReadWriter{} + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), primaryRW) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), pairRW) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + if primary.pairWriter() != pair.rw { + t.Fatal("primary did not record pair writer") + } + if primary.PairPeer() != pair.Peer { + t.Fatal("primary peer did not record pair peer") + } + + removedPrimary, err := ps.UnregisterPeer(primary) + if err != nil { + t.Fatalf("unregister primary: %v", err) + } + if !removedPrimary { + t.Fatal("primary unregister should report primary removal") + } + if got := ps.Peer(primary.id); got != nil { + t.Fatal("primary peer still registered after unregister") + } + if primary.pairWriter() != nil { + t.Fatal("primary peer still references pair writer after unregister") + } + if primary.PairPeer() != pair.Peer { + t.Fatal("primary peer lost pair link needed for disconnect cleanup") + } + if pair.PairPeer() != nil { + t.Fatal("pair peer still references removed primary") + } +} + +func TestPeerSetUnregisterPairAfterDisconnectReturnsPairNotRegistered(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 3 + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + removedPrimary, err := ps.UnregisterPeer(pair) + if err != nil { + t.Fatalf("first unregister pair: %v", err) + } + if removedPrimary { + t.Fatal("pair unregister should not report primary removal") + } + removedPrimary, err = ps.UnregisterPeer(pair) + if err != errPairNotRegistered { + t.Fatalf("second unregister pair: got %v want %v", err, errPairNotRegistered) + } + if removedPrimary { + t.Fatal("stale pair unregister should not report primary removal") + } + if got := ps.Peer(primary.id); got != primary { + t.Fatal("primary peer was removed after pair double unregister") + } + if primary.pairWriter() != nil { + t.Fatal("primary peer still references pair writer after pair double unregister") + } + if primary.PairPeer() != nil { + t.Fatal("primary peer still references pair peer after pair double unregister") + } + if pair.PairPeer() != nil { + t.Fatal("pair peer still references primary peer after pair double unregister") + } +} + +func TestPeerSetUnregisterStalePairAfterPrimaryRemovalReturnsPairNotRegistered(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 5 + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + removedPrimary, err := ps.UnregisterPeer(primary) + if err != nil { + t.Fatalf("unregister primary: %v", err) + } + if !removedPrimary { + t.Fatal("primary unregister should report primary removal") + } + removedPrimary, err = ps.UnregisterPeer(pair) + if err != errPairNotRegistered { + t.Fatalf("unregister stale pair: got %v want %v", err, errPairNotRegistered) + } + if removedPrimary { + t.Fatal("stale pair unregister should not report primary removal") + } + if errors.Is(errPairNotRegistered, errNotRegistered) { + t.Fatal("pair not registered should remain distinct from not registered") + } + removedPrimary, err = ps.UnregisterPeer(primary) + if err != errNotRegistered { + t.Fatalf("unregister stale primary: got %v want %v", err, errNotRegistered) + } + if removedPrimary { + t.Fatal("stale primary unregister should not report primary removal") + } +} + +func TestPeerSetUnregisterStalePairAfterPrimaryClearsPairReturnsPairNotRegistered(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 7 + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + if !primary.ClearPairPeer(pair.Peer) { + t.Fatal("primary pair link was not established") + } + primary.clearPairWriter() + + removedPrimary, err := ps.UnregisterPeer(pair) + if err != errPairNotRegistered { + t.Fatalf("unregister stale pair after primary clear: got %v want %v", err, errPairNotRegistered) + } + if removedPrimary { + t.Fatal("stale pair unregister should not report primary removal") + } + if got := ps.Peer(primary.id); got != primary { + t.Fatal("primary peer should remain registered after stale pair cleanup") + } + if primary.PairPeer() != nil { + t.Fatal("primary peer should keep cleared pair reference") + } + if pair.PairPeer() != primary.Peer { + t.Fatal("pair peer should retain stale primary reference for cleanup path") + } +} + +func TestPeerSetUnregisterRejectedDuplicatePairReturnsPairNotRegistered(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 6 + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + rejectedPair := newPeer(eth63, p2p.NewPeer(id, "rejected-pair", nil), &stubMsgReadWriter{}) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + if err := ps.Register(rejectedPair); err != errAlreadyRegistered { + t.Fatalf("register rejected pair: got %v want %v", err, errAlreadyRegistered) + } + if !rejectedPair.isPair { + t.Fatal("rejected duplicate pair should retain pair role for cleanup semantics") + } + removedPrimary, err := ps.UnregisterPeer(rejectedPair) + if err != errPairNotRegistered { + t.Fatalf("unregister rejected pair: got %v want %v", err, errPairNotRegistered) + } + if removedPrimary { + t.Fatal("rejected pair unregister should not report primary removal") + } + if rejectedPair.PairPeer() != nil { + t.Fatal("rejected pair should not retain a pair peer link") + } + if primary.pairWriter() != pair.rw { + t.Fatal("primary pair writer changed after rejected duplicate pair") + } +} + +func TestPeerPairRWConcurrentRegisterAndSend(t *testing.T) { + ps := newPeerSet() + var id discover.NodeID + id[0] = 4 + + primary := newPeer(eth63, p2p.NewPeer(id, "primary", nil), &stubMsgReadWriter{}) + pair := newPeer(eth63, p2p.NewPeer(id, "pair", nil), &stubMsgReadWriter{}) + + if err := ps.Register(primary); err != nil { + t.Fatalf("register primary: %v", err) + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + t.Fatalf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + } + + start := make(chan struct{}) + errc := make(chan error, 3) + var lifecycleMu sync.Mutex + var primaryCycles atomic.Int32 + var wg sync.WaitGroup + wg.Add(3) + + go func() { + defer wg.Done() + <-start + for i := 0; i < 2000; i++ { + if err := primary.SendBlockHeaders(nil); err != nil { + errc <- fmt.Errorf("send block headers: %w", err) + return + } + } + }() + + go func() { + defer wg.Done() + <-start + for i := 0; i < 2000; i++ { + lifecycleMu.Lock() + removedPrimary, err := ps.UnregisterPeer(pair) + if err != nil { + lifecycleMu.Unlock() + errc <- fmt.Errorf("unregister pair: %w", err) + return + } + if removedPrimary { + lifecycleMu.Unlock() + errc <- fmt.Errorf("unregister pair: reported primary removal") + return + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + lifecycleMu.Unlock() + errc <- fmt.Errorf("register pair: got %v want %v", err, p2p.ErrAddPairPeer) + return + } + lifecycleMu.Unlock() + } + }() + + go func() { + defer wg.Done() + <-start + for i := 0; i < 200; i++ { + lifecycleMu.Lock() + removedPrimary, err := ps.UnregisterPeer(primary) + if err != nil { + lifecycleMu.Unlock() + errc <- fmt.Errorf("unregister primary: %w", err) + return + } + if !removedPrimary { + lifecycleMu.Unlock() + errc <- fmt.Errorf("unregister primary: did not report primary removal") + return + } + primaryCycles.Add(1) + if err := ps.Register(primary); err != nil { + lifecycleMu.Unlock() + errc <- fmt.Errorf("register primary: %v", err) + return + } + if err := ps.Register(pair); err != p2p.ErrAddPairPeer { + lifecycleMu.Unlock() + errc <- fmt.Errorf("re-register pair after primary: got %v want %v", err, p2p.ErrAddPairPeer) + return + } + lifecycleMu.Unlock() + } + }() + + close(start) + wg.Wait() + close(errc) + + for err := range errc { + if err != nil { + t.Fatal(err) + } + } + if primaryCycles.Load() == 0 { + t.Fatal("primary unregister path was not exercised") + } + if got := ps.Peer(primary.id); got != primary { + t.Fatal("primary peer was not restored after concurrent lifecycle churn") + } + if primary.pairWriter() != pair.rw { + t.Fatal("primary peer did not restore pair writer after primary re-register") + } +} diff --git a/p2p/peer.go b/p2p/peer.go index b8b257f56dbc..c96c46a4f7cd 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -23,6 +23,7 @@ import ( "net" "slices" "sync" + "sync/atomic" "time" "github.com/XinFinOrg/XDPoSChain/common/mclock" @@ -115,8 +116,7 @@ type Peer struct { // events receives message send / receive events if set events *event.Feed - pairPeerMu sync.RWMutex - pairPeer *Peer + pairPeer atomic.Pointer[Peer] } // NewPeer returns a peer for testing purposes. @@ -193,27 +193,15 @@ func (p *Peer) Log() log.Logger { } func (p *Peer) PairPeer() *Peer { - p.pairPeerMu.RLock() - defer p.pairPeerMu.RUnlock() - - return p.pairPeer + return p.pairPeer.Load() } func (p *Peer) SetPairPeer(pair *Peer) { - p.pairPeerMu.Lock() - p.pairPeer = pair - p.pairPeerMu.Unlock() + p.pairPeer.Store(pair) } func (p *Peer) ClearPairPeer(pair *Peer) bool { - p.pairPeerMu.Lock() - defer p.pairPeerMu.Unlock() - - if p.pairPeer != pair { - return false - } - p.pairPeer = nil - return true + return p.pairPeer.CompareAndSwap(pair, nil) } func (p *Peer) run() (remoteRequested bool, err error) { diff --git a/p2p/peer_test.go b/p2p/peer_test.go index 09907bbb341c..b5caf5ed2a38 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -225,6 +225,37 @@ func TestPeerRunDisconnectsPairPeer(t *testing.T) { } } +func TestPeerPairPeerAtomicSemantics(t *testing.T) { + var peer Peer + first := &Peer{} + second := &Peer{} + + if got := peer.PairPeer(); got != nil { + t.Fatalf("initial pair peer = %p, want nil", got) + } + peer.SetPairPeer(first) + if got := peer.PairPeer(); got != first { + t.Fatalf("pair peer after first set = %p, want %p", got, first) + } + if cleared := peer.ClearPairPeer(second); cleared { + t.Fatal("ClearPairPeer should fail when current pair differs") + } + if got := peer.PairPeer(); got != first { + t.Fatalf("pair peer after failed clear = %p, want %p", got, first) + } + if cleared := peer.ClearPairPeer(first); !cleared { + t.Fatal("ClearPairPeer should succeed for the current pair") + } + if got := peer.PairPeer(); got != nil { + t.Fatalf("pair peer after successful clear = %p, want nil", got) + } + peer.SetPairPeer(second) + peer.SetPairPeer(nil) + if got := peer.PairPeer(); got != nil { + t.Fatalf("pair peer after explicit nil set = %p, want nil", got) + } +} + func TestNewPeer(t *testing.T) { name := "nodename" caps := []Cap{{"foo", 2}, {"bar", 3}}