diff --git a/api.go b/api.go index 50fa88e3..eef03e5a 100644 --- a/api.go +++ b/api.go @@ -58,7 +58,7 @@ type Storage interface { type Communication interface { // Nodes returns all nodes that participate in the epoch. - Nodes() []NodeID + Nodes() Nodes // Send sends a message to the given destination node Send(msg *Message, destination NodeID) @@ -130,8 +130,33 @@ type SignatureAggregator interface { // Aggregate aggregates several signatures into a QuorumCertificate Aggregate([]Signature) (QuorumCertificate, error) + // AppendSignatures appends signatures to an existing signature. + // If the existing signature is empty, it just aggregates the given signatures. + AppendSignatures([]byte, ...[]byte) ([]byte, error) + // IsQuorum returns true if the given signers constitute a quorum. // In the case of PoA, this means at least a quorum of the nodes are given. // In the case of PoS, this means at least two thirds of the st. IsQuorum([]NodeID) bool } + +// Nodes is a list of Node elements. +type Nodes []Node + +// NodeIDs returns the NodeIDs of the nodes in the Nodes. +func (nws Nodes) NodeIDs() []NodeID { + nodes := make([]NodeID, len(nws)) + for i, nw := range nws { + nodes[i] = nw.Node + } + return nodes +} + +// Node is a struct that pairs a node with its weight in the signature aggregator. +type Node struct { + Node NodeID + Weight uint64 +} + +// SignatureAggregatorCreator creates a SignatureAggregator from a list of nodes and their weights. +type SignatureAggregatorCreator func([]Node) SignatureAggregator diff --git a/blacklist.go b/blacklist.go index 842ba638..a80c16d0 100644 --- a/blacklist.go +++ b/blacklist.go @@ -206,7 +206,7 @@ func (bl *Blacklist) ApplyUpdates(updates []BlacklistUpdate, round uint64) Black } // garbageCollectSuspectedNodes returns a new list of suspected nodes for the given round. -// Nodes that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. +// NodeIDs that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. // It will also garbage-collect any redeem votes from past orbits, unless hey have surpassed the threshold of f+1. // It does not modify the current blacklist. func (bl *Blacklist) garbageCollectSuspectedNodes(round uint64) SuspectedNodes { diff --git a/blacklist_test.go b/blacklist_test.go index f47fb5ea..0a477234 100644 --- a/blacklist_test.go +++ b/blacklist_test.go @@ -465,8 +465,8 @@ func TestComputeBlacklistUpdates(t *testing.T) { func TestAdvanceRound(t *testing.T) { nodes := []uint16{0, 1, 2, 3} - // Nodes 0, 2 are suspected. - // Nodes 1 and 3 are not suspected. + // NodeIDs 0, 2 are suspected. + // NodeIDs 1 and 3 are not suspected. // Node 2 can be redeemed. suspectedNodesBefore := SuspectedNodes{ {NodeIndex: 0, SuspectingCount: 2, OrbitSuspected: 1, RedeemingCount: 1, OrbitToRedeem: 1}, diff --git a/epoch.go b/epoch.go index c81a51e8..157d7b7c 100644 --- a/epoch.go +++ b/epoch.go @@ -69,7 +69,7 @@ type EpochConfig struct { Signer Signer Verifier SignatureVerifier BlockDeserializer BlockDeserializer - SignatureAggregator SignatureAggregator + SignatureAggregatorCreator SignatureAggregatorCreator Comm Communication Storage Storage WAL WriteAheadLog @@ -83,6 +83,7 @@ type EpochConfig struct { type Epoch struct { EpochConfig // Runtime + signatureAggregator SignatureAggregator oneTimeVerifier *OneTimeVerifier buildBlockScheduler *BasicScheduler blockVerificationScheduler *BlockDependencyManager @@ -94,6 +95,7 @@ type Epoch struct { blockBuilderCtx context.Context blockBuilderCancelFunc context.CancelFunc nodes NodeIDs + nodeWeights Nodes eligibleNodeIDs map[string]struct{} rounds map[uint64]*Round emptyVotes map[uint64]*EmptyVoteSet @@ -198,8 +200,9 @@ func (e *Epoch) init() error { e.finishCtx, e.finishFn = context.WithCancel(context.Background()) e.blockBuilderCtx = context.Background() e.blockBuilderCancelFunc = func() {} - e.nodes = e.Comm.Nodes() - SortNodes(e.nodes) + e.nodeWeights = e.Comm.Nodes() + SortNodes(e.nodeWeights) + e.nodes = e.nodeWeights.NodeIDs() e.timedOutRounds = make(map[uint16]uint64, len(e.nodes)) e.redeemedRounds = make(map[uint16]uint64, len(e.nodes)) e.rounds = make(map[uint64]*Round) @@ -208,6 +211,7 @@ func (e *Epoch) init() error { e.futureMessages = make(messagesFromNode, len(e.nodes)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.MaxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock, e.RandomSource) e.timeoutHandler = NewTimeoutHandler(e.Logger, "emptyVoteRebroadcast", e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) + e.signatureAggregator = e.SignatureAggregatorCreator(e.nodeWeights) for _, node := range e.nodes { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) @@ -739,7 +743,7 @@ func (e *Epoch) handleFinalizationMessage(message *Finalization, from NodeID) er return nil } - if err := VerifyQC(message.QC, e.Logger, "Finalization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { + if err := VerifyQC(message.QC, e.Logger, "Finalization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { e.Logger.Debug("Received an invalid finalization", zap.Int("round", int(message.Finalization.Round)), zap.Stringer("NodeID", from)) @@ -1159,7 +1163,7 @@ func (e *Epoch) maybeCollectFinalization(round *Round) error { var finalizations []*FinalizeVote for _, finalizationsWithTheSameDigest := range finalizationsByMD { - if e.SignatureAggregator.IsQuorum(NodeIDsFromVotes(finalizationsWithTheSameDigest)) { + if e.signatureAggregator.IsQuorum(NodeIDsFromVotes(finalizationsWithTheSameDigest)) { finalizations = finalizationsWithTheSameDigest break } @@ -1174,7 +1178,7 @@ func (e *Epoch) maybeCollectFinalization(round *Round) error { } func (e *Epoch) assembleFinalization(round *Round, finalizationVotes []*FinalizeVote) error { - finalization, err := NewFinalization(e.Logger, e.SignatureAggregator, finalizationVotes) + finalization, err := NewFinalization(e.Logger, e.signatureAggregator, finalizationVotes) if err != nil { return err } @@ -1387,13 +1391,13 @@ func (e *Epoch) maybeAssembleEmptyNotarization() error { } // Check if we found a quorum of votes for the same metadata - popularEmptyVote, signatures, found := findEmptyVoteThatIsQuorum(emptyVotes.votes, e.SignatureAggregator.IsQuorum) + popularEmptyVote, signatures, found := findEmptyVoteThatIsQuorum(emptyVotes.votes, e.signatureAggregator.IsQuorum) if !found { e.Logger.Debug("Could not find empty vote with a quorum or more votes", zap.Uint64("round", e.round)) return nil } - qc, err := e.SignatureAggregator.Aggregate(signatures) + qc, err := e.signatureAggregator.Aggregate(signatures) if err != nil { e.Logger.Error("Could not aggregate empty votes signatures", zap.Error(err), zap.Uint64("round", e.round)) return nil @@ -1500,7 +1504,7 @@ func (e *Epoch) maybeCollectNotarization() error { } } - if !e.SignatureAggregator.IsQuorum(NodeIDsFromVotes(votesForOurBlock)) { + if !e.signatureAggregator.IsQuorum(NodeIDsFromVotes(votesForOurBlock)) { e.Logger.Debug("Not enough votes to form a notarization for our block", zap.Uint64("round", e.round), zap.Int("voteForOurBlock", len(votesForOurBlock)), @@ -1508,7 +1512,7 @@ func (e *Epoch) maybeCollectNotarization() error { return nil } - notarization, err := NewNotarization(e.Logger, e.SignatureAggregator, votesForCurrentRound, block.BlockHeader()) + notarization, err := NewNotarization(e.Logger, e.signatureAggregator, votesForCurrentRound, block.BlockHeader()) if err != nil { return err } @@ -1594,7 +1598,7 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat } // Otherwise, this round is not notarized or finalized yet, so verify the empty notarization and store it. - if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, emptyNotarization, from); err != nil { + if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, emptyNotarization, from); err != nil { return nil } @@ -1650,7 +1654,7 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er return nil } - if err := VerifyQC(message.QC, e.Logger, "Notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { + if err := VerifyQC(message.QC, e.Logger, "Notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { return nil } @@ -3206,20 +3210,20 @@ func (e *Epoch) verifyQuorumRound(q QuorumRound, from NodeID) error { if q.Finalization != nil { // extra check needed if we have a finalized block - err := VerifyQC(q.Finalization.QC, e.Logger, "Finalization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Finalization, from) + err := VerifyQC(q.Finalization.QC, e.Logger, "Finalization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Finalization, from) if err != nil { return errors.New("invalid finalization") } } if q.Notarization != nil { - if err := VerifyQC(q.Notarization.QC, e.Logger, "Notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Notarization, from); err != nil { + if err := VerifyQC(q.Notarization.QC, e.Logger, "Notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Notarization, from); err != nil { return fmt.Errorf("invalid notarization: %v", err) } } if q.EmptyNotarization != nil { - err := VerifyQC(q.EmptyNotarization.QC, e.Logger, "Empty notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.EmptyNotarization, from) + err := VerifyQC(q.EmptyNotarization.QC, e.Logger, "Empty notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.EmptyNotarization, from) if err != nil { return fmt.Errorf("invalid empty notarization QC: %v", err) } @@ -3421,9 +3425,9 @@ func (e *Epoch) nextSeqToCommit() uint64 { } // SortNodes sorts the nodes in place by their byte representations. -func SortNodes(nodes []NodeID) { - slices.SortFunc(nodes, func(a, b NodeID) int { - return bytes.Compare(a[:], b[:]) +func SortNodes(nodes Nodes) { + slices.SortFunc(nodes, func(a, b Node) int { + return bytes.Compare(a.Node[:], b.Node[:]) }) } diff --git a/epoch_failover_test.go b/epoch_failover_test.go index a6c39042..9db6a9b1 100644 --- a/epoch_failover_test.go +++ b/epoch_failover_test.go @@ -85,9 +85,9 @@ func TestEpochLeaderFailoverWithEmptyNotarization(t *testing.T) { func TestEpochRebroadcastsEmptyVoteAfterBlockProposalReceived(t *testing.T) { bb := testutil.NewTestBlockBuilder() - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes) + comm := newRebroadcastComm(nodes.EqualWeightedNodes()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -349,12 +349,12 @@ func TestEpochLeaderFailoverDoNotPersistEmptyRoundTwice(t *testing.T) { } func TestEpochLeaderRecursivelyFetchNotarizedBlocks(t *testing.T) { - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} bb := testutil.NewTestBlockBuilder() recordedMessages := make(chan *Message, 100) - comm := &recordingComm{Communication: testutil.NoopComm(nodes), SentMessages: recordedMessages} + comm := &recordingComm{Communication: testutil.NoopComm(nodes.EqualWeightedNodes()), SentMessages: recordedMessages} conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) e, err := NewEpoch(conf) @@ -1115,18 +1115,18 @@ func TestEpochBlacklist(t *testing.T) { } type rebroadcastComm struct { - nodes []NodeID + nodes Nodes emptyVotes chan *EmptyVote } -func newRebroadcastComm(nodes []NodeID) *rebroadcastComm { +func newRebroadcastComm(nodes Nodes) *rebroadcastComm { return &rebroadcastComm{ nodes: nodes, emptyVotes: make(chan *EmptyVote, 10), } } -func (r *rebroadcastComm) Nodes() []NodeID { +func (r *rebroadcastComm) Nodes() Nodes { return r.nodes } @@ -1142,9 +1142,9 @@ func (r *rebroadcastComm) Broadcast(msg *Message) { func TestEpochRebroadcastsEmptyVote(t *testing.T) { bb := testutil.NewTestBlockBuilder() - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes) + comm := newRebroadcastComm(nodes.EqualWeightedNodes()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -1227,7 +1227,7 @@ func runCrashAndRestartExecution(t *testing.T, e *Epoch, bb *testutil.TestBlockB // Case 2: t.Run(fmt.Sprintf("%s-with-crash", t.Name()), func(t *testing.T) { - conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], testutil.NewNoopComm(nodes), bbAfterCrash) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0].Node, testutil.NewNoopComm(nodes.NodeIDs()), bbAfterCrash) conf.Storage = cloneStorage conf.WAL = cloneWAL diff --git a/epoch_test.go b/epoch_test.go index 83004dff..a7a034b5 100644 --- a/epoch_test.go +++ b/epoch_test.go @@ -118,7 +118,8 @@ func TestFinalizeSameSequence(t *testing.T) { require.NoError(t, err) // create a notarization and now we should send a finalize vote for seq 1 again - notarization, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[1:]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes[1:]) require.NoError(t, err) testutil.InjectTestNotarization(t, e, notarization, nodes[1]) @@ -202,7 +203,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat } for range numEmptyNotarizations { - leader := LeaderForRound(e.Comm.Nodes(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) if e.ID.Equals(leader) { fVote := advanceWithFinalizeCheck(t, e, recordingComm, bb) finalizeVoteSeqs[fVote.Finalization.Seq] = fVote @@ -236,7 +237,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat verified <- struct{}{} } - leader := LeaderForRound(e.Comm.Nodes(), 1+numEmptyNotarizations+numNotarizations) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), 1+numEmptyNotarizations+numNotarizations) if e.ID.Equals(leader) { return } @@ -275,7 +276,8 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat } // create a notarization and now we should send a finalize vote for seqToDoubleFinalize again - notarization, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[1:]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes[1:]) require.NoError(t, err) testutil.InjectTestNotarization(t, e, notarization, nodes[1]) @@ -377,7 +379,8 @@ func TestEpochHandleNotarizationFutureRound(t *testing.T) { require.NoError(t, e.Start()) // Create a notarization for round 1 which is a future round because we haven't gone through round 0 yet. - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, secondBlock, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, secondBlock, nodes) require.NoError(t, err) // Give the node the notarization message before receiving the first block @@ -442,7 +445,8 @@ func TestEpochIndexFinalization(t *testing.T) { // when we receive that finalization, we should commit the rest of the finalizations for seqs // 1 & 2 - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, e.Comm.Nodes()) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, e.Comm.Nodes().NodeIDs()) testutil.InjectTestFinalization(t, e, &finalization, nodes[1]) storage.WaitForBlockCommit(2) @@ -539,7 +543,8 @@ func TestEpochIncreasesRoundAfterFinalization(t *testing.T) { require.Equal(t, uint64(0), storage.NumBlocks()) // create the finalized block - finalization, _ := testutil.NewFinalizationRecord(t, l, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes) testutil.InjectTestFinalization(t, e, &finalization, nodes[1]) storage.WaitForBlockCommit(1) @@ -746,10 +751,10 @@ func TestEpochStartedTwice(t *testing.T) { } func advanceRoundFromEmpty(t *testing.T, e *Epoch) { - leader := LeaderForRound(e.Comm.Nodes(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) require.False(t, e.ID.Equals(leader), "epoch cannot be the leader for the empty round") - emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes(), e.Metadata().Round) + emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) err := e.HandleMessage(&Message{ EmptyNotarization: emptyNote, }, leader) @@ -1010,8 +1015,10 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { wal.AssertWALSize(1) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + t.Run("notarization with unknown signer isn't taken into account", func(t *testing.T) { - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {5}}) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {5}}) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1025,7 +1032,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("notarization with double signer isn't taken into account", func(t *testing.T) { - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}}) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, []NodeID{{2}, {3}}) require.NoError(t, err) tqc := notarization.QC.(testutil.TestQC) @@ -1081,7 +1088,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("finalization with unknown signer isn't taken into account", func(t *testing.T) { - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {5}}) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {5}}) err = e.HandleMessage(&Message{ Finalization: &finalization, @@ -1092,7 +1099,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("finalization with double signer isn't taken into account", func(t *testing.T) { - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {3}}) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {3}}) err = e.HandleMessage(&Message{ Finalization: &finalization, @@ -1250,7 +1257,8 @@ func TestEpochSendsBlockDigestRequest(t *testing.T) { require.True(t, built) block := bb.GetBuiltBlock() - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1495,7 +1503,8 @@ func TestDoubleIncrementOnPersistNotarization(t *testing.T) { require.True(t, ok) block := bb.GetBuiltBlock() - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1586,7 +1595,8 @@ func TestRejectsOldNotarizationAndVotes(t *testing.T) { } // send notarization for round 1, after the finalization was sent - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1634,7 +1644,7 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, nodes := e.Comm.Nodes() quorum := simplex.Quorum(len(nodes)) // leader is the proposer of the new block for the given round - leader := simplex.LeaderForRound(nodes, e.Metadata().Round) + leader := simplex.LeaderForRound(nodes.NodeIDs(), e.Metadata().Round) md := e.Metadata() if injectedMD != nil { md = *injectedMD @@ -1667,8 +1677,9 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, var notarization *simplex.Notarization if notarize { // start at one since our node has already voted - n, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) - testutil.InjectTestNotarization(t, e, n, nodes[1]) + sigAggr := e.SignatureAggregatorCreator(nodes) + n, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes.NodeIDs()[0:quorum]) + testutil.InjectTestNotarization(t, e, n, nodes[1].Node) e.WAL.(*testutil.TestWAL).AssertNotarization(block.Metadata.Round) require.NoError(t, err) @@ -1677,10 +1688,10 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, if finalize { for i := 0; i <= quorum; i++ { - if nodes[i].Equals(e.ID) { + if nodes[i].Node.Equals(e.ID) { continue } - testutil.InjectTestFinalizeVote(t, e, block, nodes[i]) + testutil.InjectTestFinalizeVote(t, e, block, nodes[i].Node) } if nextSeqToCommit != block.Metadata.Seq { diff --git a/global.go b/global.go index 4fdec9bd..771da09a 100644 --- a/global.go +++ b/global.go @@ -51,3 +51,14 @@ func (nodes NodeIDs) IndexOf(id NodeID) int { } return -1 } + +func (nodes NodeIDs) EqualWeightedNodes() Nodes { + weights := make(Nodes, len(nodes)) + for i, node := range nodes { + weights[i] = Node{ + Node: node, + Weight: 1, + } + } + return weights +} diff --git a/msm/build_decision.go b/msm/build_decision.go index b19ecaf2..f0786114 100644 --- a/msm/build_decision.go +++ b/msm/build_decision.go @@ -14,31 +14,10 @@ import ( // blockBuildingDecision represents the decision of whether we should build a block at the current time, // and if so, whether we should also transition to a new epoch along the way. -type blockBuildingDecision int8 - -const ( - decisionUndefined blockBuildingDecision = iota - decisionBuild // We should build a block, and we don't need to transition to a new epoch. - decisionTransitionEpoch // We should transition to a new epoch immediately, but we don't need to build a block. - decisionBuildAndTransitionEpoch // We should build a block and transition to a new epoch along the way. - decisionContextCanceled -) - -func (bbd blockBuildingDecision) String() string { - switch bbd { - case decisionUndefined: - return "undefined" - case decisionBuild: - return "build block" - case decisionTransitionEpoch: - return "transition epoch" - case decisionBuildAndTransitionEpoch: - return "build block and transition epoch" - case decisionContextCanceled: - return "context canceled" - default: - return "unknown" - } +type blockBuildingDecision struct { + buildInnerBlock bool + transitionEpoch bool + pChainHeight uint64 } // PChainProgressListener listens for changes in the P-chain height. @@ -65,18 +44,18 @@ type blockBuildingDecider struct { // The P-chain height is returned because sampling the P-chain height afterwards might be inconsistent with the decision that was made. func (bbd *blockBuildingDecider) shouldBuildBlock( ctx context.Context, -) (blockBuildingDecision, uint64, error) { +) (blockBuildingDecision, error) { for { pChainHeight := bbd.getPChainHeight() shouldTransitionEpoch, err := bbd.hasValidatorSetChanged(pChainHeight) if err != nil { - return decisionUndefined, 0, err + return blockBuildingDecision{}, err } if shouldTransitionEpoch { // If we should transition to a new epoch, maybe we can also build a block along the way. - return bbd.buildBlockWithEpochTransition(ctx), pChainHeight, nil + return bbd.buildBlockWithEpochTransition(ctx, pChainHeight) } // Else, we don't need to transition to a new epoch, but maybe we should build a block. @@ -85,7 +64,7 @@ func (bbd *blockBuildingDecider) shouldBuildBlock( // If the context was cancelled in the meantime, abandon evaluation. if ctx.Err() != nil { - return decisionContextCanceled, 0, nil + return blockBuildingDecision{}, ctx.Err() } // If we've reached here, either the P-chain height has changed, or a block is ready to be built. @@ -99,7 +78,7 @@ func (bbd *blockBuildingDecider) shouldBuildBlock( // Else, we have reached here because a block is ready to be built, and the P-chain height has not changed, // which means we should build a block. - return decisionBuild, pChainHeight, nil + return blockBuildingDecision{buildInnerBlock: true, pChainHeight: pChainHeight}, nil } } @@ -130,7 +109,7 @@ func (bbd *blockBuildingDecider) waitForPChainChangeOrPendingBlock(ctx context.C // It waits up to a limited amount of time (bbd.maxBlockBuildingWaitTime) for a block to be ready to be built, // and if no block is ready by then, it returns the decision to transition epoch without building a block. // Otherwise, it returns the decision to build a block and transition epoch along the way. -func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Context) blockBuildingDecision { +func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Context, pChainHeight uint64) (blockBuildingDecision, error) { impatientContext, cancel := context.WithTimeout(ctx, bbd.maxBlockBuildingWaitTime) defer cancel() @@ -139,15 +118,15 @@ func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Conte bbd.waitForPendingBlock(impatientContext) if ctx.Err() != nil { - return decisionContextCanceled + return blockBuildingDecision{}, ctx.Err() } if impatientContext.Err() != nil { // We have returned from waitForPendingBlock because impatientContext has timed out, // which means we don't need to build a block. - return decisionTransitionEpoch + return blockBuildingDecision{transitionEpoch: true, pChainHeight: pChainHeight}, nil } // Block is ready to be built - return decisionBuildAndTransitionEpoch + return blockBuildingDecision{buildInnerBlock: true, transitionEpoch: true, pChainHeight: pChainHeight}, nil } diff --git a/msm/build_decision_test.go b/msm/build_decision_test.go index 828aabeb..b1b13402 100644 --- a/msm/build_decision_test.go +++ b/msm/build_decision_test.go @@ -35,10 +35,9 @@ func TestShouldBuildBlock_VMSignalsBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuild, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_ContextCanceled(t *testing.T) { @@ -60,9 +59,9 @@ func TestShouldBuildBlock_ContextCanceled(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, _, err := bbd.shouldBuildBlock(ctx) - require.NoError(t, err) - require.Equal(t, decisionContextCanceled, result) + decision, err := bbd.shouldBuildBlock(ctx) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, blockBuildingDecision{}, decision) } func TestShouldBuildBlock_PChainHeightChangeTriggersEpochTransition(t *testing.T) { @@ -92,10 +91,9 @@ func TestShouldBuildBlock_PChainHeightChangeTriggersEpochTransition(t *testing.T getPChainHeight: func() uint64 { return pChainHeight.Load() }, } - result, resultPChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionTransitionEpoch, result) - require.Equal(t, uint64(200), resultPChainHeight) + require.Equal(t, blockBuildingDecision{transitionEpoch: true, pChainHeight: 200}, decision) } func TestShouldBuildBlock_PChainHeightChangeButNoEpochTransition(t *testing.T) { @@ -128,10 +126,9 @@ func TestShouldBuildBlock_PChainHeightChangeButNoEpochTransition(t *testing.T) { getPChainHeight: func() uint64 { return pChainHeight.Load() }, } - result, resultPChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuild, result) - require.Equal(t, uint64(200), resultPChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, pChainHeight: 200}, decision) } func TestShouldBuildBlock_EpochTransitionWithVMBlock(t *testing.T) { @@ -148,10 +145,9 @@ func TestShouldBuildBlock_EpochTransitionWithVMBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuildAndTransitionEpoch, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, transitionEpoch: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_EpochTransitionWithoutVMBlock(t *testing.T) { @@ -170,10 +166,9 @@ func TestShouldBuildBlock_EpochTransitionWithoutVMBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionTransitionEpoch, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{transitionEpoch: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_EpochTransitionContextCanceled(t *testing.T) { @@ -196,7 +191,7 @@ func TestShouldBuildBlock_EpochTransitionContextCanceled(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, _, err := bbd.shouldBuildBlock(ctx) - require.NoError(t, err) - require.Equal(t, decisionContextCanceled, result) + decision, err := bbd.shouldBuildBlock(ctx) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, blockBuildingDecision{}, decision) } diff --git a/msm/encoding.go b/msm/encoding.go index 59e425ad..7fd22189 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -5,9 +5,9 @@ package metadata import ( "bytes" - "fmt" - "math" "slices" + + "github.com/ava-labs/simplex" ) //go:generate go run github.com/StephenButtolph/canoto/canoto encoding.go @@ -62,6 +62,7 @@ type SimplexEpochInfo struct { NextEpochApprovals *NextEpochApprovals `canoto:"pointer,7"` // SealingBlockSeq is the block sequence of the sealing block of the current epoch. // It defines the validator set of the next epoch. + // It is set once the first Telock is built and is copied over to subsequent Telocks. SealingBlockSeq uint64 `canoto:"uint,8"` canotoData canotoData_SimplexEpochInfo @@ -109,6 +110,30 @@ func (sei *SimplexEpochInfo) Equal(other *SimplexEpochInfo) bool { return true } +func (sei *SimplexEpochInfo) NextState() state { + prevBlockSimplexEpochInfo := sei + // If this is the first ever epoch, then this is also the first ever block to be built by Simplex. + if prevBlockSimplexEpochInfo.EpochNumber == 0 { + return stateFirstSimplexBlock + } + + // If we don't have a next P-chain preference height, it means we are not transitioning to a new epoch just yet. + if prevBlockSimplexEpochInfo.NextPChainReferenceHeight == 0 { + return stateBuildBlockNormalOp + } + + // If the previous block has a sealing block sequence, it's a Telock. + // If it has a block validation descriptor, it's a sealing block. + // Either way, the epoch has been sealed. + if prevBlockSimplexEpochInfo.SealingBlockSeq > 0 || prevBlockSimplexEpochInfo.BlockValidationDescriptor != nil { + return stateBuildBlockEpochSealed + } + + // In any other case, NextPChainReferenceHeight > 0 but the previous block is not a Telock or sealing block, + // it means we are in the process of collecting approvals for the next epoch. + return stateBuildCollectingApprovals +} + type NodeBLSMapping struct { NodeID nodeID `canoto:"fixed bytes,1"` BLSKey []byte `canoto:"bytes,2"` @@ -199,44 +224,37 @@ func (nea *NextEpochApprovals) Equals(other *NextEpochApprovals) bool { type NodeBLSMappings []NodeBLSMapping -func (nbms NodeBLSMappings) TotalWeight() (int64, error) { - var totalWeight uint64 - for _, nbm := range nbms { - var err error - totalWeight, err = safeAdd(totalWeight, nbm.Weight) - if err != nil { - return 0, fmt.Errorf("failed to sum weights of all nodes: %w", err) +func (nbms NodeBLSMappings) NodeWeights() simplex.Nodes { + nodeWeights := make(simplex.Nodes, len(nbms)) + for i, nbm := range nbms { + nodeWeights[i] = simplex.Node{ + Node: nbm.NodeID[:], + Weight: nbm.Weight, } } + return nodeWeights +} - if totalWeight == 0 { - return 0, fmt.Errorf("total weight of validators is 0") - } - - if totalWeight > math.MaxInt64 { - return 0, fmt.Errorf("total weight of validators is too big, overflows int64: %d", totalWeight) +// IndexByNodeID returns a mapping from NodeID to the validator's index in the set, +// which is the position used by approval bitmasks. +func (nbms NodeBLSMappings) IndexByNodeID() map[nodeID]int { + result := make(map[nodeID]int, len(nbms)) + for i, nbm := range nbms { + result[nbm.NodeID] = i } - return int64(totalWeight), nil + return result } -func (nbms NodeBLSMappings) ApprovingWeight(approvingNodes bitmask) (int64, error) { - var approvingWeight uint64 +func (nbms NodeBLSMappings) SelectSubset(bitmask bitmask) []simplex.NodeID { + nodeIDs := make([]simplex.NodeID, 0, len(nbms)) for i, nbm := range nbms { - if !approvingNodes.Contains(i) { + if !bitmask.Contains(i) { continue } - var err error - approvingWeight, err = safeAdd(approvingWeight, nbm.Weight) - if err != nil { - return 0, fmt.Errorf("failed to compute approving weights: %w", err) - } - } - - if approvingWeight > math.MaxInt64 { - return 0, fmt.Errorf("approving weight of validators is too big, overflows int64: %d", approvingWeight) + nodeIDs = append(nodeIDs, nbm.NodeID[:]) } - return int64(approvingWeight), nil + return nodeIDs } func (nbms NodeBLSMappings) Clone() NodeBLSMappings { @@ -282,10 +300,10 @@ type ValidatorSetApproval struct { type ValidatorSetApprovals []ValidatorSetApproval -func (vsa ValidatorSetApprovals) Filter(f func(int, ValidatorSetApproval) bool) ValidatorSetApprovals { +func (vsa ValidatorSetApprovals) Filter(f func(ValidatorSetApproval, simplex.Logger) bool, logger simplex.Logger) ValidatorSetApprovals { result := make(ValidatorSetApprovals, 0, len(vsa)) - for i, v := range vsa { - if f(i, v) { + for _, v := range vsa { + if f(v, logger) { result = append(result, v) } } diff --git a/msm/encoding_test.go b/msm/encoding_test.go index 4c7bc321..ef259503 100644 --- a/msm/encoding_test.go +++ b/msm/encoding_test.go @@ -6,6 +6,8 @@ package metadata import ( "testing" + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -402,22 +404,23 @@ func TestNodeBLSMappingsCompare(t *testing.T) { } func TestValidatorSetApprovalsFilter(t *testing.T) { + logger := testutil.MakeLogger(t) approvals := ValidatorSetApprovals{ {NodeID: nodeID{1}, PChainHeight: 10}, {NodeID: nodeID{2}, PChainHeight: 20}, {NodeID: nodeID{3}, PChainHeight: 30}, } - filtered := approvals.Filter(func(_ int, v ValidatorSetApproval) bool { + filtered := approvals.Filter(func(v ValidatorSetApproval, _ simplex.Logger) bool { return v.PChainHeight > 15 - }) + }, logger) require.Len(t, filtered, 2) require.Equal(t, uint64(20), filtered[0].PChainHeight) require.Equal(t, uint64(30), filtered[1].PChainHeight) // Filter all - filtered = approvals.Filter(func(int, ValidatorSetApproval) bool { + filtered = approvals.Filter(func(ValidatorSetApproval, simplex.Logger) bool { return false - }) + }, logger) require.Empty(t, filtered) } diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go new file mode 100644 index 00000000..58eefb30 --- /dev/null +++ b/msm/fake_node_test.go @@ -0,0 +1,449 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "context" + "crypto/rand" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/stretchr/testify/require" +) + +func TestFakeNode(t *testing.T) { + validatorSetRetriever := validatorSetRetriever{ + resultMap: map[uint64]NodeBLSMappings{ + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, + 300: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 3, NodeID: [20]byte{3}}, {BLSKey: []byte{4}, Weight: 1, NodeID: [20]byte{4}}}, + }, + } + + var pChainHeight atomic.Uint64 + pChainHeight.Store(100) + node := newFakeNode(t) + node.sm.GetValidatorSet = validatorSetRetriever.getValidatorSet + node.sm.GetPChainHeight = func() uint64 { + return pChainHeight.Load() + } + + // Create some blocks and finalize them, until we reach height 10 + for node.Height() < 10 { + node.act() + } + + // Next, we increase the P-Chain height, which should cause the node to update its validator set and move to the new epoch. + pChainHeight.Store(200) + + epoch := node.Epoch() + for node.Epoch() == epoch { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{1}, PChainHeight: 200, Signature: []byte{1}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 200, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), uint64(1)) + + // Finally, we increase the P-Chain height again, which should cause the node to update its validator set and move to the new epoch. + + pChainHeight.Store(300) + + epoch = node.Epoch() + for node.Epoch() == epoch { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 300, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{3}, PChainHeight: 300, Signature: []byte{3}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), epoch) +} + +func TestFakeNodeEmptyMempool(t *testing.T) { + validatorSetRetriever := validatorSetRetriever{ + resultMap: map[uint64]NodeBLSMappings{ + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, + 300: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 3, NodeID: [20]byte{3}}, {BLSKey: []byte{4}, Weight: 1, NodeID: [20]byte{4}}}, + }, + } + + var pChainHeight uint64 = 100 + node := newFakeNode(t) + node.sm.MaxBlockBuildingWaitTime = 100 * time.Millisecond + node.sm.GetValidatorSet = validatorSetRetriever.getValidatorSet + node.sm.GetPChainHeight = func() uint64 { + return pChainHeight + } + + // Create some blocks and finalize them, until we reach height 10 + for node.Height() < 10 { + node.act() + } + + // Next, we increase the P-Chain height, which should cause the node to update its validator set and move to the new epoch. + pChainHeight = 200 + + // However, we mark the mempool as empty, which should cause the node to wait until it sees a change in the P-Chain height, rather than building blocks on top of the old epoch. + node.mempoolEmpty = true + + // We build blocks until the sealing block is finalized. + for node.lastFinalizedBlock().Metadata.SimplexEpochInfo.BlockValidationDescriptor == nil { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{1}, PChainHeight: 200, Signature: []byte{1}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 200, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + node.mempoolEmpty = false + + // Build a new block and check that the node has transitioned to the new epoch, + // rather than building a block on top of the old epoch. + height := node.Height() + + for node.Height() == height { + node.act() + } + require.Greater(t, node.Epoch(), uint64(1)) + + t.Log("Epoch:", node.Epoch()) + + epoch := node.Epoch() + require.Greater(t, epoch, uint64(1)) + + // Finally, we increase the P-Chain height again, which should cause the node to update its validator set and move to the new epoch. + + pChainHeight = 300 + + for node.Height() < 30 { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 300, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{3}, PChainHeight: 300, Signature: []byte{3}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), epoch) + require.Equal(t, node.Height(), uint64(30)) +} + +type innerBlock struct { + InnerBlock + Prev [32]byte +} + +type blockState struct { + block StateMachineBlock + finalized bool + innerBlock VMBlock +} + +type fakeNode struct { + t *testing.T + sm *StateMachine + mempoolEmpty bool + // blocks holds notarized blocks in order. Finalized blocks always form a + // prefix: all finalized entries precede all non-finalized entries. + blocks []blockState +} + +func (fn *fakeNode) WaitForProgress(ctx context.Context, pChainHeight uint64) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Millisecond): + if fn.sm.GetPChainHeight() != pChainHeight { + return nil + } + } + } +} + +func (fn *fakeNode) WaitForPendingBlock(ctx context.Context) { + if fn.mempoolEmpty { + <-ctx.Done() + return + } +} + +func newFakeNode(t *testing.T) *fakeNode { + sm, _ := newStateMachine(t) + + fn := &fakeNode{ + t: t, + sm: sm, + } + + fn.sm.BlockBuilder = fn + fn.sm.PChainProgressListener = fn + + fn.sm.GetBlock = func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + if seq == 0 { + return genesisBlock, nil, nil + } + for _, bs := range fn.blocks { + match := bs.block.Digest() == digest + if !match { + md, err := simplex.ProtocolMetadataFromBytes(bs.block.Metadata.SimplexProtocolMetadata) + if err != nil { + return StateMachineBlock{}, nil, err + } + match = md.Seq == seq + } + if match { + var fin *simplex.Finalization + if bs.finalized { + fin = &simplex.Finalization{} + } + return bs.block, fin, nil + } + } + + require.Failf(t, "not found block", "height: %d", seq) + return StateMachineBlock{}, nil, fmt.Errorf("block not found") + } + + fn.sm.FirstEverSimplexBlock = func() *StateMachineBlock { + for _, block := range fn.blocks { + if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { + continue + } + return &block.block + } + require.FailNow(t, "block not found") + return nil + } + + return fn +} + +// lastFinalizedBlock returns the most recently finalized block. +// Panics if nothing has been finalized. +func (fn *fakeNode) lastFinalizedBlock() StateMachineBlock { + for i := len(fn.blocks) - 1; i >= 0; i-- { + if fn.blocks[i].finalized { + return fn.blocks[i].block + } + } + panic("no finalized block") +} + +func (fn *fakeNode) Height() uint64 { + var count uint64 + for _, bs := range fn.blocks { + if bs.finalized { + count++ + } + } + return count +} + +func (fn *fakeNode) Epoch() uint64 { + return fn.blocks[len(fn.blocks)-1].block.Metadata.SimplexEpochInfo.EpochNumber +} + +func (fn *fakeNode) act() { + if fn.canFinalize() && flipCoin() { + fn.tryFinalizeNextBlock() + return + } + + if flipCoin() { + return + } + + fn.buildAndNotarizeBlock() +} + +func (fn *fakeNode) canFinalize() bool { + return fn.nextUnfinalizedIndex() < len(fn.blocks) +} + +func (fn *fakeNode) nextUnfinalizedIndex() int { + for i, bs := range fn.blocks { + if !bs.finalized { + return i + } + } + return len(fn.blocks) +} + +func (fn *fakeNode) tryFinalizeNextBlock() { + nextIndex := fn.nextUnfinalizedIndex() + + if fn.isNextBlockTelock(nextIndex) { + return + } + + fn.blocks[nextIndex].finalized = true + block := fn.blocks[nextIndex].block + + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(fn.t, err) + + fn.sm.LatestPersistedHeight = md.Seq + fn.t.Logf("Finalized block at height %d with epoch %d", md.Seq, block.Metadata.SimplexEpochInfo.EpochNumber) + + // If we just finalized a sealing block, trim trailing Telock blocks. + if block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil { + fn.blocks = fn.blocks[:nextIndex+1] + fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.blocks)) + } +} + +func (fn *fakeNode) isNextBlockTelock(nextIndex int) bool { + if nextIndex == 0 { + return false + } + return fn.blocks[nextIndex].block.Metadata.SimplexEpochInfo.SealingBlockSeq > 0 +} + +func (fn *fakeNode) buildAndNotarizeBlock() { + vmBlock, block := fn.buildBlock() + require.NoError(fn.t, fn.sm.VerifyBlock(context.Background(), block)) + + fn.blocks = append(fn.blocks, blockState{block: *block, innerBlock: vmBlock}) +} + +func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { + parentBlock := fn.getParentBlock() + + lastMD, prevBlockDigest := fn.prepareMetadataAndPrevBlockDigest() + + _, finalization, err := fn.sm.GetBlock(lastMD.Seq, prevBlockDigest) + require.NoError(fn.t, err) + + finalizedString := "not finalized" + if finalization != nil { + finalizedString = "finalized" + } + + fn.t.Logf("Building a block on top of %s parent with epoch %d", finalizedString, parentBlock.Metadata.SimplexEpochInfo.EpochNumber) + + block, err := fn.sm.BuildBlock(context.Background(), simplex.ProtocolMetadata{ + Seq: lastMD.Seq + 1, + Round: lastMD.Round + 1, + Prev: prevBlockDigest, + }, nil) + require.NoError(fn.t, err) + + return block.InnerBlock, block +} + +func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetadata, [32]byte) { + var lastMD *simplex.ProtocolMetadata + var err error + lastBlockDigest := genesisBlock.Digest() + if len(fn.blocks) > 0 { + lastBlock := fn.blocks[len(fn.blocks)-1].block + lastBlockDigest = lastBlock.Digest() + lastMD, err = simplex.ProtocolMetadataFromBytes(lastBlock.Metadata.SimplexProtocolMetadata) + require.NoError(fn.t, err) + } else { + lastMD = &simplex.ProtocolMetadata{ + Prev: lastBlockDigest, + } + } + return lastMD, lastBlockDigest +} + +func (fn *fakeNode) BuildBlock(context.Context, uint64) (VMBlock, error) { + // Count the number of inner blocks in the chain + var count int + for _, bs := range fn.blocks { + if bs.block.InnerBlock != nil { + count++ + } + } + + vmBlock := &innerBlock{ + Prev: fn.getLastVMBlockDigest(), + InnerBlock: InnerBlock{ + Bytes: randomBuff(10), + TS: time.Now(), + BlockHeight: uint64(count), + }, + } + return vmBlock, nil +} + +func (fn *fakeNode) getParentBlock() StateMachineBlock { + if len(fn.blocks) > 0 { + return fn.blocks[len(fn.blocks)-1].block + } + gb := genesisBlock.InnerBlock.(*InnerBlock) + return StateMachineBlock{ + InnerBlock: &innerBlock{ + InnerBlock: *gb, + }, + } +} + +func (fn *fakeNode) getLastVMBlockDigest() [32]byte { + for i := len(fn.blocks) - 1; i >= 0; i-- { + if fn.blocks[i].block.InnerBlock != nil { + return fn.blocks[i].block.Digest() + } + } + return genesisBlock.Digest() +} + +func randomBuff(n int) []byte { + buff := make([]byte, n) + _, err := rand.Read(buff) + if err != nil { + panic(err) + } + return buff +} + +func flipCoin() bool { + buff := make([]byte, 1) + _, err := rand.Read(buff) + if err != nil { + panic(err) + } + + lsb := buff[0] & 1 + + return lsb == 1 +} diff --git a/msm/misc_test.go b/msm/misc_test.go index b899aa6a..b78d2cd3 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -4,9 +4,19 @@ package metadata import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/asn1" + "fmt" + "maps" "math" "testing" + "time" + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -142,3 +152,345 @@ func TestBitmask(t *testing.T) { require.False(t, cloned.Contains(7)) }) } + +// Test helpers + +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} + +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) +} + +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context) error { + return nil +} + +// fakeVMBlock is a minimal VMBlock implementation for tests. +type fakeVMBlock struct { + height uint64 +} + +func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } +func (f *fakeVMBlock) Height() uint64 { return f.height } +func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } +func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } + +type outerBlock struct { + finalization *simplex.Finalization + block StateMachineBlock +} + +type blockStore map[uint64]*outerBlock + +func (bs blockStore) clone() blockStore { + newStore := make(blockStore) + maps.Copy(newStore, bs) + return newStore +} + +func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[seq] + if !exits { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) + } + return blk.block, blk.finalization, nil +} + +type approvalsRetriever struct { + result ValidatorSetApprovals +} + +func (a approvalsRetriever) Approvals() ValidatorSetApprovals { + return a.result +} + +type signatureVerifier struct { + err error +} + +func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { + return sv.err +} + +type signatureAggregator struct { + weightByNodeID map[string]uint64 + totalWeight uint64 +} + +type aggregatrdSignature struct { + Signatures [][]byte +} + +func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + all := make([][]byte, 0, len(sigs)+1) + all = append(all, sigs...) + if len(existing) > 0 { + all = append(all, existing) + } + return asn1.Marshal(aggregatrdSignature{Signatures: all}) +} + +func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { + var sum uint64 + for _, signer := range signers { + sum += sv.weightByNodeID[string(signer)] + } + return sum*3 > sv.totalWeight*2 +} + +func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { + return func(weights []simplex.Node) simplex.SignatureAggregator { + s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} + for _, nw := range weights { + s.weightByNodeID[string(nw.Node)] = nw.Weight + s.totalWeight += nw.Weight + } + return s + } +} + +type noOpPChainListener struct{} + +func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { + <-ctx.Done() + return ctx.Err() +} + +type blockBuilder struct { + block VMBlock + err error +} + +func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { + // Block is always ready in tests. +} + +func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { + return bb.block, bb.err +} + +type validatorSetRetriever struct { + result NodeBLSMappings + resultMap map[uint64]NodeBLSMappings + err error +} + +func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { + if vsr.resultMap != nil { + if result, ok := vsr.resultMap[height]; ok { + return result, vsr.err + } + } + return vsr.result, vsr.err +} + +type keyAggregator struct{} + +func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + aggregated := make([]byte, 0) + for _, key := range keys { + aggregated = append(aggregated, key...) + } + return aggregated, nil +} + +var ( + genesisBlock = StateMachineBlock{ + // Genesis block metadata has all zero values + InnerBlock: &InnerBlock{ + TS: time.Now(), + Bytes: []byte{1, 2, 3}, + }, + } +) + +type dynamicApprovalsRetriever struct { + approvals *ValidatorSetApprovals +} + +func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { + return *d.approvals +} + +func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { + startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) + blocks := make([]StateMachineBlock, 0, endHeight+1) + var round, seq uint64 + for h := uint64(0); h <= endHeight; h++ { + index := len(blocks) + + if h == 0 { + blocks = append(blocks, genesisBlock) + continue + } + + if h < simplexStartHeight { + blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) + continue + } + + seq = uint64(index) + + blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) + round++ + } + return blocks +} + +func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + prev := genesisBlock.Digest() + if index > 0 { + prev = blocks[index-1].Digest() + } + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + Metadata: StateMachineMetadata{ + PChainHeight: 100, + SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Epoch: 1, + Prev: prev, + }).Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{}, + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: uint64(index), + }, + }, + } +} + +func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h-startHeight) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + } +} + +type testConfig struct { + blockStore blockStore + approvalsRetriever approvalsRetriever + signatureVerifier signatureVerifier + signatureAggregator signatureAggregator + blockBuilder blockBuilder + keyAggregator keyAggregator + validatorSetRetriever validatorSetRetriever +} + +func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { + bs := make(blockStore) + bs[0] = &outerBlock{block: genesisBlock} + + var testConfig testConfig + testConfig.blockStore = bs + testConfig.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, + } + + smConfig := Config{ + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + LastNonSimplexBlockPChainHeight: 100, + FirstEverSimplexBlock: func() *StateMachineBlock { + var res *StateMachineBlock + min := uint64(math.MaxUint64) + for seq, block := range testConfig.blockStore { + if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { + continue + } + if seq < min { + min = seq + res = &block.block + } + } + return res + }, + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, + SignatureAggregatorCreator: newSignatureAggregatorCreator(), + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, + GetPChainHeight: func() uint64 { + return 100 + }, + GetUpgrades: func() any { + return nil + }, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, + LastNonSimplexInnerBlock: genesisBlock.InnerBlock, + } + + sm, err := NewStateMachine(&smConfig) + require.NoError(t, err) + + return sm, &testConfig +} + +// concatAggregator concatenates signatures for easy verification in tests. +type concatAggregator struct{} + +func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + result := bytes.Join(sigs, nil) + return append(result, existing...), nil +} + +func (concatAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type failingAggregator struct{} + +func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { + return nil, fmt.Errorf("aggregation failed") +} + +func (failingAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} diff --git a/msm/msm.go b/msm/msm.go index c8efd734..2d240331 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -4,13 +4,13 @@ package metadata import ( + "context" "crypto/sha256" - "errors" "fmt" - "math" - "sort" + "time" "github.com/ava-labs/simplex" + "go.uber.org/zap" ) // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. @@ -36,62 +36,707 @@ func (smb *StateMachineBlock) Digest() [32]byte { return sha256.Sum256(combined) } -// Used to aggregate validator signatures for epoch transitions. -type SignatureAggregator interface { - AggregateSignatures(signatures ...[]byte) ([]byte, error) +// ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. +type ApprovalsRetriever interface { + Approvals() ValidatorSetApprovals } -// ValidatorSetRetriever retrieves the validator set at a given P-chain height. -type ValidatorSetRetriever func(pChainHeight uint64) (NodeBLSMappings, error) +// KeyAggregator combines multiple public keys into a single aggregated public key. +type KeyAggregator interface { + AggregateKeys(keys ...[]byte) ([]byte, error) +} -// RetrievingOpts specifies the options for retrieving a block by height and/or digest. -type RetrievingOpts struct { - // Height is the sequence number of the block to retrieve. - Height uint64 - // Digest is the expected hash of the block, used for validation. - Digest [32]byte +// SignatureVerifier verifies a cryptographic signature against a message and public key. +// Used to verify Approvals from validators for epoch transitions. +type SignatureVerifier interface { + VerifySignature(signature []byte, message []byte, publicKey []byte) error } -// BlockRetriever retrieves a block and its finalization status given the retrieval options. +// ValidatorSetRetriever retrieves the validator set at a given P-chain height. +type ValidatorSetRetriever func(pChainHeight uint64) (NodeBLSMappings, error) + +// BlockRetriever retrieves a block and its finalization status given the block's sequence number and expected digest. // If the block cannot be found it returns ErrBlockNotFound. // If an error occurs during retrieval, it returns a non-nil error. -type BlockRetriever func(RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) +type BlockRetriever func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) + +// BlockBuilder builds a new VM block with the given observed P-chain height. +type BlockBuilder interface { + BuildBlock(ctx context.Context, pChainHeight uint64) (VMBlock, error) + + // WaitForPendingBlock returns when either the given context is cancelled, + // or when the VM signals that a block should be built. + WaitForPendingBlock(ctx context.Context) +} + +type verificationInput struct { + prevMD StateMachineMetadata + proposedBlockMD StateMachineMetadata + hasInnerBlock bool + innerBlockTimestamp time.Time // only set when hasInnerBlock is true + prevBlockSeq uint64 + nextBlockType BlockType + state state +} + +type verifier interface { + Verify(in verificationInput) error +} + +// StateMachine manages block building and verification across epoch transitions. +type StateMachine struct { + // verifiers is the list of verifiers used to verify proposed blocks. + // Each verifier is responsible for verifying a specific aspect of the block's metadata. + verifiers []verifier + + *Config +} + +// Config contains the dependencies and configuration parameters needed to initialize the StateMachine. +type Config struct { + // LatestPersistedHeight is the height of the most recently persisted block. + LatestPersistedHeight uint64 + // MaxBlockBuildingWaitTime is the maximum duration to wait for the VM to build a block + // before producing a block without an inner block. + MaxBlockBuildingWaitTime time.Duration + // TimeSkewLimit is the maximum allowed time difference between a block's timestamp and the current time. + TimeSkewLimit time.Duration + // GetTime returns the current time. + GetTime func() time.Time + // GetPChainHeight returns the latest known P-chain height. + GetPChainHeight func() uint64 + // GetUpgrades returns the current upgrade configuration. + GetUpgrades func() UpgradeConfig + // BlockBuilder builds new VM blocks. + BlockBuilder BlockBuilder + // Logger is used for logging state machine operations. + Logger simplex.Logger + // GetValidatorSet retrieves the validator set at a given P-chain height. + GetValidatorSet ValidatorSetRetriever + // GetBlock retrieves a previously built or finalized block. + GetBlock BlockRetriever + // ApprovalsRetriever retrieves validator approvals for epoch transitions. + ApprovalsRetriever ApprovalsRetriever + // SignatureAggregatorCreator creates a new SignatureAggregator for aggregating validator signatures for epoch transitions. + SignatureAggregatorCreator simplex.SignatureAggregatorCreator + // KeyAggregator aggregates public keys from validators. + KeyAggregator KeyAggregator + // SignatureVerifier verifies signatures from validators. + SignatureVerifier SignatureVerifier + // PChainProgressListener listens for changes in the P-chain height to trigger block building or epoch transitions. + PChainProgressListener PChainProgressListener + // FirstEverSimplexBlock is the first block ever built by Simplex, or nil if Simplex has yet to build a block. + FirstEverSimplexBlock func() *StateMachineBlock + // LastNonSimplexBlockPChainHeight is the P-chain height of the last block built by a non-Simplex proposer. + // It is used to determine the validator set of the first ever Simplex epoch. + LastNonSimplexBlockPChainHeight uint64 + // LastNonSimplexInnerBlock is the inner block of the last block built by a non-Simplex proposer. + LastNonSimplexInnerBlock VMBlock + // GenesisValidatorSet is the validator set used for the genesis block. + GenesisValidatorSet NodeBLSMappings +} type state uint8 const ( - stateFirstSimplexBlock state = iota + stateFirstSimplexBlock state = iota + 1 stateBuildBlockNormalOp stateBuildCollectingApprovals stateBuildBlockEpochSealed ) -func identifyCurrentState(prevBlockSimplexEpochInfo SimplexEpochInfo) state { - // If this is the first ever epoch, then this is also the first ever block to be built by Simplex. - if prevBlockSimplexEpochInfo.EpochNumber == 0 { - return stateFirstSimplexBlock +func NewStateMachine(config *Config) (*StateMachine, error) { + if config.LastNonSimplexInnerBlock == nil { + config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") + return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + } + sm := StateMachine{Config: config} + return &sm, nil +} + +// BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. +func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.ProtocolMetadata, blacklist *simplex.Blacklist) (*StateMachineBlock, error) { + // The zero sequence number is reserved for the genesis block, which should never be built. + if metadata.Seq == 0 { + return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) + } + + prevBlockSeq := metadata.Seq - 1 + + parentBlock, _, err := sm.GetBlock(prevBlockSeq, metadata.Prev) + if err != nil { + return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", prevBlockSeq, metadata.Prev.String(), err) + } + + start := time.Now() + + sm.Logger.Debug("Building block", + zap.Uint64("seq", metadata.Seq), + zap.Uint64("epoch", metadata.Epoch), + zap.Stringer("prevHash", metadata.Prev)) + + defer func() { + elapsed := time.Since(start) + sm.Logger.Debug("Built block", + zap.Uint64("seq", metadata.Seq), + zap.Uint64("epoch", metadata.Epoch), + zap.Stringer("prevHash", metadata.Prev), + zap.Duration("elapsed", elapsed), + ) + }() + + var simplexBlacklistBytes []byte + if blacklist != nil { + simplexBlacklistBytes = blacklist.Bytes() + } + + // In order to know where in the epoch change process we are, + // we identify the current state by looking at the parent block's epoch info. + currentState := parentBlock.Metadata.SimplexEpochInfo.NextState() + + simplexMetadataBytes := metadata.Bytes() + + switch currentState { + case stateFirstSimplexBlock: + return sm.buildBlockZero(parentBlock, simplexMetadataBytes, simplexBlacklistBytes) + case stateBuildBlockNormalOp: + return sm.buildBlockNormalOp(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + case stateBuildCollectingApprovals: + return sm.buildBlockCollectingApprovals(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + case stateBuildBlockEpochSealed: + return sm.buildBlockEpochSealed(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + default: + return nil, fmt.Errorf("unknown state %d", currentState) + } +} + +// VerifyBlock validates a proposed block by checking its metadata, epoch info, +// and inner block against the previous block and the current state. +func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { + if block == nil { + return fmt.Errorf("InnerBlock is nil") + } + + pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed to parse ProtocolMetadata: %w", err) + } + + seq := pmd.Seq + + if seq == 0 { + return fmt.Errorf("attempted to build a genesis inner block") + } + + prevBlock, _, err := sm.GetBlock(seq-1, pmd.Prev) + if err != nil { + return fmt.Errorf("failed to retrieve previous (%d) inner block: %w", seq-1, err) + } + + prevMD := prevBlock.Metadata + currentState := prevMD.SimplexEpochInfo.NextState() + + switch currentState { + case stateFirstSimplexBlock: + err = sm.verifyBlockZero(block, prevBlock) + default: + err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) + } + return err +} + +func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { + blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) + sm.Logger.Debug("Identified block type", + zap.Stringer("blockType", blockType), + zap.Bool("nextHasBVD", block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil), + zap.Uint64("nextEpochNumber", block.Metadata.SimplexEpochInfo.EpochNumber), + zap.Bool("prevHasBVD", prevBlockMD.SimplexEpochInfo.BlockValidationDescriptor != nil), + zap.Uint64("prevEpochNumber", prevBlockMD.SimplexEpochInfo.EpochNumber), + zap.Uint64("prevNextPChainRefHeight", prevBlockMD.SimplexEpochInfo.NextPChainReferenceHeight), + zap.Uint64("prevSealingBlockSeq", prevBlockMD.SimplexEpochInfo.SealingBlockSeq), + zap.Uint64("prevSeq", prevSeq), + ) + + var innerBlockTimestamp time.Time + if block.InnerBlock != nil { + innerBlockTimestamp = block.InnerBlock.Timestamp() + } + + for _, verifier := range sm.verifiers { + if err := verifier.Verify(verificationInput{ + proposedBlockMD: block.Metadata, + nextBlockType: blockType, + prevMD: prevBlockMD, + state: state, + prevBlockSeq: prevSeq, + hasInnerBlock: block.InnerBlock != nil, + innerBlockTimestamp: innerBlockTimestamp, + }); err != nil { + sm.Logger.Debug("Invalid block", zap.Error(err)) + return err + } + } + + if block.InnerBlock == nil { + return nil + } + + return block.InnerBlock.Verify(ctx) +} + +// buildBlockNormalOp builds a block while not trying to transition to a new epoch. +func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // Since in the previous block, we were not transitioning to a new epoch, + // the P-chain reference height and epoch of the new block should remain the same. + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + blockBuildingDecider := sm.createBlockBuildingDecider(parentBlock) + decisionToBuildBlock, err := blockBuildingDecider.shouldBuildBlock(ctx) + if err != nil { + return nil, err + } + + sm.Logger.Debug("Block building decision", + zap.Bool("build inner block", decisionToBuildBlock.buildInnerBlock), + zap.Bool("transition epoch", decisionToBuildBlock.transitionEpoch), + zap.Uint64("P-chain height", decisionToBuildBlock.pChainHeight)) + + if decisionToBuildBlock.transitionEpoch { + sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", decisionToBuildBlock.pChainHeight)) + newSimplexEpochInfo.NextPChainReferenceHeight = decisionToBuildBlock.pChainHeight + } + + var innerBlock VMBlock + + if decisionToBuildBlock.buildInnerBlock { + // TODO: This P-chain height should be taken from the ICM epoch + innerBlock, err = sm.BlockBuilder.BuildBlock(ctx, decisionToBuildBlock.pChainHeight) + if err != nil { + return nil, err + } } - // If we don't have a next P-chain preference height, it means we are not transitioning to a new epoch just yet. - if prevBlockSimplexEpochInfo.NextPChainReferenceHeight == 0 { - return stateBuildBlockNormalOp + return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil +} + +func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock) blockBuildingDecider { + blockBuildingDecider := blockBuildingDecider{ + logger: sm.Logger, + maxBlockBuildingWaitTime: sm.MaxBlockBuildingWaitTime, + pChainListener: sm.PChainProgressListener, + getPChainHeight: sm.GetPChainHeight, + waitForPendingBlock: sm.BlockBuilder.WaitForPendingBlock, + hasValidatorSetChanged: func(pChainHeight uint64) (bool, error) { + // The given pChainHeight was sampled by the caller of shouldTransitionEpoch(). + // We compare between the current validator set, defined by the P-chain reference height in the parent block, + // and the new validator set defined by the given pChainHeight. + // If they are different, then we should transition to a new epoch. + + currentValidatorSet, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return false, err + } + + newValidatorSet, err := sm.GetValidatorSet(pChainHeight) + if err != nil { + return false, err + } + + if !currentValidatorSet.Equal(newValidatorSet) { + sm.Logger.Debug("Validator set has changed, should transition epoch", + zap.String("currentValidatorSet", fmt.Sprintf("%v", currentValidatorSet.NodeWeights())), + zap.String("newValidatorSet", fmt.Sprintf("%v", newValidatorSet.NodeWeights())), + zap.Uint64("currentPChainRefHeight", parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight), + zap.Uint64("newPChainHeight", pChainHeight)) + return true, nil + } + return false, nil + }, } + return blockBuildingDecider +} + +// buildBlockZero builds the first ever block for Simplex, +// which is a special block that introduces the first validator set and starts the first epoch. +func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { + pChainHeight := sm.LastNonSimplexBlockPChainHeight - // If the previous block has a sealing block sequence, it's a Telock. - // If it has a block validation descriptor, it's a sealing block. - // Either way, the epoch has been sealed. - if prevBlockSimplexEpochInfo.SealingBlockSeq > 0 || prevBlockSimplexEpochInfo.BlockValidationDescriptor != nil { - return stateBuildBlockEpochSealed + var validatorSet NodeBLSMappings + if sm.LastNonSimplexInnerBlock.Height() == 0 { + validatorSet = sm.GenesisValidatorSet + } else { + var err error + validatorSet, err = sm.GetValidatorSet(pChainHeight) + if err != nil { + return nil, err + } + } + + var prevVMBlockSeq uint64 + if parentBlock.InnerBlock != nil { + prevVMBlockSeq = parentBlock.InnerBlock.Height() + } else { + // We can only have blocks without inner blocks in Simplex blocks, but this is the first Simplex block. + // Therefore, the parent block must have an inner block. + sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") + return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") } - // In any other case, NextPChainReferenceHeight > 0 but the previous block is not a Telock or sealing block, - // it means we are in the process of collecting approvals for the next epoch. - return stateBuildCollectingApprovals + timestamp := sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli() + simplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, validatorSet, prevVMBlockSeq) + + md, err := simplex.ProtocolMetadataFromBytes(simplexMetadata) + if err != nil { + return nil, fmt.Errorf("failed to parse simplex metadata: %w", err) + } + md.Prev = sm.LastNonSimplexInnerBlock.Digest() + md.Seq = sm.LastNonSimplexInnerBlock.Height() + + return &StateMachineBlock{ + Metadata: StateMachineMetadata{ + Timestamp: uint64(timestamp), + SimplexProtocolMetadata: simplexMetadata, + SimplexBlacklist: simplexBlacklist, + SimplexEpochInfo: simplexEpochInfo, + PChainHeight: pChainHeight, + }, + }, nil +} + +func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock StateMachineBlock) error { + if block == nil { + return fmt.Errorf("block is nil") + } + + simplexEpochInfo := block.Metadata.SimplexEpochInfo + + if simplexEpochInfo.EpochNumber != 1 { + return fmt.Errorf("invalid epoch number (%d), should be 1", simplexEpochInfo.EpochNumber) + } + + if prevBlock.InnerBlock == nil { + return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) + } + + pChainHeight := sm.LastNonSimplexBlockPChainHeight + prevVMBlockSeq := prevBlock.InnerBlock.Height() + + if block.Metadata.PChainHeight != pChainHeight { + return fmt.Errorf("invalid P-chain height (%d), expected to be %d", + block.Metadata.PChainHeight, pChainHeight) + } + + var expectedValidatorSet NodeBLSMappings + if prevBlock.InnerBlock.Height() == 0 { + expectedValidatorSet = sm.GenesisValidatorSet + } else { + var err error + expectedValidatorSet, err = sm.GetValidatorSet(pChainHeight) + if err != nil { + return fmt.Errorf("failed to retrieve validator set at height %d: %w", pChainHeight, err) + } + } + + if simplexEpochInfo.BlockValidationDescriptor == nil { + return fmt.Errorf("invalid BlockValidationDescriptor: should not be nil") + } + + membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members + if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { + return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", pChainHeight) + } + + // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo + expectedSimplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, expectedValidatorSet, prevVMBlockSeq) + + if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { + return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) + } + + // The InnerBlock must match the last non-Simplex inner block. + if block.InnerBlock != nil { + return fmt.Errorf("zero block must not have an inner block") + } + if prevBlock.InnerBlock.Digest() != sm.LastNonSimplexInnerBlock.Digest() { + return fmt.Errorf("zero block inner block digest does not match last non-Simplex inner block digest") + } + + // The timestamp must equal the last non-Simplex inner block's timestamp. + expectedTimestamp := uint64(sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli()) + if block.Metadata.Timestamp != expectedTimestamp { + return fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, block.Metadata.Timestamp) + } + + return nil +} + +func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // The P-chain reference height and epoch number should remain the same until we transition to the new epoch. + // The next P-chain reference height should have been set in the previous block, + // which is the reason why we are collecting approvals in the first place. + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + // We prepare information that is needed to compute the approvals for the new epoch, + // such as the validator set for the next epoch, and the approvals from peers. + validators, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) + if err != nil { + return nil, err + } + + // We retrieve approvals that validators have sent us for the next epoch. + // These approvals are signed by validators of the next epoch. + approvalsFromPeers := sm.ApprovalsRetriever.Approvals() + sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) + + nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight + prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals + + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + + newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators, sm.Logger) + if err != nil { + return nil, err + } + + // This might be the first time we created approvals for the next epoch, + // so we need to initialize the NextEpochApprovals. + if newSimplexEpochInfo.NextEpochApprovals == nil { + newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} + } + // The node IDs and signature are aggregated across all past and present approvals. + newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs + newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature + pChainHeight := parentBlock.Metadata.PChainHeight + + // We might not have enough approvals to seal the current epoch, + // in which case we just carry over the approvals we have so far to the next block, + // so that eventually we'll have enough approvals to seal the epoch. + if !newApprovals.canSeal { + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + } + + sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") + + // Else, we have enough approvals to seal the epoch, so we create the sealing block. + return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) +} + +// buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. +// If the VM fails to build a block within that time, we build a block without an inner block, +// so that we can continue making progress and not get stuck waiting for the VM. +func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + impatientContext, cancel := context.WithTimeout(ctx, sm.MaxBlockBuildingWaitTime) + defer cancel() + + start := time.Now() + + // TODO: This P-chain height should be taken from the ICM epoch + childBlock, err := sm.BlockBuilder.BuildBlock(impatientContext, pChainHeight) + if err != nil && impatientContext.Err() == nil { + // If we got an error building the block, and we didn't time out, log the error but continue building the block without the inner block, + // so that we can continue making progress and not get stuck on a single block. + sm.Logger.Error("Error building block, building block without inner block instead", zap.Error(err)) + } + if impatientContext.Err() != nil { + sm.Logger.Debug("Timed out waiting for block to be built, building block without inner block instead", + zap.Duration("elapsed", time.Since(start)), zap.Duration("maxBlockBuildingWaitTime", sm.MaxBlockBuildingWaitTime)) + } + return sm.wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil +} + +func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) + if err != nil { + return nil, err + } + if simplexEpochInfo.BlockValidationDescriptor == nil { + simplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + } + simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members = validators + + // If this is not the first epoch, and this is the sealing block, we set the hash of the previous sealing block. + if simplexEpochInfo.EpochNumber > 1 { + prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) + if err != nil { + sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) + return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) + } + if finalization == nil { + sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) + return nil, fmt.Errorf("previous sealing InnerBlock at epoch %d is not finalized", simplexEpochInfo.EpochNumber-1) + } + simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() + } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. + firstSimplexBlock := sm.FirstEverSimplexBlock() + if firstSimplexBlock == nil { + return nil, fmt.Errorf("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") + } + simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() + } + + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) +} + +// wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. +func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { + timestamp := parentBlock.Metadata.Timestamp + + hasChildBlock := childBlock != nil + + var newTimestamp time.Time + if hasChildBlock { + newTimestamp = childBlock.Timestamp() + timestamp = uint64(newTimestamp.UnixMilli()) + } + + return &StateMachineBlock{ + InnerBlock: childBlock, + Metadata: StateMachineMetadata{ + Timestamp: timestamp, + SimplexProtocolMetadata: simplexMetadata, + SimplexBlacklist: simplexBlacklist, + SimplexEpochInfo: newSimplexEpochInfo, + PChainHeight: pChainHeight, + }, + } +} + +// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // We check if the sealing block has already been finalized. + // If not, we build a Telock block. + + sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.SealingBlockSeq + + // If the sealing block sequence is still 0, it means previous block was the sealing block. + if sealingBlockSeq == 0 { + sealingBlockSeq = prevBlockSeq + } + + if sealingBlockSeq == 0 { + return nil, fmt.Errorf("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + } + + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + SealingBlockSeq: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) + if err != nil { + return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) + } + + isSealingBlockFinalized := finalization != nil + + if !isSealingBlockFinalized { + pChainHeight := parentBlock.Metadata.PChainHeight + return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + } + + // Else, we build a block for the new epoch. + newSimplexEpochInfo = SimplexEpochInfo{ + // P-chain reference height is previous block's NextPChainReferenceHeight. + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + // The epoch number is the sequence of the sealing block. + EpochNumber: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + // TODO: This P-chain height should be taken from the ICM epoch + childBlock, err := sm.BlockBuilder.BuildBlock(ctx, sm.GetPChainHeight()) + if err != nil { + return nil, err + } + + return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, parentBlock.Metadata.PChainHeight, simplexMetadata, simplexBlacklist), nil +} + +// constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. +func constructSimplexZeroBlockSimplexEpochInfo(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight, + EpochNumber: 1, + // We treat the zero block as a special case, and we encode in it the block validation descriptor, + // despite it not actually being a sealing block. This is because the zero block is the first block that introduces the validator set. + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: newValidatorSet, + }, + }, + NextEpochApprovals: nil, // We don't need to collect approvals to seal the first ever epoch. + PrevVMBlockSeq: prevVMBlockSeq, + SealingBlockSeq: 0, // We don't have a sealing block in the zero block. + PrevSealingBlockHash: [32]byte{}, // The zero block has no previous sealing block. + NextPChainReferenceHeight: 0, + } + return newSimplexEpochInfo +} + +func computeNewApprovals( + prevNextEpochApprovals *NextEpochApprovals, + approvalsFromPeers ValidatorSetApprovals, + pChainHeight uint64, + sigAggr simplex.SignatureAggregator, + validators NodeBLSMappings, + logger simplex.Logger, +) (*approvals, error) { + if prevNextEpochApprovals == nil { + prevNextEpochApprovals = &NextEpochApprovals{} + } + + oldApprovingNodes := bitmaskFromBytes(prevNextEpochApprovals.NodeIDs) + + nodeID2ValidatorIndex := validators.IndexByNodeID() + + oldApprovalFromPeersCount := len(approvalsFromPeers) + // We have the approvals obtained from peers, but we need to sanitize them by filtering out approvals that are not valid, + // such as approvals that do not agree with our candidate auxiliary info digest and P-Chain height, + // and approvals that are from nodes that are not in the validator set or have already approved in prior blocks. + approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes, logger) + logger.Debug("Sanitized approvals after filtering out invalid approvals", zap.Int("numApprovalsBefore", oldApprovalFromPeersCount), zap.Int("numApprovalsAfter", len(approvalsFromPeers))) + + // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. + aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(prevNextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr, logger) + if err != nil { + return nil, err + } + + // we check if we have enough approvals to seal the epoch by computing the relative approval ratio, + // which is the ratio of the total weight of approving nodes divided by the total weight of all validators. + canSeal := sigAggr.IsQuorum(validators.SelectSubset(newApprovingNodes)) + + return &approvals{ + canSeal: canSeal, + signature: aggregatedSignature, + nodeIDs: newApprovingNodes.Bytes(), + }, nil } // computeNewApproverSignaturesAndSigners computes the signatures of the nodes that approve the next epoch including the previous aggregated signature, // and bitmask of nodes that correspond to those signatures, and aggregates all signatures together. -func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int, aggregator SignatureAggregator) ([]byte, bitmask, error) { +func computeNewApproverSignaturesAndSigners( + nextEpochApprovals *NextEpochApprovals, + approvalsFromPeers ValidatorSetApprovals, + oldApprovingNodes bitmask, + nodeID2ValidatorIndex map[nodeID]int, + sigAggr simplex.SignatureAggregator, + logger simplex.Logger, +) ([]byte, bitmask, error) { if nextEpochApprovals == nil { return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") } @@ -101,6 +746,11 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // We will overwrite the old approving nodes with the new approving nodes, by turning on the bits for the new approvers. newApprovingNodes := oldApprovingNodes.Clone() + logger.Debug("Existing approving nodes bitmask before adding new approvals", + zap.Int("count", oldApprovingNodes.Len())) + + logger.Debug("New approvals from peers that we will consider for aggregation", zap.Int("count", len(approvalsFromPeers))) + for _, approval := range approvalsFromPeers { approvingNodeIndexOfNewApprover, exists := nodeID2ValidatorIndex[approval.NodeID] if !exists { @@ -120,12 +770,9 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // Add the existing signature into the list of signatures to aggregate existingSignature := nextEpochApprovals.Signature - if existingSignature != nil { - newSignatures = append(newSignatures, existingSignature) - } // Finally, we aggregate all signatures together, to compute the new aggregated signature. - aggregatedSignature, err := aggregator.AggregateSignatures(newSignatures...) + aggregatedSignature, err := sigAggr.AppendSignatures(existingSignature, newSignatures...) if err != nil { return nil, bitmask{}, fmt.Errorf("failed to aggregate signatures: %w", err) } @@ -135,23 +782,32 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // sanitizeApprovals filters out approvals that are not valid by checking if they agree with our candidate auxiliary info digest and P-Chain height, // and if they are from the validator set and haven't already been approved. -func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask) ValidatorSetApprovals { - filter1 := approvalsThatAgreeWithAuxInfoAndPChainHeight(pChainHeight) +func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask, logger simplex.Logger) ValidatorSetApprovals { + filter1 := approvalsThatAgreeWithPChainHeight(pChainHeight) filter2 := approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes.Clone(), nodeID2ValidatorIndex) - return approvals.Filter(filter1).Filter(filter2).UniqueByNodeID() + return approvals.Filter(filter1, logger).Filter(filter2, logger).UniqueByNodeID() } -func approvalsThatAgreeWithAuxInfoAndPChainHeight(pChainHeight uint64) func(i int, approval ValidatorSetApproval) bool { - return func(i int, approval ValidatorSetApproval) bool { +func approvalsThatAgreeWithPChainHeight(pChainHeight uint64) func(approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(approval ValidatorSetApproval, logger simplex.Logger) bool { // Pick only approvals that agree with our P-Chain height - return approval.PChainHeight == pChainHeight + ok := approval.PChainHeight == pChainHeight + if !ok { + logger.Debug("Filtering out approval that does not agree with our P-Chain height", + zap.String("nodeID", fmt.Sprintf("%x", approval.NodeID)), + zap.Uint64("approvalPChainHeight", approval.PChainHeight), + zap.Uint64("expectedPChainHeight", pChainHeight)) + } + return ok } } -func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(i int, approval ValidatorSetApproval) bool { - return func(i int, approval ValidatorSetApproval) bool { +func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(approval ValidatorSetApproval, logger simplex.Logger) bool { approvingNodeIndexOfNewApprover, exists := nodeID2ValidatorIndex[approval.NodeID] if !exists { + logger.Debug("Filtering out approval from node that is not in the validator set", + zap.String("nodeID", fmt.Sprintf("%x", approval.NodeID))) // If the approving node is not in the validator set, we ignore this approval. return false } @@ -160,39 +816,6 @@ func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes b } } -func findFirstSimplexBlock(getBlock BlockRetriever, endHeight uint64) (uint64, error) { - var haltError error - - if endHeight > math.MaxInt-1 { - return 0, fmt.Errorf("endHeight %d is too big, must be at most %d", endHeight, math.MaxInt-1) - } - - firstSimplexBlock := sort.Search(int(endHeight+1), func(i int) bool { - if haltError != nil { - return true - } - block, _, err := getBlock(RetrievingOpts{Height: uint64(i)}) - if errors.Is(err, simplex.ErrBlockNotFound) { - return false - } - if err != nil { - haltError = fmt.Errorf("error retrieving block at height %d: %w", i, err) - return false - } - // The first Simplex block is such that its epoch info isn't the zero value. - return !block.Metadata.SimplexEpochInfo.IsZero() - }) - if haltError != nil { - return 0, haltError - } - - if uint64(firstSimplexBlock) > endHeight { - return 0, fmt.Errorf("no simplex blocks found in range [%d, %d]", 0, endHeight) - } - - return uint64(firstSimplexBlock), nil -} - func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) uint64 { // Either our parent block has no inner block, in which case we just inherit its previous VM block sequence, if parentBlock.InnerBlock == nil { @@ -209,7 +832,7 @@ var ( func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { if prev.NextEpochApprovals == nil { - // Condition satisifed vacuously. + // Condition satisfied vacuously. return nil } // Else, prev.NextEpochApprovals is not nil. @@ -229,3 +852,9 @@ func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexE } return nil } + +type approvals struct { + canSeal bool + nodeIDs []byte + signature []byte +} diff --git a/msm/msm_test.go b/msm/msm_test.go index 585a1ed8..eff624da 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -4,26 +4,700 @@ package metadata import ( - "bytes" "context" + "crypto/rand" "fmt" - "math" "testing" "time" "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) -// fakeVMBlock is a minimal VMBlock implementation for tests. -type fakeVMBlock struct { - height uint64 +func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { + validMD := simplex.ProtocolMetadata{ + Round: 1, + Seq: 1, + Epoch: 1, + Prev: genesisBlock.Digest(), + } + + for _, testCase := range []struct { + name string + md simplex.ProtocolMetadata + err string + configure func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + }{ + { + name: "correct information", + md: validMD, + }, + { + name: "verifying a genesis block", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: "attempted to build a genesis inner block", + }, + { + name: "previous block not found", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + delete(tc.blockStore, 0) + }, + err: "failed to retrieve previous (0) inner block", + }, + { + name: "parent has no inner block", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + tc.blockStore[0] = &outerBlock{ + block: StateMachineBlock{}, + } + }, + err: "parent inner block (", + }, + { + name: "wrong epoch number", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: "invalid epoch number (2), should be 1", + }, + { + name: "P-chain height too big", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: "invalid P-chain height (110), expected to be 100", + }, + { + name: "P-chain height smaller than parent", + md: validMD, + configure: func(sm *StateMachine, tc *testConfig) { + sm.LastNonSimplexBlockPChainHeight = 99 + }, + err: "invalid P-chain height (100), expected to be 99", + }, + { + name: "nil BlockValidationDescriptor", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = nil + }, + err: "invalid BlockValidationDescriptor: should not be nil", + }, + { + name: "membership mismatch", + md: validMD, + configure: func(sm *StateMachine, tc *testConfig) { + sm.GenesisValidatorSet = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, + } + }, + err: "invalid BlockValidationDescriptor: should match validator set", + }, + { + name: "SimplexEpochInfo mismatch", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: "invalid SimplexEpochInfo", + }, + } { + t.Run(testCase.name, func(t *testing.T) { + sm1, _ := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) + + if testCase.configure != nil { + testCase.configure(sm2, testConfig2) + } + + block, err := sm1.BuildBlock(context.Background(), testCase.md, nil) + require.NoError(t, err) + require.NotNil(t, block) + + if testCase.mutateBlock != nil { + testCase.mutateBlock(block) + } + + err = sm2.VerifyBlock(context.Background(), block) + if testCase.err != "" { + require.ErrorContains(t, err, testCase.err) + return + } + require.NoError(t, err) + }) + } } -func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } -func (f *fakeVMBlock) Height() uint64 { return f.height } -func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } -func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } +func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { + preSimplexParent := StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: time.Now(), + BlockHeight: 42, + Bytes: []byte{4, 5, 6}, + }, + // Zero-valued metadata means this is a pre-Simplex block or a genesis block. + // But since the height is 42, it can't be a genesis block, so it must be a pre-Simplex block. + Metadata: StateMachineMetadata{}, + } + + md := simplex.ProtocolMetadata{ + Round: 0, + Seq: 43, + Epoch: 1, + Prev: preSimplexParent.Digest(), + } + + sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) + + testConfig1.blockStore[0] = &outerBlock{ + block: preSimplexParent, + } + + testConfig1.blockStore[42] = &outerBlock{block: preSimplexParent} + testConfig2.blockStore[42] = &outerBlock{block: preSimplexParent} + + sm1.LastNonSimplexInnerBlock = testConfig1.blockStore[42].block.InnerBlock + sm2.LastNonSimplexInnerBlock = testConfig1.blockStore[42].block.InnerBlock + + testConfig1.blockBuilder.block = &InnerBlock{ + TS: time.Now(), + BlockHeight: 43, + Bytes: []byte{7, 8, 9}, + } + + block, err := sm1.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.NotNil(t, block) + + require.NoError(t, sm2.VerifyBlock(context.Background(), block)) + + require.Equal(t, &StateMachineBlock{ + Metadata: StateMachineMetadata{ + Timestamp: uint64(preSimplexParent.InnerBlock.Timestamp().UnixMilli()), + PChainHeight: 100, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: 42, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: testConfig1.validatorSetRetriever.result, + }, + }, + }, + }, + }, block) +} + +func TestMSMNormalOp(t *testing.T) { + newPChainHeight := uint64(200) + newValidatorSet := NodeBLSMappings{ + {BLSKey: []byte{5}, Weight: 1}, {BLSKey: []byte{6}, Weight: 1}, {BLSKey: []byte{7}, Weight: 1}, + } + + for _, testCase := range []struct { + name string + setup func(*StateMachine, *testConfig) + expectedPChainHeight uint64 + expectedNextPChainRefHeight uint64 + }{ + { + name: "correct information", + expectedPChainHeight: 100, + }, + { + name: "validator set change detected", + setup: func(sm *StateMachine, tc *testConfig) { + tc.validatorSetRetriever.resultMap = map[uint64]NodeBLSMappings{ + newPChainHeight: newValidatorSet, + } + sm.GetPChainHeight = func() uint64 { return newPChainHeight } + }, + expectedPChainHeight: newPChainHeight, + expectedNextPChainRefHeight: newPChainHeight, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + chain := makeChain(t, 5, 10) + sm1, testConfig1 := newStateMachine(t) + + for i, block := range chain { + testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} + } + + lastBlock := chain[len(chain)-1] + md, err := simplex.ProtocolMetadataFromBytes(lastBlock.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + + md.Seq++ + md.Round++ + md.Prev = lastBlock.Digest() + + var blacklist simplex.Blacklist + blacklist.NodeCount = 4 + + blockTime := lastBlock.InnerBlock.Timestamp().Add(time.Second) + + content := make([]byte, 10) + _, err = rand.Read(content) + require.NoError(t, err) + + testConfig1.blockBuilder.block = &InnerBlock{ + TS: blockTime, + BlockHeight: lastBlock.InnerBlock.Height(), + Bytes: content, + } + + if testCase.setup != nil { + testCase.setup(sm1, testConfig1) + } + + block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) + require.NoError(t, err) + require.NotNil(t, block1) + + require.Equal(t, &StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: blockTime, + BlockHeight: lastBlock.InnerBlock.Height(), + Bytes: content, + }, + Metadata: StateMachineMetadata{ + SimplexBlacklist: blacklist.Bytes(), + Timestamp: uint64(blockTime.UnixMilli()), + PChainHeight: testCase.expectedPChainHeight, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: lastBlock.InnerBlock.Height(), + NextPChainReferenceHeight: testCase.expectedNextPChainRefHeight, + }, + }, + }, block1) + }) + } +} + +func TestMSMFullEpochLifecycle(t *testing.T) { + // Validator sets: epoch 1 uses validatorSet1, epoch 2 uses validatorSet2. + node1 := [20]byte{1} + node2 := [20]byte{2} + node3 := [20]byte{3} + + validatorSet1 := NodeBLSMappings{ + {NodeID: node1, BLSKey: []byte{1}, Weight: 1}, + {NodeID: node2, BLSKey: []byte{2}, Weight: 1}, + {NodeID: node3, BLSKey: []byte{3}, Weight: 1}, + } + validatorSet2 := NodeBLSMappings{ + {NodeID: node1, BLSKey: []byte{1}, Weight: 1}, + {NodeID: node2, BLSKey: []byte{4}, Weight: 1}, + {NodeID: node3, BLSKey: []byte{5}, Weight: 1}, + } + + pChainHeight1 := uint64(100) + pChainHeight2 := uint64(200) + + startTime := time.Now() + + nextBlock := func(height uint64) *InnerBlock { + return &InnerBlock{ + TS: startTime.Add(time.Duration(height) * time.Millisecond), + BlockHeight: height, + Bytes: []byte{byte(height)}, + } + } + + // ----- Step 0: Building on top of genesis or upgrading to Simplex----- + genesis := StateMachineBlock{ + InnerBlock: &InnerBlock{ + BlockHeight: 0, // Genesis block has height 0 + TS: startTime, + Bytes: []byte{0}, + }, + } + + notGenesis := StateMachineBlock{ + InnerBlock: &InnerBlock{ + BlockHeight: 42, + TS: startTime, + Bytes: []byte{0}, + }, + } + for _, testCase := range []struct { + name string + firstBlockBeforeSimplex StateMachineBlock + }{ + { + name: "building on top of genesis", + firstBlockBeforeSimplex: genesis, + }, + { + name: "upgrading to Simplex from pre-Simplex blocks", + firstBlockBeforeSimplex: notGenesis, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + + currentPChainHeight := pChainHeight1 + + getValidatorSet := func(height uint64) (NodeBLSMappings, error) { + if height >= pChainHeight2 { + return validatorSet2, nil + } + return validatorSet1, nil + } + getPChainHeight := func() uint64 { + return currentPChainHeight + } + + // Create fresh state machine instances for each iteration. + sm, tc := newStateMachine(t) + sm.GetValidatorSet = getValidatorSet + sm.GetPChainHeight = getPChainHeight + tc.blockStore[0] = &outerBlock{block: genesis} + tc.blockStore[42] = &outerBlock{block: notGenesis} + + sm.LastNonSimplexInnerBlock = testCase.firstBlockBeforeSimplex.InnerBlock + sm.GenesisValidatorSet = validatorSet1 + sm.LastNonSimplexBlockPChainHeight = pChainHeight1 + + smVerify, tcVerify := newStateMachine(t) + smVerify.GetValidatorSet = getValidatorSet + smVerify.GetPChainHeight = getPChainHeight + + smVerify.LastNonSimplexInnerBlock = testCase.firstBlockBeforeSimplex.InnerBlock + smVerify.GenesisValidatorSet = validatorSet1 + smVerify.LastNonSimplexBlockPChainHeight = pChainHeight1 + + // addBlock adds a block to both block stores so builder and verifier stay in sync. + addBlock := func(seq uint64, block StateMachineBlock, fin *simplex.Finalization) { + tc.blockStore[seq] = &outerBlock{block: block, finalization: fin} + tcVerify.blockStore[seq] = &outerBlock{block: block, finalization: fin} + } + + baseSeq := testCase.firstBlockBeforeSimplex.InnerBlock.Height() + addBlock(baseSeq, testCase.firstBlockBeforeSimplex, nil) + + aggr := &signatureAggregator{} + + // ----- Step 1: Build zero epoch block (first simplex block) ----- + tc.blockBuilder.block = nextBlock(1) + md := simplex.ProtocolMetadata{ + Seq: baseSeq + 1, + Round: 0, + Epoch: 1, + Prev: testCase.firstBlockBeforeSimplex.Digest(), + } + + block1, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.UnixMilli()), + PChainHeight: pChainHeight1, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: validatorSet1, + }, + }, + }, + }, + }, block1) + addBlock(md.Seq, *block1, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block1)) + + // After we build the first block, the StateMachine should consider it as the latest persisted height. + sm.LatestPersistedHeight = baseSeq + 1 + smVerify.LatestPersistedHeight = baseSeq + 1 + + // ----- Step 2: Build a normal block (no validator set change) ----- + tc.blockBuilder.block = nextBlock(2) + md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: 1, Prev: block1.Digest()} + block2, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(2), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(2 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight1, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq, + }, + }, + }, block2) + addBlock(md.Seq, *block2, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block2)) + + // ----- Step 3: Build a normal block that detects a validator set change ----- + // Advance P-chain height so that GetValidatorSet returns a different set. + currentPChainHeight = pChainHeight2 + + tc.blockBuilder.block = nextBlock(3) + md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: 1, Prev: block2.Digest()} + block3, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(3), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(3 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 2, + NextPChainReferenceHeight: pChainHeight2, + }, + }, + }, block3) + addBlock(md.Seq, *block3, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block3)) + + // ----- Step 4: First collecting block (1/3 approvals, not enough to seal) ----- + + // Override ApprovalsRetriever to use our dynamic approvals. + var approvalsResult ValidatorSetApprovals + sm.ApprovalsRetriever = &dynamicApprovalsRetriever{approvals: &approvalsResult} + + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node1, + PChainHeight: pChainHeight2, + Signature: []byte("sig1"), + }, + } + + // node1 is at index 0 in validatorSet2 → bitmask bit 0 → {1} + bitmask := []byte{1} + sig, err := aggr.AppendSignatures(nil, []byte("sig1")) + require.NoError(t, err) + + tc.blockBuilder.block = nextBlock(4) + md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: 1, Prev: block3.Digest()} + block4, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(4), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(4 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 3, + NextPChainReferenceHeight: pChainHeight2, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig, + }, + }, + }, + }, block4) + addBlock(md.Seq, *block4, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block4)) + + // ----- Step 5: Second collecting block (2/3 approvals, still not enough since threshold is strictly > 2/3) ----- + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node2, + PChainHeight: pChainHeight2, + Signature: []byte("sig2"), + }, + } + + // node2 is at index 1 → bitmask bits 0,1 → {3} + sig, err = aggr.AppendSignatures(sig, []byte("sig2")) + require.NoError(t, err) + bitmask = []byte{3} + + tc.blockBuilder.block = nextBlock(5) + md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: 1, Prev: block4.Digest()} + block5, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(5), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(5 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 4, + NextPChainReferenceHeight: pChainHeight2, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig, + }, + }, + }, + }, block5) + addBlock(md.Seq, *block5, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block5)) + + // ----- Step 6: Sealing block (3/3 approvals, enough to seal) ----- + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node3, + PChainHeight: pChainHeight2, + Signature: []byte("sig3"), + }, + } + + // node3 is at index 2 → bitmask bits 0,1,2 → {7} + sig6, err := aggr.AppendSignatures(sig, []byte("sig3")) + require.NoError(t, err) + bitmask = []byte{7} + + tc.blockBuilder.block = nextBlock(6) + md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: 1, Prev: block5.Digest()} + block6, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(6), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 5, + NextPChainReferenceHeight: pChainHeight2, + SealingBlockSeq: 0, + PrevSealingBlockHash: block1.Digest(), + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: validatorSet2, + }, + }, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig6, + }, + }, + }, + }, block6) + addBlock(md.Seq, *block6, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block6)) + + sealingSeq := baseSeq + 6 // The sealing block's sequence (md.Seq from step 6) + + backupStoreTC := tc.blockStore.clone() + backupStoreTCVerify := tcVerify.blockStore.clone() + + for _, subTestCase := range []struct { + name string + setup func() + }{ + { + name: "sealing block not finalized yet", + setup: func() { + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, nil) + }, + }, + { + name: "sealing block immediately finalized", + setup: func() { + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, &simplex.Finalization{}) + }, + }, + } { + testName := fmt.Sprintf("%s-%s", testCase.name, subTestCase.name) + t.Run(testName, func(t *testing.T) { + tc.blockStore = backupStoreTC.clone() + sm.GetBlock = tc.blockStore.getBlock + tcVerify.blockStore = backupStoreTCVerify.clone() + smVerify.GetBlock = tcVerify.blockStore.getBlock + + subTestCase.setup() + + tc.blockBuilder.block = nextBlock(7) + md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: 1, Prev: block6.Digest()} + + // If the sealing block isn't finalized yet, we expect to build a Telock. + // However, despite the fact that the block builder is willing to build a new block, + // a Telock shouldn't contain an inner block. + if tc.blockStore[sealingSeq].finalization == nil { + telock, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + + require.Equal(t, &StateMachineBlock{ + InnerBlock: nil, + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + NextPChainReferenceHeight: pChainHeight2, + PrevVMBlockSeq: baseSeq + 6, + SealingBlockSeq: sealingSeq, + }, + }, + }, telock) + + // Next, finalize the sealing block after we have built a Telock. + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, &simplex.Finalization{}) + } + + // ----- Step 7: Build a new epoch block (sealing block is finalized) ----- + + block7, err := sm.BuildBlock(context.Background(), md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(7), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(7 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight2, + EpochNumber: sealingSeq, + PrevVMBlockSeq: baseSeq + 6, + }, + }, + }, block7) + addBlock(md.Seq, *block7, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block7)) + }) + } + }) + } +} func TestIdentifyCurrentState(t *testing.T) { bvd := &BlockValidationDescriptor{} @@ -59,7 +733,7 @@ func TestIdentifyCurrentState(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - result := identifyCurrentState(tc.input) + result := tc.input.NextState() require.Equal(t, tc.expected, result) }) } @@ -129,120 +803,6 @@ func TestComputePrevVMBlockSeq(t *testing.T) { }) } -func TestFindFirstSimplexBlock(t *testing.T) { - t.Run("endHeight too big", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, nil - } - _, err := findFirstSimplexBlock(getBlock, math.MaxUint64) - require.ErrorContains(t, err, fmt.Sprintf(" is too big, must be at most %d", math.MaxInt64-1)) - }) - - t.Run("found at height 3", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - if opts.Height < 3 { - return StateMachineBlock{}, nil, nil - } - return StateMachineBlock{ - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, - }, nil, nil - } - result, err := findFirstSimplexBlock(getBlock, 5) - require.NoError(t, err) - require.Equal(t, uint64(3), result) - }) - - t.Run("no simplex blocks found", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, nil - } - _, err := findFirstSimplexBlock(getBlock, 5) - require.ErrorContains(t, err, "no simplex blocks found") - }) - - t.Run("block not found errors are skipped", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - if opts.Height < 2 { - return StateMachineBlock{}, nil, simplex.ErrBlockNotFound - } - return StateMachineBlock{ - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, - }, nil, nil - } - result, err := findFirstSimplexBlock(getBlock, 5) - require.NoError(t, err) - require.Equal(t, uint64(2), result) - }) - - t.Run("retrieval error propagated", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, fmt.Errorf("disk error") - } - _, err := findFirstSimplexBlock(getBlock, 5) - require.ErrorContains(t, err, "disk error") - }) -} - -func TestComputeTotalWeight(t *testing.T) { - t.Run("valid weights", func(t *testing.T) { - validators := NodeBLSMappings{ - {Weight: 100}, - {Weight: 200}, - {Weight: 300}, - } - total, err := validators.TotalWeight() - require.NoError(t, err) - require.Equal(t, int64(600), total) - }) - - t.Run("zero total weight", func(t *testing.T) { - validators := NodeBLSMappings{{Weight: 0}} - _, err := validators.TotalWeight() - require.ErrorContains(t, err, "total weight of validators is 0") - }) - - t.Run("empty validators", func(t *testing.T) { - _, err := NodeBLSMappings{}.TotalWeight() - require.ErrorContains(t, err, "total weight of validators is 0") - }) -} - -func TestComputeApprovingWeight(t *testing.T) { - validators := NodeBLSMappings{ - {Weight: 100}, - {Weight: 200}, - {Weight: 300}, - } - - t.Run("all approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{7}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(600), weight) - }) - - t.Run("partial approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{5}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(400), weight) - }) - - t.Run("none approving", func(t *testing.T) { - bm := bitmaskFromBytes(nil) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(0), weight) - }) - - t.Run("single validator approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{2}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(200), weight) - }) -} - func TestSanitizeApprovals(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} @@ -255,13 +815,15 @@ func TestSanitizeApprovals(t *testing.T) { node2: 2, } + logger := testutil.MakeLogger(t) + t.Run("filters by p-chain height", func(t *testing.T) { approvals := ValidatorSetApprovals{ {NodeID: node0, PChainHeight: 100}, {NodeID: node1, PChainHeight: 200}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node0, result[0].NodeID) }) @@ -272,7 +834,7 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node1, PChainHeight: 100}, } oldApproving := bitmaskFromBytes([]byte{1}) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node1, result[0].NodeID) }) @@ -283,7 +845,7 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node2, PChainHeight: 100}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node2, result[0].NodeID) }) @@ -294,24 +856,11 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node0, PChainHeight: 100}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) }) } -// concatAggregator concatenates signatures for easy verification in tests. -type concatAggregator struct{} - -func (concatAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { - return bytes.Join(sigs, nil), nil -} - -type failingAggregator struct{} - -func (failingAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { - return nil, fmt.Errorf("aggregation failed") -} - func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} @@ -323,6 +872,8 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { node2: 2, } + logger := testutil.MakeLogger(t) + t.Run("duplicate peer with already-approved node does not double-aggregate", func(t *testing.T) { // node0 is already in the previous approvals (bit 0 set). A duplicate peer // entry for node0 must not append node0's signature to the new aggregate @@ -338,7 +889,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, 1, newApproving.Len()) @@ -354,7 +905,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node1, Signature: []byte("sig1")}, } - _, _, err := computeNewApproverSignaturesAndSigners(nil, peers, oldApproving, nodeID2Index, concatAggregator{}) + _, _, err := computeNewApproverSignaturesAndSigners(nil, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.Error(t, err) }) @@ -367,7 +918,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node1, Signature: []byte("sig1")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.True(t, newApproving.Contains(1)) @@ -386,7 +937,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node2, Signature: []byte("sig2")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) // preserved from old require.True(t, newApproving.Contains(2)) // newly added @@ -401,7 +952,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { } oldApproving := bitmaskFromBytes([]byte{1}) - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, nil, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, nil, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, []byte("existing"), aggSig) @@ -417,7 +968,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, 1, newApproving.Len()) @@ -431,7 +982,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}) + _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}, logger) require.ErrorContains(t, err, "aggregation failed") }) } diff --git a/pos_test.go b/pos_test.go index 6f7af6bf..1d1d7f08 100644 --- a/pos_test.go +++ b/pos_test.go @@ -23,16 +23,18 @@ func TestPoS(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} - posSigAggregator := &testutil.TestSignatureAggregator{ - IsQuorumFunc: func(signatures []simplex.NodeID) bool { - var totalWeight uint64 - for _, signer := range signatures { - totalWeight += weights[int(signer[0])] - } - return totalWeight > 6 - }, + posSigAggregatorCreator := func(_ []simplex.Node) simplex.SignatureAggregator { + return &testutil.TestSignatureAggregator{ + IsQuorumFunc: func(signatures []simplex.NodeID) bool { + var totalWeight uint64 + for _, signer := range signatures { + totalWeight += weights[int(signer[0])] + } + return totalWeight > 6 + }, + } } - testConf := &testutil.TestNodeConfig{SigAggregator: posSigAggregator, ReplicationEnabled: true} + testConf := &testutil.TestNodeConfig{SigAggregatorCreator: posSigAggregatorCreator, ReplicationEnabled: true} net := testutil.NewControlledNetwork(t, nodes) testutil.NewControlledSimplexNode(t, nodes[0], net, testConf) diff --git a/recovery_test.go b/recovery_test.go index 0e82dccd..e87e8b71 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -250,7 +250,8 @@ func TestWalCreatedProperly(t *testing.T) { records, err = e.WAL.ReadAll() require.NoError(t, err) require.Len(t, records, 2) - expectedNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + expectedNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) require.Equal(t, expectedNotarizationRecord, records[1]) @@ -423,6 +424,8 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { require.NoError(t, err) t.Cleanup(e.Stop) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + protocolMetadata := e.Metadata() firstBlock, ok := bb.BuildBlock(ctx, protocolMetadata, emptyBlacklist) require.True(t, ok) @@ -431,7 +434,7 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { record := BlockRecord(firstBlock.BlockHeader(), fBytes) wal.Append(record) - firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(firstNotarizationRecord) @@ -445,12 +448,12 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { wal.Append(record) // Add notarization for second block - secondNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) + secondNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(secondNotarizationRecord) // Create finalization record for second block - finalization2, finalizationRecord := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) + finalization2, finalizationRecord := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) wal.Append(finalizationRecord) err = e.Start() @@ -460,7 +463,7 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { require.Equal(t, uint64(0), e.Storage.NumBlocks()) // now if we send finalization for block 1, we should index both 1 & 2 - finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) err = e.HandleMessage(&Message{ Finalization: &finalization1, }, nodes[1]) @@ -499,11 +502,12 @@ func TestRecoveryBlocksIndexed(t *testing.T) { record := BlockRecord(firstBlock.BlockHeader(), fBytes) wal.Append(record) - firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + sigAggr := conf.SignatureAggregatorCreator(conf.Comm.Nodes()) + firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(firstNotarizationRecord) - _, finalizationBytes := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + _, finalizationBytes := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) wal.Append(finalizationBytes) protocolMetadata.Round = 1 @@ -524,9 +528,9 @@ func TestRecoveryBlocksIndexed(t *testing.T) { record = BlockRecord(thirdBlock.BlockHeader(), tBytes) wal.Append(record) - finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) - finalization2, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) - fCer3, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, thirdBlock, nodes[0:quorum]) + finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) + finalization2, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) + fCer3, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, thirdBlock, nodes[0:quorum]) conf.Storage.Index(ctx, firstBlock, finalization1) conf.Storage.Index(ctx, secondBlock, finalization2) @@ -655,7 +659,8 @@ func TestWalRecoveryTriggersEmptyVoteTimeout(t *testing.T) { require.NoError(t, wal.Append(blockRecord)) // lets add some notarizations - notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) require.NoError(t, wal.Append(notarizationRecord)) @@ -761,7 +766,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord := BlockRecord(block.BlockHeader(), bBytes) - notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block, nodes[0:quorum]) require.NoError(t, err) return [][]byte{blockRecord, notarizationRecord} @@ -779,7 +784,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - _, finalizationRecord1 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + _, finalizationRecord1 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) // Create empty notarization for round 0 emptyNotarization0 := testutil.NewEmptyNotarization(nodes[0:quorum], 0) @@ -800,7 +805,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes0, err := block0.Bytes() require.NoError(t, err) blockRecord0 := BlockRecord(block0.BlockHeader(), bBytes0) - notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block0, nodes[0:quorum]) + notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block0, nodes[0:quorum]) require.NoError(t, err) block1, ok := bb.BuildBlock(ctx, ProtocolMetadata{Round: 1, Epoch: 0, Seq: 1}, emptyBlacklist) @@ -808,7 +813,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes1, err := block1.Bytes() require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) require.NoError(t, err) block2, ok := bb.BuildBlock(ctx, ProtocolMetadata{Round: 2, Epoch: 0, Seq: 2}, emptyBlacklist) @@ -816,7 +821,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes2, err := block2.Bytes() require.NoError(t, err) blockRecord2 := BlockRecord(block2.BlockHeader(), bBytes2) - notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) require.NoError(t, err) // Create empty notarization for round 3 @@ -843,7 +848,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes3, err := block3.Bytes() require.NoError(t, err) blockRecord3 := BlockRecord(block3.BlockHeader(), bBytes3) - _, finalizationRecord3 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block3, nodes[0:quorum]) + _, finalizationRecord3 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block3, nodes[0:quorum]) // Create empty notarization for round 2 emptyNotarization2 := testutil.NewEmptyNotarization(nodes[0:quorum], 2) @@ -856,7 +861,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes1, err := block1.Bytes() require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) require.NoError(t, err) // Return in reverse order @@ -878,7 +883,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes0, err := block0.Bytes() require.NoError(t, err) blockRecord0 := BlockRecord(block0.BlockHeader(), bBytes0) - notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block0, nodes[0:quorum]) + notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block0, nodes[0:quorum]) require.NoError(t, err) // Create finalization for round 10 (highest) @@ -887,7 +892,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes10, err := block10.Bytes() require.NoError(t, err) blockRecord10 := BlockRecord(block10.BlockHeader(), bBytes10) - _, finalizationRecord10 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block10, nodes[0:quorum]) + _, finalizationRecord10 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block10, nodes[0:quorum]) // Create empty notarization for round 5 emptyNotarization5 := testutil.NewEmptyNotarization(nodes[0:quorum], 5) @@ -913,10 +918,10 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord2 := BlockRecord(block2.BlockHeader(), bBytes2) - notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) require.NoError(t, err) - _, finalizationRecord2 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + _, finalizationRecord2 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) // All records for same round return [][]byte{ diff --git a/replication_test.go b/replication_test.go index 7b5988a7..d7e4685a 100644 --- a/replication_test.go +++ b/replication_test.go @@ -132,7 +132,8 @@ func TestReplicationAdversarialNode(t *testing.T) { require.Equal(t, uint64(0), laggingNode.E.Metadata().Round) net.Connect(laggingNode.E.ID) - finalization, _ := NewFinalizationRecord(t, laggingNode.E.Logger, laggingNode.E.SignatureAggregator, blocks[1], nodes[:quorum]) + sigAggr := laggingNode.E.SignatureAggregatorCreator(laggingNode.E.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, laggingNode.E.Logger, sigAggr, blocks[1], nodes[:quorum]) finalizationMsg := &simplex.Message{ Finalization: &finalization, } @@ -375,7 +376,8 @@ func TestReplicationStartsBeforeCurrentRound(t *testing.T) { record := simplex.BlockRecord(firstBlock.BlockHeader(), fBytes) laggingNode.WAL.Append(record) - firstNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, laggingNode.E.SignatureAggregator, firstBlock, nodes[0:quorum]) + sigAggr := laggingNode.E.SignatureAggregatorCreator(laggingNode.E.Comm.Nodes()) + firstNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) laggingNode.WAL.Append(firstNotarizationRecord) @@ -385,7 +387,7 @@ func TestReplicationStartsBeforeCurrentRound(t *testing.T) { record = simplex.BlockRecord(secondBlock.BlockHeader(), sBytes) laggingNode.WAL.Append(record) - secondNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, laggingNode.E.SignatureAggregator, secondBlock, nodes[0:quorum]) + secondNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, sigAggr, secondBlock, nodes[0:quorum]) require.NoError(t, err) laggingNode.WAL.Append(secondNotarizationRecord) @@ -414,7 +416,7 @@ func TestReplicationFutureFinalization(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} quorum := simplex.Quorum(len(nodes)) - conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[1], NoopComm(nodes), bb) + conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[1], NewNoopComm(nodes), bb) e, err := simplex.NewEpoch(conf) require.NoError(t, err) @@ -441,7 +443,8 @@ func TestReplicationFutureFinalization(t *testing.T) { }, nodes[0]) require.NoError(t, err) - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // send finalization err = e.HandleMessage(&simplex.Message{ Finalization: &finalization, @@ -600,7 +603,7 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, }, bb) @@ -627,7 +630,8 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { highBlock, _ := blocks[3].VerifiedBlock.(*TestBlock) - highFinalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, highBlock, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + highFinalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, highBlock, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -973,7 +977,8 @@ func TestReplicationVerifyNotarization(t *testing.T) { block := bb.GetBuiltBlock() - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -988,7 +993,7 @@ func TestReplicationVerifyNotarization(t *testing.T) { } } - notarization, err := NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + notarization, err := NewNotarization(e.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) // Corrupt the QC @@ -1060,7 +1065,8 @@ func TestReplicationVerifyEmptyNotarization(t *testing.T) { block := bb.GetBuiltBlock() - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -1348,7 +1354,7 @@ func TestReplicationStoresFinalization(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[3], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, }, bb) conf.ReplicationEnabled = true @@ -1533,7 +1539,7 @@ func TestReplicationStartsRoundFromFinalization(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) broadcastMessages := make(chan *simplex.Message, 100) conf, wal, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, BroadcastMessages: broadcastMessages, }, bb) @@ -1643,7 +1649,7 @@ func TestReplicationStartsRoundFromFinalizationWithBlock(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) broadcastMessages := make(chan *simplex.Message, 100) conf, wal, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, BroadcastMessages: broadcastMessages, }, bb) diff --git a/replication_timeout_test.go b/replication_timeout_test.go index cbaa6ead..e64e99e8 100644 --- a/replication_timeout_test.go +++ b/replication_timeout_test.go @@ -622,7 +622,8 @@ func TestReplicationResendsFinalizedBlocksThatFailedVerification(t *testing.T) { block := bb.GetBuiltBlock() block.VerificationError = errors.New("block verification failed") - finalization, _ := testutil.NewFinalizationRecord(t, l, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes[0:quorum]) // send the finalization to start the replication process e.HandleMessage(&simplex.Message{ @@ -658,7 +659,7 @@ func TestReplicationResendsFinalizedBlocksThatFailedVerification(t *testing.T) { block.Data = append(block.Data, 0) block.ComputeDigest() - finalization, _ = testutil.NewFinalizationRecord(t, l, e.SignatureAggregator, block, nodes[0:quorum]) + finalization, _ = testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes[0:quorum]) replicationResponse = &simplex.ReplicationResponse{ Data: []simplex.QuorumRound{ { diff --git a/testutil/comm.go b/testutil/comm.go index f5fb9cca..71f33191 100644 --- a/testutil/comm.go +++ b/testutil/comm.go @@ -22,10 +22,10 @@ import ( // - bool: true if the message can be transmitted, false otherwise type MessageFilter func(msg *simplex.Message, from simplex.NodeID, to simplex.NodeID) bool -type NoopComm []simplex.NodeID +type NoopComm simplex.Nodes -func (n NoopComm) Nodes() []simplex.NodeID { - return n +func (n NoopComm) Nodes() simplex.Nodes { + return simplex.Nodes(n) } func (n NoopComm) Send(*simplex.Message, simplex.NodeID) { @@ -51,8 +51,8 @@ func NewTestComm(from simplex.NodeID, net *BasicInMemoryNetwork, messageFilter M } } -func (c *TestComm) Nodes() []simplex.NodeID { - return c.net.nodes +func (c *TestComm) Nodes() simplex.Nodes { + return c.net.nodeWeights } func (c *TestComm) Send(msg *simplex.Message, destination simplex.NodeID) { @@ -177,7 +177,7 @@ func (c *TestComm) Broadcast(msg *simplex.Message) { if !c.isMessagePermitted(msg, instance.E.ID) { continue } - // Skip sending the message to yourself or disconnected nodes + // Skip sending the message to yourself or disconnected nodeWeights if bytes.Equal(c.from, instance.E.ID) || c.net.IsDisconnected(instance.E.ID) { continue } @@ -191,6 +191,6 @@ func AllowAllMessages(*simplex.Message, simplex.NodeID, simplex.NodeID) bool { return true } -func NewNoopComm(nodes []simplex.NodeID) NoopComm { - return NoopComm(nodes) +func NewNoopComm(nodes simplex.NodeIDs) NoopComm { + return NoopComm(nodes.EqualWeightedNodes()) } diff --git a/testutil/controlled.go b/testutil/controlled.go index fe45eec5..e9a337d3 100644 --- a/testutil/controlled.go +++ b/testutil/controlled.go @@ -19,9 +19,9 @@ type ControlledInMemoryNetwork struct { } // NewControlledNetwork creates an in-memory network. Node IDs must be provided before -// adding instances, as nodes require prior knowledge of all participants. -func NewControlledNetwork(t *testing.T, nodes []simplex.NodeID) *ControlledInMemoryNetwork { - simplex.SortNodes(nodes) +// adding instances, as nodeWeights require prior knowledge of all participants. +func NewControlledNetwork(t *testing.T, nodes simplex.NodeIDs) *ControlledInMemoryNetwork { + simplex.SortNodes(nodes.EqualWeightedNodes()) net := &ControlledInMemoryNetwork{ BasicInMemoryNetwork: NewBasicInMemoryNetwork(t, nodes), Instances: make([]*ControlledNode, 0), @@ -72,7 +72,7 @@ func (n *ControlledInMemoryNetwork) AdvanceWithoutLeader(round uint64, laggingNo } for _, n := range n.Instances { - leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes(), n.E.Metadata().Round)) + leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes().NodeIDs(), n.E.Metadata().Round)) if leader || laggingNodeId.Equals(n.E.ID) { continue } diff --git a/testutil/network.go b/testutil/network.go index fa4555a8..8ce6ac58 100644 --- a/testutil/network.go +++ b/testutil/network.go @@ -16,15 +16,18 @@ import ( type BasicInMemoryNetwork struct { t *testing.T nodes []simplex.NodeID + nodeWeights simplex.Nodes lock sync.RWMutex disconnected map[string]struct{} instances []*BasicNode } -func NewBasicInMemoryNetwork(t *testing.T, nodes []simplex.NodeID) *BasicInMemoryNetwork { - simplex.SortNodes(nodes) +func NewBasicInMemoryNetwork(t *testing.T, nodes simplex.NodeIDs) *BasicInMemoryNetwork { + nodeWeights := nodes.EqualWeightedNodes() + simplex.SortNodes(nodeWeights) return &BasicInMemoryNetwork{ t: t, + nodeWeights: nodeWeights, nodes: nodes, disconnected: make(map[string]struct{}), instances: make([]*BasicNode, 0), @@ -99,9 +102,9 @@ func (b *BasicInMemoryNetwork) StartInstances() { b.lock.RLock() defer b.lock.RUnlock() - require.Equal(b.t, len(b.nodes), len(b.instances)) + require.Equal(b.t, len(b.nodeWeights), len(b.instances)) - for i := len(b.nodes) - 1; i >= 0; i-- { + for i := len(b.nodeWeights) - 1; i >= 0; i-- { b.instances[i].Start() } } @@ -142,8 +145,8 @@ func (b *BasicInMemoryNetwork) AddNode(node *BasicNode) { defer b.lock.Unlock() allowed := false - for _, id := range b.nodes { - if bytes.Equal(id, node.E.ID) { + for _, nodeWeight := range b.nodeWeights { + if bytes.Equal(nodeWeight.Node, node.E.ID) { allowed = true break } diff --git a/testutil/node.go b/testutil/node.go index 2cc09eee..21784a1e 100644 --- a/testutil/node.go +++ b/testutil/node.go @@ -198,8 +198,8 @@ func UpdateEpochConfig(epochConfig *simplex.EpochConfig, testConfig *TestNodeCon epochConfig.Comm = testConfig.Comm } - if testConfig.SigAggregator != nil { - epochConfig.SignatureAggregator = testConfig.SigAggregator + if testConfig.SigAggregatorCreator != nil { + epochConfig.SignatureAggregatorCreator = testConfig.SigAggregatorCreator } if testConfig.BlockBuilder != nil { @@ -234,11 +234,11 @@ func UpdateEpochConfig(epochConfig *simplex.EpochConfig, testConfig *TestNodeCon // NodeConfig type TestNodeConfig struct { // optional - InitialStorage []simplex.VerifiedFinalizedBlock - Comm simplex.Communication - SigAggregator simplex.SignatureAggregator - ReplicationEnabled bool - BlockBuilder *testControlledBlockBuilder + InitialStorage []simplex.VerifiedFinalizedBlock + Comm simplex.Communication + SigAggregatorCreator simplex.SignatureAggregatorCreator + ReplicationEnabled bool + BlockBuilder *testControlledBlockBuilder // Long Running Tests MaxRoundWindow uint64 diff --git a/testutil/util.go b/testutil/util.go index d8923672..4be23174 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -32,10 +32,12 @@ func DefaultTestNodeEpochConfig(t *testing.T, nodeID simplex.NodeID, comm simple Verifier: &testVerifier{}, Storage: storage, BlockBuilder: bb, - SignatureAggregator: &TestSignatureAggregator{N: len(comm.Nodes())}, - BlockDeserializer: &BlockDeserializer{}, - QCDeserializer: &testQCDeserializer{t: t}, - StartTime: time.Now(), + SignatureAggregatorCreator: func(weights []simplex.Node) simplex.SignatureAggregator { + return &TestSignatureAggregator{N: len(weights)} + }, + BlockDeserializer: &BlockDeserializer{}, + QCDeserializer: &testQCDeserializer{t: t}, + StartTime: time.Now(), } return conf, wal, storage } @@ -120,12 +122,29 @@ func (t *testQCDeserializer) DeserializeQuorumCertificate(bytes []byte) (simplex return TestQC(qc), err } +type TestSignatureAggregatorCreator struct { + Err error + N int + IsQuorumFunc func(signatures []simplex.NodeID) bool +} + type TestSignatureAggregator struct { Err error N int IsQuorumFunc func(signatures []simplex.NodeID) bool } +func (t *TestSignatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + if t.Err != nil { + return nil, t.Err + } + result := append([]byte{}, existing...) + for _, s := range sigs { + result = append(result, s...) + } + return result, nil +} + func (t *TestSignatureAggregator) Aggregate(signatures []simplex.Signature) (simplex.QuorumCertificate, error) { return TestQC(signatures), t.Err }