From 7bf5e973f7415a2be6d3ba162cb607bd4b96b73b Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 16 Apr 2026 21:30:17 +0200 Subject: [PATCH 01/16] Simplex reconfiguration framework - Part III (MSM implementation) - Add block building to msm.go - Add verification.go which contains logic for block verification - Add tests that mimic Simplex flow (fake_node_test.go) Signed-off-by: Yacov Manevich --- msm/fake_node_test.go | 428 ++++++++++++++++ msm/msm.go | 726 +++++++++++++++++++++++++++ msm/msm_test.go | 1021 ++++++++++++++++++++++++++++++++++++++ msm/verification.go | 515 +++++++++++++++++++ msm/verification_test.go | 1016 +++++++++++++++++++++++++++++++++++++ 5 files changed, 3706 insertions(+) create mode 100644 msm/fake_node_test.go create mode 100644 msm/verification.go create mode 100644 msm/verification_test.go diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go new file mode 100644 index 00000000..4c980a44 --- /dev/null +++ b/msm/fake_node_test.go @@ -0,0 +1,428 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "context" + "crypto/rand" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/stretchr/testify/require" +) + +func TestFakeNode(t *testing.T) { + validatorSetRetriever := validatorSetRetriever{ + resultMap: map[uint64]NodeBLSMappings{ + 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, + 300: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 3, NodeID: [20]byte{3}}, {BLSKey: []byte{4}, Weight: 1, NodeID: [20]byte{4}}}, + }, + } + + var pChainHeight atomic.Uint64 + pChainHeight.Store(100) + node := newFakeNode(t) + node.sm.GetValidatorSet = validatorSetRetriever.getValidatorSet + node.sm.GetPChainHeight = func() uint64 { + return pChainHeight.Load() + } + + // Create some blocks and finalize them, until we reach height 10 + for node.Height() < 10 { + node.act() + } + + // Next, we increase the P-Chain height, which should cause the node to update its validator set and move to the new epoch. + pChainHeight.Store(200) + + epoch := node.Epoch() + for node.Epoch() == epoch { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{1}, PChainHeight: 200, Signature: []byte{1}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 200, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), uint64(1)) + + // Finally, we increase the P-Chain height again, which should cause the node to update its validator set and move to the new epoch. + + pChainHeight.Store(300) + + epoch = node.Epoch() + for node.Epoch() == epoch { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 300, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{3}, PChainHeight: 300, Signature: []byte{3}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), epoch) +} + +func TestFakeNodeEmptyMempool(t *testing.T) { + validatorSetRetriever := validatorSetRetriever{ + resultMap: map[uint64]NodeBLSMappings{ + 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, + 300: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, + {BLSKey: []byte{3}, Weight: 3, NodeID: [20]byte{3}}, {BLSKey: []byte{4}, Weight: 1, NodeID: [20]byte{4}}}, + }, + } + + var pChainHeight uint64 = 100 + node := newFakeNode(t) + node.sm.MaxBlockBuildingWaitTime = 100 * time.Millisecond + node.sm.GetValidatorSet = validatorSetRetriever.getValidatorSet + node.sm.GetPChainHeight = func() uint64 { + return pChainHeight + } + + // Create some blocks and finalize them, until we reach height 10 + for node.Height() < 10 { + node.act() + } + + // Next, we increase the P-Chain height, which should cause the node to update its validator set and move to the new epoch. + pChainHeight = 200 + + // However, we mark the mempool as empty, which should cause the node to wait until it sees a change in the P-Chain height, rather than building blocks on top of the old epoch. + node.mempoolEmpty = true + + // We build blocks until the sealing block is finalized. + for node.finalizedBlocks[len(node.finalizedBlocks)-1].Metadata.SimplexEpochInfo.BlockValidationDescriptor == nil { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{1}, PChainHeight: 200, Signature: []byte{1}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 200, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + node.mempoolEmpty = false + + // Build a new block and check that the node has transitioned to the new epoch, + // rather than building a block on top of the old epoch. + height := node.Height() + + for node.Height() == height { + node.act() + } + require.Greater(t, node.Epoch(), uint64(1)) + + t.Log("Epoch:", node.Epoch()) + + epoch := node.Epoch() + require.Greater(t, epoch, uint64(1)) + + // Finally, we increase the P-Chain height again, which should cause the node to update its validator set and move to the new epoch. + + pChainHeight = 300 + + for node.Height() < 30 { + node.act() + if flipCoin() { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{2}, PChainHeight: 300, Signature: []byte{2}, AuxInfoSeqDigest: [32]byte{}}}, + } + } else { + node.sm.ApprovalsRetriever = &approvalsRetriever{ + result: []ValidatorSetApproval{{NodeID: [20]byte{3}, PChainHeight: 300, Signature: []byte{3}, AuxInfoSeqDigest: [32]byte{}}}, + } + } + } + + t.Log("Epoch:", node.Epoch()) + require.Greater(t, node.Epoch(), epoch) + require.Equal(t, node.Height(), uint64(30)) +} + +type innerBlock struct { + InnerBlock + Prev [32]byte +} + +type fakeNode struct { + t *testing.T + sm StateMachine + mempoolEmpty bool + notarizedBlocks []StateMachineBlock + finalizedBlocks []StateMachineBlock + innerChain []innerBlock +} + +func (fn *fakeNode) WaitForProgress(ctx context.Context, pChainHeight uint64) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(10 * time.Millisecond): + if fn.sm.GetPChainHeight() != pChainHeight { + return nil + } + } + } +} + +func (fn *fakeNode) WaitForPendingBlock(ctx context.Context) { + if fn.mempoolEmpty { + <-ctx.Done() + return + } +} + +func newFakeNode(t *testing.T) *fakeNode { + sm, _ := newStateMachine(t) + + fn := &fakeNode{ + t: t, + sm: sm, + } + + fn.sm.BlockBuilder = fn + fn.sm.PChainProgressListener = fn + + fn.sm.GetBlock = func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + if opts.Height == 0 { + return genesisBlock, nil, nil + } + for _, block := range fn.finalizedBlocks { + if block.Digest() == opts.Digest { + return block, &simplex.Finalization{}, nil + } + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return StateMachineBlock{}, nil, err + } + if md.Seq == opts.Height { + return block, &simplex.Finalization{}, nil + } + } + for _, block := range fn.notarizedBlocks { + if block.Digest() == opts.Digest { + return block, nil, nil + } + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return StateMachineBlock{}, nil, err + } + if md.Seq == opts.Height { + return block, nil, nil + } + } + + require.Failf(t, "not found block", "height: %d", opts.Height) + return StateMachineBlock{}, nil, fmt.Errorf("block not found") + } + + return fn +} + +func (fn *fakeNode) Height() uint64 { + return uint64(len(fn.finalizedBlocks)) +} + +func (fn *fakeNode) Epoch() uint64 { + return fn.notarizedBlocks[len(fn.notarizedBlocks)-1].Metadata.SimplexEpochInfo.EpochNumber +} + +func (fn *fakeNode) act() { + if fn.canFinalize() && flipCoin() { + fn.tryFinalizeNextBlock() + return + } + + if flipCoin() { + return + } + + fn.buildAndNotarizeBlock() +} + +func (fn *fakeNode) canFinalize() bool { + return len(fn.notarizedBlocks) > len(fn.finalizedBlocks) +} + +func (fn *fakeNode) tryFinalizeNextBlock() { + nextIndex := len(fn.finalizedBlocks) + + if fn.isNextBlockTelock() { + return + } + + block := fn.notarizedBlocks[nextIndex] + fn.finalizedBlocks = append(fn.finalizedBlocks, block) + + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(fn.t, err) + + fn.sm.LatestPersistedHeight = md.Seq + fn.t.Logf("Finalized block at height %d with epoch %d", md.Seq, block.Metadata.SimplexEpochInfo.EpochNumber) + + // If we just finalized a sealing block, trim trailing Telock blocks. + if block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil { + fn.notarizedBlocks = fn.notarizedBlocks[:len(fn.finalizedBlocks)] + fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.notarizedBlocks)) + } +} + +func (fn *fakeNode) isNextBlockTelock() bool { + if len(fn.finalizedBlocks) == 0 { + return false + } + return fn.notarizedBlocks[len(fn.finalizedBlocks)].Metadata.SimplexEpochInfo.SealingBlockSeq > 0 +} + +func (fn *fakeNode) buildAndNotarizeBlock() { + vmBlock, block := fn.buildBlock() + require.NoError(fn.t, fn.sm.VerifyBlock(context.Background(), block)) + + fn.notarizedBlocks = append(fn.notarizedBlocks, *block) + + if vmBlock != nil { + fn.innerChain = append(fn.innerChain, *vmBlock.(*innerBlock)) + } +} + +func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { + parentBlock := fn.getParentBlock() + + lastMD, prevBlockDigest := fn.prepareMetadataAndPrevBlockDigest() + + _, finalization, err := fn.sm.GetBlock(RetrievingOpts{ + Digest: prevBlockDigest, + Height: lastMD.Seq, + }) + require.NoError(fn.t, err) + + finalizedString := "not finalized" + if finalization != nil { + finalizedString = "finalized" + } + + fn.t.Logf("Building a block on top of %s parent with epoch %d", finalizedString, parentBlock.Metadata.SimplexEpochInfo.EpochNumber) + + block, err := fn.sm.BuildBlock(context.Background(), parentBlock, simplex.ProtocolMetadata{ + Seq: lastMD.Seq + 1, + Round: lastMD.Round + 1, + Prev: prevBlockDigest, + }, nil) + require.NoError(fn.t, err) + + return block.InnerBlock, block +} + +func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetadata, [32]byte) { + var lastMD *simplex.ProtocolMetadata + var err error + lastBlockDigest := genesisBlock.Digest() + if len(fn.notarizedBlocks) > 0 { + lastBlock := fn.notarizedBlocks[len(fn.notarizedBlocks)-1] + lastBlockDigest = lastBlock.Digest() + lastMD, err = simplex.ProtocolMetadataFromBytes(lastBlock.Metadata.SimplexProtocolMetadata) + require.NoError(fn.t, err) + } else { + lastMD = &simplex.ProtocolMetadata{ + Prev: lastBlockDigest, + } + } + return lastMD, lastBlockDigest +} + +func (fn *fakeNode) BuildBlock(context.Context, uint64) (VMBlock, error) { + // Count the number of inner blocks in the chain + var count int + for _, block := range fn.notarizedBlocks { + if block.InnerBlock != nil { + count++ + } + } + + vmBlock := &innerBlock{ + Prev: fn.getLastVMBlockDigest(), + InnerBlock: InnerBlock{ + Bytes: randomBuff(10), + TS: time.Now(), + BlockHeight: uint64(count), + }, + } + return vmBlock, nil +} + +func (fn *fakeNode) getParentBlock() StateMachineBlock { + var parentBlock StateMachineBlock + if len(fn.notarizedBlocks) > 0 { + parentBlock = fn.notarizedBlocks[len(fn.notarizedBlocks)-1] + } else { + gb := genesisBlock.InnerBlock.(*InnerBlock) + parentBlock = StateMachineBlock{ + InnerBlock: &innerBlock{ + InnerBlock: *gb, + }, + } + } + return parentBlock +} + +func (fn *fakeNode) getLastVMBlockDigest() [32]byte { + var lastVMBlockDigest = genesisBlock.Digest() + + notarizedBlocks := fn.notarizedBlocks + for len(notarizedBlocks) > 0 { + lastNotarizedBlock := notarizedBlocks[len(notarizedBlocks)-1] + if lastNotarizedBlock.InnerBlock == nil { + notarizedBlocks = notarizedBlocks[:len(notarizedBlocks)-1] + continue + } + lastVMBlockDigest = lastNotarizedBlock.Digest() + break + } + return lastVMBlockDigest +} + +func randomBuff(n int) []byte { + buff := make([]byte, n) + _, err := rand.Read(buff) + if err != nil { + panic(err) + } + return buff +} + +func flipCoin() bool { + buff := make([]byte, 1) + _, err := rand.Read(buff) + if err != nil { + panic(err) + } + + lsb := buff[0] & 1 + + return lsb == 1 +} diff --git a/msm/msm.go b/msm/msm.go index 9022267e..cf78fffe 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -4,13 +4,17 @@ package metadata import ( + "context" "crypto/sha256" "errors" "fmt" "math" + "math/big" "sort" + "time" "github.com/ava-labs/simplex" + "go.uber.org/zap" ) // A StateMachineBlock is a representation of a parsed OuterBlock, containing the inner block and the metadata. @@ -41,6 +45,24 @@ type SignatureAggregator interface { AggregateSignatures(signatures ...[]byte) ([]byte, error) } +// ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. +type ApprovalsRetriever interface { + RetrieveApprovals() ValidatorSetApprovals +} + + +// KeyAggregator combines multiple public keys into a single aggregated public key. +type KeyAggregator interface { + AggregateKeys(keys ...[]byte) ([]byte, error) +} + + +// SignatureVerifier verifies a cryptographic signature against a message and public key. +// Used to verify Approvals from validators for epoch transitions. +type SignatureVerifier interface { + VerifySignature(signature []byte, message []byte, publicKey []byte) error +} + // ValidatorSetRetriever retrieves the validator set at a given P-chain height. type ValidatorSetRetriever func(pChainHeight uint64) (NodeBLSMappings, error) @@ -57,6 +79,58 @@ type RetrievingOpts struct { // If an error occurs during retrieval, it returns a non-nil error. type BlockRetriever func(RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) +// BlockBuilder builds a new VM block with the given observed P-chain height. +type BlockBuilder interface { + BuildBlock(ctx context.Context, pChainHeight uint64) (VMBlock, error) + + // WaitForPendingBlock returns when either the given context is cancelled, + // or when the VM signals that a block should be built. + WaitForPendingBlock(ctx context.Context) +} + +// StateMachine manages block building and verification across epoch transitions. +type StateMachine struct { + // LatestPersistedHeight is the height of the most recently persisted block. + LatestPersistedHeight uint64 + // MaxBlockBuildingWaitTime is the maximum duration to wait for the VM to build a block + // before producing a block without an inner block. + MaxBlockBuildingWaitTime time.Duration + // TimeSkewLimit is the maximum allowed time difference between a block's timestamp and the current time. + TimeSkewLimit time.Duration + // GetTime returns the current time. + GetTime func() time.Time + // GetPChainHeight returns the latest known P-chain height. + GetPChainHeight func() uint64 + // GetUpgrades returns the current upgrade configuration. + GetUpgrades func() UpgradeConfig + // BlockBuilder builds new VM blocks. + BlockBuilder BlockBuilder + // Logger is used for logging state machine operations. + Logger Logger + // GetValidatorSet retrieves the validator set at a given P-chain height. + GetValidatorSet ValidatorSetRetriever + // GetBlock retrieves a previously built or finalized block. + GetBlock BlockRetriever + // ApprovalsRetriever retrieves validator approvals for epoch transitions. + ApprovalsRetriever ApprovalsRetriever + // SignatureAggregator aggregates signatures from validators. + SignatureAggregator SignatureAggregator + // KeyAggregator aggregates public keys from validators. + KeyAggregator KeyAggregator + // SignatureVerifier verifies signatures from validators. + SignatureVerifier SignatureVerifier + // PChainProgressListener listens for changes in the P-chain height to trigger block building or epoch transitions. + PChainProgressListener PChainProgressListener + + // initialized tracks whether the state machine has been initialized. + // This is used to lazily initialize the verifiers. + initialized bool + + // verifiers is the list of verifiers used to verify proposed blocks. + // Each verifier is responsible for verifying a specific aspect of the block's metadata. + verifiers []verifier +} + type state uint8 const ( @@ -66,6 +140,184 @@ const ( stateBuildBlockEpochSealed ) + +// BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. +func (sm *StateMachine) BuildBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata simplex.ProtocolMetadata, simplexBlacklist *simplex.Blacklist) (*StateMachineBlock, error) { + sm.maybeInit() + + // The zero sequence number is reserved for the genesis block, which should never be built. + if simplexMetadata.Seq == 0 { + return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", simplexMetadata.Seq) + } + + start := time.Now() + + sm.Logger.Debug("Building block", + zap.Uint64("seq", simplexMetadata.Seq), + zap.Uint64("epoch", simplexMetadata.Epoch), + zap.Stringer("prevHash", simplexMetadata.Prev)) + + defer func() { + elapsed := time.Since(start) + sm.Logger.Debug("Built block", + zap.Uint64("seq", simplexMetadata.Seq), + zap.Uint64("epoch", simplexMetadata.Epoch), + zap.Stringer("prevHash", simplexMetadata.Prev), + zap.Duration("elapsed", elapsed), + ) + }() + + var simplexBlacklistBytes []byte + if simplexBlacklist != nil { + simplexBlacklistBytes = simplexBlacklist.Bytes() + } + + // In order to know where in the epoch change process we are, + // we identify the current state by looking at the parent block's epoch info. + currentState, err := identifyCurrentState(parentBlock.Metadata.SimplexEpochInfo) + if err != nil { + return nil, err + } + + simplexMetadataBytes := simplexMetadata.Bytes() + prevBlockSeq := simplexMetadata.Seq - 1 + + switch currentState { + case stateFirstSimplexBlock: + return sm.buildBlockZero(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes) + case stateBuildBlockNormalOp: + return sm.buildBlockNormalOp(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + case stateBuildCollectingApprovals: + return sm.buildBlockCollectingApprovals(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + case stateBuildBlockEpochSealed: + return sm.buildBlockEpochSealed(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) + default: + return nil, fmt.Errorf("unknown state %d", currentState) + } +} + +// VerifyBlock validates a proposed block by checking its metadata, epoch info, +// and inner block against the previous block and the current state. +func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { + sm.maybeInit() + + if block == nil { + return fmt.Errorf("InnerBlock is nil") + } + + pmd, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed to parse ProtocolMetadata: %w", err) + } + + seq := pmd.Seq + + if seq == 0 { + return fmt.Errorf("attempted to build a genesis inner block") + } + + prevBlock, _, err := sm.GetBlock(RetrievingOpts{Digest: pmd.Prev, Height: seq - 1}) + if err != nil { + return fmt.Errorf("failed to retrieve previous (%d) inner block: %w", seq-1, err) + } + + prevMD := prevBlock.Metadata + currentState, err := identifyCurrentState(prevMD.SimplexEpochInfo) + if err != nil { + return fmt.Errorf("failed to identify previous state: %w", err) + } + + switch currentState { + case stateFirstSimplexBlock: + err = sm.verifyBlockZero(ctx, block, prevBlock) + default: + err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) + } + return err +} + +func (sm *StateMachine) maybeInit() { + if sm.initialized { + return + } + sm.init() + sm.initialized = true +} + +func (sm *StateMachine) init() { + sm.verifiers = []verifier{ + &pChainHeightVerifier{ + getPChainHeight: sm.GetPChainHeight, + }, + ×tampVerifier{ + timeSkewLimit: sm.TimeSkewLimit, + getTime: sm.GetTime, + }, + &pChainReferenceHeightVerifier{}, + &epochNumberVerifier{}, + &prevSealingBlockHashVerifier{ + getBlock: sm.GetBlock, + latestPersistedHeight: &sm.LatestPersistedHeight, + }, + &nextPChainReferenceHeightVerifier{ + getPChainHeight: sm.GetPChainHeight, + getValidatorSet: sm.GetValidatorSet, + }, + &vmBlockSeqVerifier{ + getBlock: sm.GetBlock, + }, + &validationDescriptorVerifier{ + getValidatorSet: sm.GetValidatorSet, + }, + &nextEpochApprovalsVerifier{ + getValidatorSet: sm.GetValidatorSet, + keyAggregator: sm.KeyAggregator, + sigVerifier: sm.SignatureVerifier, + }, + &sealingBlockSeqVerifier{}, + } +} + +func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { + blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) + sm.Logger.Debug("Identified block type", + zap.Stringer("blockType", blockType), + zap.Bool("nextHasBVD", block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil), + zap.Uint64("nextEpochNumber", block.Metadata.SimplexEpochInfo.EpochNumber), + zap.Bool("prevHasBVD", prevBlockMD.SimplexEpochInfo.BlockValidationDescriptor != nil), + zap.Uint64("prevEpochNumber", prevBlockMD.SimplexEpochInfo.EpochNumber), + zap.Uint64("prevNextPChainRefHeight", prevBlockMD.SimplexEpochInfo.NextPChainReferenceHeight), + zap.Uint64("prevSealingBlockSeq", prevBlockMD.SimplexEpochInfo.SealingBlockSeq), + zap.Uint64("prevSeq", prevSeq), + ) + + var innerBlockTimestamp time.Time + if block.InnerBlock != nil { + innerBlockTimestamp = block.InnerBlock.Timestamp() + } + + for _, verifier := range sm.verifiers { + if err := verifier.Verify(verificationInput{ + proposedBlockMD: block.Metadata, + nextBlockType: blockType, + prevMD: prevBlockMD, + state: state, + prevBlockSeq: prevSeq, + hasInnerBlock: block.InnerBlock != nil, + innerBlockTimestamp: innerBlockTimestamp, + }); err != nil { + sm.Logger.Debug("Invalid block", zap.Error(err)) + return err + } + } + + if block.InnerBlock == nil { + return nil + } + + return block.InnerBlock.Verify(ctx) +} + func identifyCurrentState(prevBlockSimplexEpochInfo SimplexEpochInfo) state { // If this is the first ever epoch, then this is also the first ever block to be built by Simplex. if prevBlockSimplexEpochInfo.EpochNumber == 0 { @@ -89,6 +341,455 @@ func identifyCurrentState(prevBlockSimplexEpochInfo SimplexEpochInfo) state { return stateBuildCollectingApprovals } + +// buildBlockNormalOp builds a block while not trying to transition to a new epoch. +func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // Since in the previous block, we were not transitioning to a new epoch, + // the P-chain reference height and epoch of the new block should remain the same. + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + blockBuildingDecider := sm.createBlockBuildingDecider(parentBlock) + decisionToBuildBlock, pChainHeight, err := blockBuildingDecider.shouldBuildBlock(ctx) + if err != nil { + return nil, err + } + + sm.Logger.Debug("Block building decision", zap.Stringer("decision", decisionToBuildBlock)) + + var childBlock VMBlock + + switch decisionToBuildBlock { + case blockBuildingDecisionBuildBlock, blockBuildingDecisionBuildBlockAndTransitionEpoch: + // If we reached here, we need to build a new block, and maybe also transition to a new epoch. + return sm.buildBlockAndMaybeTransitionEpoch(ctx, parentBlock, simplexMetadata, simplexBlacklist, childBlock, decisionToBuildBlock, newSimplexEpochInfo, pChainHeight) + case blockBuildingDecisionTransitionEpoch: + // If we reached here, we don't need to build an inner block, yet we need to transition to a new epoch. + // Initiate the epoch transition by setting the next P-chain reference height for the new epoch info, + // and build a block without an inner block. + newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight + return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + case blockBuildingDecisionContextCanceled: + return nil, ctx.Err() + default: + return nil, fmt.Errorf("unknown block building decision %d", decisionToBuildBlock) + } +} + +func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock) blockBuildingDecider { + blockBuildingDecider := blockBuildingDecider{ + logger: sm.Logger, + maxBlockBuildingWaitTime: sm.MaxBlockBuildingWaitTime, + pChainlistener: sm.PChainProgressListener, + getPChainHeight: sm.GetPChainHeight, + waitForPendingBlock: sm.BlockBuilder.WaitForPendingBlock, + shouldTransitionEpoch: func(pChainHeight uint64) (bool, error) { + // The given pChainHeight was sampled by the caller of shouldTransitionEpoch(). + // We compare between the current validator set, defined by the P-chain reference height in the parent block, + // and the new validator set defined by the given pChainHeight. + // If they are different, then we should transition to a new epoch. + + currentValidatorSet, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return false, err + } + + newValidatorSet, err := sm.GetValidatorSet(pChainHeight) + if err != nil { + return false, err + } + + if !currentValidatorSet.Equal(newValidatorSet) { + return true, nil + } + return false, nil + }, + } + return blockBuildingDecider +} + +func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, childBlock VMBlock, decisionToBuildBlock blockBuildingDecision, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + // TODO: This P-chain height should be taken from the ICM epoch + childBlock, err := sm.BlockBuilder.BuildBlock(ctx, pChainHeight) + if err != nil { + return nil, err + } + + if decisionToBuildBlock == blockBuildingDecisionBuildBlockAndTransitionEpoch { + // We need to also transition to a new epoch, in addition to building an inner block, + // so set the next P-chain reference height for the new epoch info. + newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight + } + + return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil +} + +// buildBlockZero builds the first ever block for Simplex, +// which is a special block that introduces the first validator set and starts the first epoch. +func (sm *StateMachine) buildBlockZero(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { + pChainHeight := sm.GetPChainHeight() + + newValidatorSet, err := sm.GetValidatorSet(pChainHeight) + if err != nil { + return nil, err + } + + var prevVMBlockSeq uint64 + if parentBlock.InnerBlock != nil { + prevVMBlockSeq = parentBlock.InnerBlock.Height() + } else { + // We can only have blocks without inner blocks in Simplex blocks, but this is the first Simplex block. + // Therefore, the parent block must have an inner block. + sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") + return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") + } + simplexEpochInfo := constructSimplexZeroBlock(pChainHeight, newValidatorSet, prevVMBlockSeq) + + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) +} + +func (sm *StateMachine) verifyBlockZero(ctx context.Context, block *StateMachineBlock, prevBlock StateMachineBlock) error { + if block == nil { + return fmt.Errorf("block is nil") + } + + simplexEpochInfo := block.Metadata.SimplexEpochInfo + + if simplexEpochInfo.EpochNumber != 1 { + return fmt.Errorf("invalid epoch number (%d), should be 1", simplexEpochInfo.EpochNumber) + } + + if prevBlock.InnerBlock == nil { + return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) + } + + prevVMBlockSeq := prevBlock.InnerBlock.Height() + + currentPChainHeight := sm.GetPChainHeight() + + if block.Metadata.PChainHeight > currentPChainHeight { + return fmt.Errorf("invalid P-chain height (%d) is too big, expected to be ≤ %d", + block.Metadata.PChainHeight, currentPChainHeight) + } + + if prevBlock.Metadata.PChainHeight > block.Metadata.PChainHeight { + return fmt.Errorf("invalid P-chain height (%d) is smaller than parent InnerBlock's P-chain height (%d)", + block.Metadata.PChainHeight, prevBlock.Metadata.PChainHeight) + } + + expectedValidatorSet, err := sm.GetValidatorSet(simplexEpochInfo.PChainReferenceHeight) + if err != nil { + return fmt.Errorf("failed to retrieve validator set at height %d: %w", simplexEpochInfo.PChainReferenceHeight, err) + } + + if simplexEpochInfo.BlockValidationDescriptor == nil { + return fmt.Errorf("invalid BlockValidationDescriptor: should not be nil") + } + + membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members + if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { + return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", simplexEpochInfo.PChainReferenceHeight) + } + + // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo + expectedSimplexEpochInfo := constructSimplexZeroBlock(simplexEpochInfo.PChainReferenceHeight, expectedValidatorSet, prevVMBlockSeq) + + if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { + return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) + } + + _, err = sm.verifyZeroBlockTimestamp(block, prevBlock) + if err != nil { + return err + } + + if block.InnerBlock == nil { + return nil + } + + return block.InnerBlock.Verify(ctx) +} + +func (sm *StateMachine) verifyZeroBlockTimestamp(block *StateMachineBlock, prevBlock StateMachineBlock) (time.Time, error) { + var proposedTime time.Time + if block.InnerBlock != nil { + proposedTime = block.InnerBlock.Timestamp() + } else { + proposedTime = time.UnixMilli(int64(prevBlock.Metadata.Timestamp)) + } + + expectedTimestamp := proposedTime.UnixMilli() + if expectedTimestamp != int64(block.Metadata.Timestamp) { + return time.Time{}, fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, int64(block.Metadata.Timestamp)) + } + currentTime := sm.GetTime() + if currentTime.Add(sm.TimeSkewLimit).Before(proposedTime) { + return time.Time{}, fmt.Errorf("proposed block timestamp is too far in the future, current time is %s but got %s", currentTime.String(), proposedTime.String()) + } + if prevBlock.Metadata.Timestamp > block.Metadata.Timestamp { + return time.Time{}, fmt.Errorf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", prevBlock.Metadata.Timestamp, block.Metadata.Timestamp) + } + return proposedTime, nil +} + +func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // The P-chain reference height and epoch number should remain the same until we transition to the new epoch. + // The next P-chain reference height should have been set in the previous block, + // which is the reason why we are collecting approvals in the first place. + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + // We prepare information that is needed to compute the approvals for the new epoch, + // such as the validator set for the next epoch, and the approvals from peers. + validators, err := sm.GetValidatorSet(parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight) + if err != nil { + return nil, err + } + + // We retrieve approvals that validators have sent us for the next epoch. + // These approvals are signed by validators of the next epoch. + approvalsFromPeers := sm.ApprovalsRetriever.RetrieveApprovals() + nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight + prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals + + newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sm.SignatureAggregator, validators) + if err != nil { + return nil, err + } + + // This might be the first time we created approvals for the next epoch, + // so we need to initialize the NextEpochApprovals. + if newSimplexEpochInfo.NextEpochApprovals == nil { + newSimplexEpochInfo.NextEpochApprovals = &NextEpochApprovals{} + } + // The node IDs and signature are aggregated across all past and present approvals. + newSimplexEpochInfo.NextEpochApprovals.NodeIDs = newApprovals.nodeIDs + newSimplexEpochInfo.NextEpochApprovals.Signature = newApprovals.signature + pChainHeight := parentBlock.Metadata.PChainHeight + + // We might not have enough approvals to seal the current epoch, + // in which case we just carry over the approvals we have so far to the next block, + // so that eventually we'll have enough approvals to seal the epoch. + if !newApprovals.canSeal { + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) + } + + // Else, we have enough approvals to seal the epoch, so we create the sealing block. + return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, newApprovals, pChainHeight) +} + +// buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. +// If the VM fails to build a block within that time, we build a block without an inner block, +// so that we can continue making progress and not get stuck waiting for the VM. +func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { + impatientContext, cancel := context.WithTimeout(ctx, sm.MaxBlockBuildingWaitTime) + defer cancel() + + start := time.Now() + + // TODO: This P-chain height should be taken from the ICM epoch + childBlock, err := sm.BlockBuilder.BuildBlock(impatientContext, pChainHeight) + if err != nil && impatientContext.Err() == nil { + // If we got an error building the block, and we didn't time out, log the error but continue building the block without the inner block, + // so that we can continue making progress and not get stuck on a single block. + sm.Logger.Error("Error building block, building block without inner block instead", zap.Error(err)) + } + if impatientContext.Err() != nil { + sm.Logger.Debug("Timed out waiting for block to be built, building block without inner block instead", + zap.Duration("elapsed", time.Since(start)), zap.Duration("maxBlockBuildingWaitTime", sm.MaxBlockBuildingWaitTime)) + } + return sm.wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil +} + +func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, newApprovals *approvals, pChainHeight uint64) (*StateMachineBlock, error) { + validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) + if err != nil { + return nil, err + } + if simplexEpochInfo.BlockValidationDescriptor == nil { + simplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + } + simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members = validators + + // If this is not the first epoch, and this is the sealing block, we set the hash of the previous sealing block. + if simplexEpochInfo.EpochNumber > 1 { + prevSealingBlock, finalization, err := sm.GetBlock(RetrievingOpts{Height: simplexEpochInfo.EpochNumber}) + if err != nil { + sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) + return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) + } + if finalization == nil { + sm.Logger.Error("Previous sealing block is not finalized", zap.Uint64("seq", simplexEpochInfo.EpochNumber)) + return nil, fmt.Errorf("previous sealing InnerBlock at epoch %d is not finalized", simplexEpochInfo.EpochNumber-1) + } + simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() + } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. + + firstSimplexBlock, err := findFirstSimplexBlock(sm.GetBlock, sm.LatestPersistedHeight+1) + if err != nil { + return nil, fmt.Errorf("failed to find first simplex block: %w", err) + } + firstSimplexBlockRetrieved, _, err := sm.GetBlock(RetrievingOpts{Height: firstSimplexBlock}) + if err != nil { + return nil, fmt.Errorf("failed to retrieve first simplex block at height %d: %w", firstSimplexBlock, err) + } + simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlockRetrieved.Digest() + } + + return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) +} + +// wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. +func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { + parentMetadata := parentBlock.Metadata + timestamp := parentMetadata.Timestamp + + hasChildBlock := childBlock != nil + + var newTimestamp time.Time + if hasChildBlock { + newTimestamp = childBlock.Timestamp() + timestamp = uint64(newTimestamp.UnixMilli()) + } + + return &StateMachineBlock{ + InnerBlock: childBlock, + Metadata: StateMachineMetadata{ + Timestamp: timestamp, + SimplexProtocolMetadata: simplexMetadata, + SimplexBlacklist: simplexBlacklist, + SimplexEpochInfo: newSimplexEpochInfo, + PChainHeight: pChainHeight, + }, + } +} + +// buildBlockEpochSealed builds a block where the epoch is being sealed due to a sealing block already created in this epoch. +func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { + // We check if the sealing block has already been finalized. + // If not, we build a Telock block. + + sealingBlockSeq := parentBlock.Metadata.SimplexEpochInfo.SealingBlockSeq + + // If the sealing block sequence is still 0, it means previous block was the sealing block. + if sealingBlockSeq == 0 { + sealingBlockSeq = prevBlockSeq + } + + if sealingBlockSeq == 0 { + return nil, fmt.Errorf("cannot build epoch sealed block: sealing block sequence is 0 or undefined") + } + + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight, + EpochNumber: parentBlock.Metadata.SimplexEpochInfo.EpochNumber, + NextPChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + SealingBlockSeq: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + _, finalization, err := sm.GetBlock(RetrievingOpts{Height: sealingBlockSeq}) + if err != nil { + return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) + } + + isSealingBlockFinalized := finalization != nil + + if !isSealingBlockFinalized { + pChainHeight := parentBlock.Metadata.PChainHeight + return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil + } + + // Else, we build a block for the new epoch. + newSimplexEpochInfo = SimplexEpochInfo{ + // P-chain reference height is previous block's NextPChainReferenceHeight. + PChainReferenceHeight: parentBlock.Metadata.SimplexEpochInfo.NextPChainReferenceHeight, + // The epoch number is the sequence of the sealing block. + EpochNumber: sealingBlockSeq, + PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), + } + + // TODO: This P-chain height should be taken from the ICM epoch + childBlock, err := sm.BlockBuilder.BuildBlock(ctx, sm.GetPChainHeight()) + if err != nil { + return nil, err + } + + return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, parentBlock.Metadata.PChainHeight, simplexMetadata, simplexBlacklist), nil +} + +// constructSimplexZeroBlock constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. +func constructSimplexZeroBlock(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { + newSimplexEpochInfo := SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight, + EpochNumber: 1, + // We treat the zero block as a special case, and we encode in it the block validation descriptor, + // despite it not actually being a sealing block. This is because the zero block is the first block that introduces the validator set. + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: newValidatorSet, + }, + }, + NextEpochApprovals: nil, // We don't need to collect approvals to seal the first ever epoch. + PrevVMBlockSeq: prevVMBlockSeq, + SealingBlockSeq: 0, // We don't have a sealing block in the zero block. + PrevSealingBlockHash: [32]byte{}, // The zero block has no previous sealing block. + NextPChainReferenceHeight: 0, + } + return newSimplexEpochInfo +} + +func computeNewApprovals( + nextEpochApprovals *NextEpochApprovals, + approvalsFromPeers ValidatorSetApprovals, + pChainHeight uint64, + aggregator SignatureAggregator, + validators NodeBLSMappings, +) (*approvals, error) { + if nextEpochApprovals == nil { + nextEpochApprovals = &NextEpochApprovals{} + } + + oldApprovingNodes := bitmaskFromBytes(nextEpochApprovals.NodeIDs) + + // We map each validator to its relative index in the validator set. + nodeID2ValidatorIndex := make(map[nodeID]int) + validators.ForEach(func(i int, nbm NodeBLSMapping) { + nodeID2ValidatorIndex[nbm.NodeID] = i + }) + + // We have the approvals obtained from peers, but we need to sanitize them by filtering out approvals that are not valid, + // such as approvals that do not agree with our candidate auxiliary info digest and P-Chain height, + // and approvals that are from nodes that are not in the validator set or have already approved in prior blocks. + approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes) + + // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. + aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, aggregator) + if err != nil { + return nil, err + } + + // we check if we have enough approvals to seal the epoch by computing the relative approval ratio, + // which is the ratio of the total weight of approving nodes divided by the total weight of all validators. + canSeal, err := canSealBlock(validators, newApprovingNodes) + if err != nil { + return nil, err + } + + return &approvals{ + canSeal: canSeal, + signature: aggregatedSignature, + nodeIDs: newApprovingNodes.Bytes(), + }, nil +} + // computeNewApproverSignaturesAndSigners computes the signatures of the nodes that approve the next epoch including the previous aggregated signature, // and bitmask of nodes that correspond to those signatures, and aggregates all signatures together. func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int, aggregator SignatureAggregator) ([]byte, bitmask, error) { @@ -133,6 +834,25 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova return aggregatedSignature, newApprovingNodes, nil } +func canSealBlock(validators NodeBLSMappings, newApprovingNodes bitmask) (bool, error) { + approvingWeight, err := computeApprovingWeight(validators, &newApprovingNodes) + if err != nil { + return false, err + } + + totalWeight, err := computeTotalWeight(validators) + if err != nil { + return false, err + } + + threshold := big.NewRat(2, 3) + + approvingRatio := big.NewRat(approvingWeight, totalWeight) + + canSeal := approvingRatio.Cmp(threshold) > 0 + return canSeal, nil +} + // sanitizeApprovals filters out approvals that are not valid by checking if they agree with our candidate auxiliary info digest and P-Chain height, // and if they are from the validator set and haven't already been approved. func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask) ValidatorSetApprovals { @@ -229,3 +949,9 @@ func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexE } return nil } + +type approvals struct { + canSeal bool + nodeIDs []byte + signature []byte +} \ No newline at end of file diff --git a/msm/msm_test.go b/msm/msm_test.go index f8014254..bb9e8a6b 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -6,12 +6,15 @@ package metadata import ( "bytes" "context" + "crypto/rand" + "encoding/asn1" "fmt" "math" "testing" "time" "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -25,6 +28,1024 @@ func (f *fakeVMBlock) Height() uint64 { return f.height } func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } +type outerBlock struct { + finalization *simplex.Finalization + block StateMachineBlock +} + +type blockStore map[uint64]*outerBlock + +func (bs blockStore) clone() blockStore { + newStore := make(blockStore) + for k, v := range bs { + newStore[k] = v + } + return newStore +} + +func (bs blockStore) getBlock(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[opts.Height] + if !exits { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, opts.Height) + } + return blk.block, blk.finalization, nil +} + +type approvalsRetriever struct { + result ValidatorSetApprovals +} + +func (a approvalsRetriever) RetrieveApprovals() ValidatorSetApprovals { + return a.result +} + +type signatureVerifier struct { + err error +} + +func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { + return sv.err +} + +type signatureAggregator struct { +} + +type aggregatrdSignature struct { + Signatures [][]byte +} + +func (sv *signatureAggregator) AggregateSignatures(signatures ...[]byte) ([]byte, error) { + bytes, err := asn1.Marshal(aggregatrdSignature{Signatures: signatures}) + if err != nil { + return nil, err + } + return bytes, nil +} + +type noOpPChainListener struct{} + +func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { + <-ctx.Done() + return ctx.Err() +} + +type blockBuilder struct { + block VMBlock + err error +} + +func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { + // Block is always ready in tests. +} + +func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { + return bb.block, bb.err +} + +type validatorSetRetriever struct { + result NodeBLSMappings + resultMap map[uint64]NodeBLSMappings + err error +} + +func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { + if vsr.resultMap != nil { + if result, ok := vsr.resultMap[height]; ok { + return result, vsr.err + } + } + return vsr.result, vsr.err +} + +type keyAggregator struct{} + +func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + aggregated := make([]byte, 0) + for _, key := range keys { + aggregated = append(aggregated, key...) + } + return aggregated, nil +} + +var ( + genesisBlock = StateMachineBlock{ + // Genesis block metadata has all zero values + InnerBlock: &InnerBlock{ + TS: time.Now(), + Bytes: []byte{1, 2, 3}, + }, + } +) + +func TestMSMFirstBlockAfterGenesis(t *testing.T) { + validMD := simplex.ProtocolMetadata{ + Round: 0, + Seq: 1, + Epoch: 1, + Prev: genesisBlock.Digest(), + } + + for _, testCase := range []struct { + name string + md simplex.ProtocolMetadata + err string + configure func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + }{ + { + name: "correct information", + md: validMD, + }, + { + name: "trying to build a genesis block", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: "attempted to build a genesis inner block", + }, + { + name: "previous block not found", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + delete(tc.blockStore, 0) + }, + err: "failed to retrieve previous (0) inner block", + }, + { + name: "parent has no inner block", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + tc.blockStore[0] = &outerBlock{ + block: StateMachineBlock{}, + } + }, + err: "parent inner block (", + }, + { + name: "wrong epoch number", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: "invalid epoch number (2), should be 1", + }, + { + name: "P-chain height too big", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: "invalid P-chain height (110) is too big", + }, + { + name: "P-chain height smaller than parent", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + tc.blockStore[0] = &outerBlock{ + block: StateMachineBlock{ + InnerBlock: &InnerBlock{TS: time.Now(), Bytes: []byte{1, 2, 3}}, + Metadata: StateMachineMetadata{PChainHeight: 110}, + }, + } + }, + err: "invalid P-chain height (100) is smaller than parent InnerBlock's P-chain height (110)", + }, + { + name: "validator set retrieval fails", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + tc.validatorSetRetriever.err = fmt.Errorf("validator set unavailable") + }, + err: "failed to retrieve validator set", + }, + { + name: "nil BlockValidationDescriptor", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = nil + }, + err: "invalid BlockValidationDescriptor: should not be nil", + }, + { + name: "membership mismatch", + md: validMD, + configure: func(_ *StateMachine, tc *testConfig) { + tc.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, + } + }, + err: "invalid BlockValidationDescriptor: should match validator set", + }, + { + name: "SimplexEpochInfo mismatch", + md: validMD, + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: "invalid SimplexEpochInfo", + }, + } { + t.Run(testCase.name, func(t *testing.T) { + sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) + + testConfig1.blockStore[0] = &outerBlock{ + block: genesisBlock, + } + + testConfig2.blockStore[0] = &outerBlock{ + block: genesisBlock, + } + + if testCase.configure != nil { + testCase.configure(&sm2, testConfig2) + } + + block, err := sm1.BuildBlock(context.Background(), genesisBlock, testCase.md, nil) + require.NoError(t, err) + require.NotNil(t, block) + + if testCase.mutateBlock != nil { + testCase.mutateBlock(block) + } + + err = sm2.VerifyBlock(context.Background(), block) + if testCase.err != "" { + require.ErrorContains(t, err, testCase.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { + preSimplexParent := StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: time.Now(), + BlockHeight: 42, + Bytes: []byte{4, 5, 6}, + }, + // Zero-valued metadata means this is a pre-Simplex block or a genesis block. + // But since the height is 42, it can't be a genesis block, so it must be a pre-Simplex block. + Metadata: StateMachineMetadata{}, + } + + md := simplex.ProtocolMetadata{ + Round: 0, + Seq: 43, + Epoch: 1, + Prev: preSimplexParent.Digest(), + } + + sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) + + testConfig1.blockStore[42] = &outerBlock{block: preSimplexParent} + testConfig2.blockStore[42] = &outerBlock{block: preSimplexParent} + + testConfig1.blockBuilder.block = &InnerBlock{ + TS: time.Now(), + BlockHeight: 43, + Bytes: []byte{7, 8, 9}, + } + + block, err := sm1.BuildBlock(context.Background(), preSimplexParent, md, nil) + require.NoError(t, err) + require.NotNil(t, block) + + require.NoError(t, sm2.VerifyBlock(context.Background(), block)) + + require.Equal(t, &StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: testConfig1.blockBuilder.block.Timestamp(), + BlockHeight: 43, + Bytes: []byte{7, 8, 9}, + }, + Metadata: StateMachineMetadata{ + Timestamp: uint64(testConfig1.blockBuilder.block.Timestamp().UnixMilli()), + PChainHeight: 100, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: 42, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: testConfig1.validatorSetRetriever.result, + }, + }, + }, + }, + }, block) +} + +func TestMSMNormalOp(t *testing.T) { + newPChainHeight := uint64(200) + newValidatorSet := NodeBLSMappings{ + {BLSKey: []byte{5}, Weight: 1}, {BLSKey: []byte{6}, Weight: 1}, {BLSKey: []byte{7}, Weight: 1}, + } + + for _, testCase := range []struct { + name string + setup func(*StateMachine, *testConfig) + mutateBlock func(*StateMachineBlock) + err string + expectedPChainHeight uint64 + expectedNextPChainRefHeight uint64 + }{ + { + name: "correct information", + expectedPChainHeight: 100, + }, + { + name: "trying to build a genesis block", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 0 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: "attempted to build a genesis inner block", + }, + { + name: "previous block not found", + mutateBlock: func(block *StateMachineBlock) { + md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + md.Seq = 999 + block.Metadata.SimplexProtocolMetadata = md.Bytes() + }, + err: "failed to retrieve previous (998) inner block", + }, + { + name: "P-chain height too big", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 110 + }, + err: "invalid P-chain height (110) is too big", + }, + { + name: "P-chain height smaller than parent", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.PChainHeight = 0 + }, + err: "invalid P-chain height (0) is smaller than parent block's P-chain height (100)", + }, + { + name: "wrong epoch number", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.EpochNumber = 2 + }, + err: "expected epoch number to be 1 but got 2", + }, + { + name: "non-nil BlockValidationDescriptor", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} + }, + err: "failed to find first Simplex block", + }, + { + name: "non-zero sealing block seq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.SealingBlockSeq = 5 + }, + err: "expected sealing block sequence number to be 0 but got 5", + }, + { + name: "wrong PChainReferenceHeight", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PChainReferenceHeight = 50 + }, + err: "expected P-chain reference height to be 100 but got 50", + }, + { + name: "non-empty PrevSealingBlockHash", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevSealingBlockHash = [32]byte{1, 2, 3} + }, + err: "expected prev sealing block hash of a non sealing block to be empty", + }, + { + name: "wrong PrevVMBlockSeq", + mutateBlock: func(block *StateMachineBlock) { + block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 + }, + err: "expected PrevVMBlockSeq to be", + }, + { + name: "validator set change detected", + setup: func(sm *StateMachine, tc *testConfig) { + tc.validatorSetRetriever.resultMap = map[uint64]NodeBLSMappings{ + newPChainHeight: newValidatorSet, + } + sm.GetPChainHeight = func() uint64 { return newPChainHeight } + }, + expectedPChainHeight: newPChainHeight, + expectedNextPChainRefHeight: newPChainHeight, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + chain := makeChain(t, 5, 10) + sm1, testConfig1 := newStateMachine(t) + sm2, testConfig2 := newStateMachine(t) + + for i, block := range chain { + testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} + testConfig2.blockStore[uint64(i)] = &outerBlock{block: block} + } + + lastBlock := chain[len(chain)-1] + md, err := simplex.ProtocolMetadataFromBytes(lastBlock.Metadata.SimplexProtocolMetadata) + require.NoError(t, err) + + md.Seq++ + md.Round++ + md.Prev = lastBlock.Digest() + + var blacklist simplex.Blacklist + blacklist.NodeCount = 4 + + blockTime := lastBlock.InnerBlock.Timestamp().Add(time.Second) + + content := make([]byte, 10) + _, err = rand.Read(content) + require.NoError(t, err) + + testConfig1.blockBuilder.block = &InnerBlock{ + TS: blockTime, + BlockHeight: lastBlock.InnerBlock.Height(), + Bytes: content, + } + + if testCase.setup != nil { + testCase.setup(&sm1, testConfig1) + testCase.setup(&sm2, testConfig2) + } + + block1, err := sm1.BuildBlock(context.Background(), lastBlock, *md, &blacklist) + require.NoError(t, err) + require.NotNil(t, block1) + + if testCase.mutateBlock != nil { + testCase.mutateBlock(block1) + } + + err = sm2.VerifyBlock(context.Background(), block1) + if testCase.err != "" { + require.ErrorContains(t, err, testCase.err) + return + } + require.NoError(t, err) + + require.Equal(t, &StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: blockTime, + BlockHeight: lastBlock.InnerBlock.Height(), + Bytes: content, + }, + Metadata: StateMachineMetadata{ + SimplexBlacklist: blacklist.Bytes(), + Timestamp: uint64(blockTime.UnixMilli()), + PChainHeight: testCase.expectedPChainHeight, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: lastBlock.InnerBlock.Height(), + NextPChainReferenceHeight: testCase.expectedNextPChainRefHeight, + }, + }, + }, block1) + }) + } +} + +func TestMSMFullEpochLifecycle(t *testing.T) { + // Validator sets: epoch 1 uses validatorSet1, epoch 2 uses validatorSet2. + node1 := [20]byte{1} + node2 := [20]byte{2} + node3 := [20]byte{3} + + validatorSet1 := NodeBLSMappings{ + {NodeID: node1, BLSKey: []byte{1}, Weight: 1}, + {NodeID: node2, BLSKey: []byte{2}, Weight: 1}, + {NodeID: node3, BLSKey: []byte{3}, Weight: 1}, + } + validatorSet2 := NodeBLSMappings{ + {NodeID: node1, BLSKey: []byte{1}, Weight: 1}, + {NodeID: node2, BLSKey: []byte{4}, Weight: 1}, + {NodeID: node3, BLSKey: []byte{5}, Weight: 1}, + } + + pChainHeight1 := uint64(100) + pChainHeight2 := uint64(200) + + startTime := time.Now() + + nextBlock := func(height uint64) *InnerBlock { + return &InnerBlock{ + TS: startTime.Add(time.Duration(height) * time.Millisecond), + BlockHeight: height, + Bytes: []byte{byte(height)}, + } + } + + // ----- Step 0: Building on top of genesis or upgrading to Simplex----- + genesis := StateMachineBlock{ + InnerBlock: &InnerBlock{ + BlockHeight: 0, // Genesis block has height 0 + TS: startTime, + Bytes: []byte{0}, + }, + } + + notGenesis := StateMachineBlock{ + InnerBlock: &InnerBlock{ + BlockHeight: 42, + TS: startTime, + Bytes: []byte{0}, + }, + } + for _, testCase := range []struct { + name string + firstBlockBeforeSimplex StateMachineBlock + }{ + { + name: "building on top of genesis", + firstBlockBeforeSimplex: genesis, + }, + { + name: "upgrading to Simplex from pre-Simplex blocks", + firstBlockBeforeSimplex: notGenesis, + }, + } { + t.Run(testCase.name, func(t *testing.T) { + + currentPChainHeight := pChainHeight1 + + getValidatorSet := func(height uint64) (NodeBLSMappings, error) { + if height >= pChainHeight2 { + return validatorSet2, nil + } + return validatorSet1, nil + } + getPChainHeight := func() uint64 { + return currentPChainHeight + } + + // Create fresh state machine instances for each iteration. + sm, tc := newStateMachine(t) + sm.GetValidatorSet = getValidatorSet + sm.GetPChainHeight = getPChainHeight + + smVerify, tcVerify := newStateMachine(t) + smVerify.GetValidatorSet = getValidatorSet + smVerify.GetPChainHeight = getPChainHeight + + // addBlock adds a block to both block stores so builder and verifier stay in sync. + addBlock := func(seq uint64, block StateMachineBlock, fin *simplex.Finalization) { + tc.blockStore[seq] = &outerBlock{block: block, finalization: fin} + tcVerify.blockStore[seq] = &outerBlock{block: block, finalization: fin} + } + + baseSeq := testCase.firstBlockBeforeSimplex.InnerBlock.Height() + addBlock(baseSeq, testCase.firstBlockBeforeSimplex, nil) + + aggr := &signatureAggregator{} + + // ----- Step 1: Build zero epoch block (first simplex block) ----- + tc.blockBuilder.block = nextBlock(1) + md := simplex.ProtocolMetadata{ + Seq: baseSeq + 1, + Round: 0, + Epoch: 1, + Prev: testCase.firstBlockBeforeSimplex.Digest(), + } + + block1, err := sm.BuildBlock(context.Background(), testCase.firstBlockBeforeSimplex, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(1), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(1 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight1, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: validatorSet1, + }, + }, + }, + }, + }, block1) + addBlock(md.Seq, *block1, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block1)) + + // After we build the first block, the StateMachine should consider it as the latest persisted height. + sm.LatestPersistedHeight = baseSeq + 1 + smVerify.LatestPersistedHeight = baseSeq + 1 + + // ----- Step 2: Build a normal block (no validator set change) ----- + tc.blockBuilder.block = nextBlock(2) + md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: 1, Prev: block1.Digest()} + block2, err := sm.BuildBlock(context.Background(), *block1, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(2), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(2 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight1, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 1, + }, + }, + }, block2) + addBlock(md.Seq, *block2, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block2)) + + // ----- Step 3: Build a normal block that detects a validator set change ----- + // Advance P-chain height so that GetValidatorSet returns a different set. + currentPChainHeight = pChainHeight2 + + tc.blockBuilder.block = nextBlock(3) + md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: 1, Prev: block2.Digest()} + block3, err := sm.BuildBlock(context.Background(), *block2, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(3), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(3 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 2, + NextPChainReferenceHeight: pChainHeight2, + }, + }, + }, block3) + addBlock(md.Seq, *block3, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block3)) + + // ----- Step 4: First collecting block (1/3 approvals, not enough to seal) ----- + + // Override ApprovalsRetriever to use our dynamic approvals. + var approvalsResult ValidatorSetApprovals + sm.ApprovalsRetriever = &dynamicApprovalsRetriever{approvals: &approvalsResult} + + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node1, + PChainHeight: pChainHeight2, + Signature: []byte("sig1"), + }, + } + + // node1 is at index 0 in validatorSet2 → bitmask bit 0 → {1} + bitmask := []byte{1} + sig, err := aggr.AggregateSignatures([]byte("sig1")) + require.NoError(t, err) + + tc.blockBuilder.block = nextBlock(4) + md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: 1, Prev: block3.Digest()} + block4, err := sm.BuildBlock(context.Background(), *block3, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(4), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(4 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 3, + NextPChainReferenceHeight: pChainHeight2, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig, + }, + }, + }, + }, block4) + addBlock(md.Seq, *block4, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block4)) + + // ----- Step 5: Second collecting block (2/3 approvals, still not enough since threshold is strictly > 2/3) ----- + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node2, + PChainHeight: pChainHeight2, + Signature: []byte("sig2"), + }, + } + + // node2 is at index 1 → bitmask bits 0,1 → {3} + sig, err = aggr.AggregateSignatures([]byte("sig2"), sig) + require.NoError(t, err) + bitmask = []byte{3} + + tc.blockBuilder.block = nextBlock(5) + md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: 1, Prev: block4.Digest()} + block5, err := sm.BuildBlock(context.Background(), *block4, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(5), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(5 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 4, + NextPChainReferenceHeight: pChainHeight2, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig, + }, + }, + }, + }, block5) + addBlock(md.Seq, *block5, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block5)) + + // ----- Step 6: Sealing block (3/3 approvals, enough to seal) ----- + approvalsResult = ValidatorSetApprovals{ + { + NodeID: node3, + PChainHeight: pChainHeight2, + Signature: []byte("sig3"), + }, + } + + // node3 is at index 2 → bitmask bits 0,1,2 → {7} + sig6, err := aggr.AggregateSignatures([]byte("sig3"), sig) + require.NoError(t, err) + bitmask = []byte{7} + + tc.blockBuilder.block = nextBlock(6) + md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: 1, Prev: block5.Digest()} + block6, err := sm.BuildBlock(context.Background(), *block5, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(6), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + PrevVMBlockSeq: baseSeq + 5, + NextPChainReferenceHeight: pChainHeight2, + SealingBlockSeq: 0, + PrevSealingBlockHash: block1.Digest(), + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{ + Members: validatorSet2, + }, + }, + NextEpochApprovals: &NextEpochApprovals{ + NodeIDs: bitmask, + Signature: sig6, + }, + }, + }, + }, block6) + addBlock(md.Seq, *block6, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block6)) + + sealingSeq := baseSeq + 6 // The sealing block's sequence (md.Seq from step 6) + + backupStoreTC := tc.blockStore.clone() + backupStoreTCVerify := tcVerify.blockStore.clone() + + for _, subTestCase := range []struct { + name string + setup func() + }{ + { + name: "sealing block not finalized yet", + setup: func() { + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, nil) + }, + }, + { + name: "sealing block immediately finalized", + setup: func() { + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, &simplex.Finalization{}) + }, + }, + } { + testName := fmt.Sprintf("%s-%s", testCase.name, subTestCase.name) + t.Run(testName, func(t *testing.T) { + tc.blockStore = backupStoreTC.clone() + sm.GetBlock = tc.blockStore.getBlock + tcVerify.blockStore = backupStoreTCVerify.clone() + smVerify.GetBlock = tcVerify.blockStore.getBlock + + subTestCase.setup() + + tc.blockBuilder.block = nextBlock(7) + md = simplex.ProtocolMetadata{Seq: baseSeq + 7, Round: 6, Epoch: 1, Prev: block6.Digest()} + + // If the sealing block isn't finalized yet, we expect to build a Telock. + // However, despite the fact that the block builder is willing to build a new block, + // a Telock shouldn't contain an inner block. + if tc.blockStore[sealingSeq].finalization == nil { + telock, err := sm.BuildBlock(context.Background(), *block6, md, nil) + require.NoError(t, err) + + require.Equal(t, &StateMachineBlock{ + InnerBlock: nil, + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(6 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight1, + EpochNumber: 1, + NextPChainReferenceHeight: pChainHeight2, + PrevVMBlockSeq: baseSeq + 6, + SealingBlockSeq: sealingSeq, + }, + }, + }, telock) + + // Next, finalize the sealing block after we have built a Telock. + addBlock(sealingSeq, tc.blockStore[sealingSeq].block, &simplex.Finalization{}) + } + + // ----- Step 7: Build a new epoch block (sealing block is finalized) ----- + + block7, err := sm.BuildBlock(context.Background(), *block6, md, nil) + require.NoError(t, err) + require.Equal(t, &StateMachineBlock{ + InnerBlock: nextBlock(7), + Metadata: StateMachineMetadata{ + Timestamp: uint64(startTime.Add(7 * time.Millisecond).UnixMilli()), + PChainHeight: pChainHeight2, + SimplexProtocolMetadata: md.Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PChainReferenceHeight: pChainHeight2, + EpochNumber: sealingSeq, + PrevVMBlockSeq: baseSeq + 6, + }, + }, + }, block7) + addBlock(md.Seq, *block7, nil) + + require.NoError(t, smVerify.VerifyBlock(context.Background(), block7)) + }) + } + }) + } +} + +type dynamicApprovalsRetriever struct { + approvals *ValidatorSetApprovals +} + +func (d *dynamicApprovalsRetriever) RetrieveApprovals() ValidatorSetApprovals { + return *d.approvals +} + +func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { + startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) + blocks := make([]StateMachineBlock, 0, endHeight+1) + var round, seq uint64 + for h := uint64(0); h <= endHeight; h++ { + index := len(blocks) + + if h == 0 { + blocks = append(blocks, genesisBlock) + continue + } + + if h < simplexStartHeight { + blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) + continue + } + + seq = uint64(index) + + blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) + round++ + } + return blocks +} + +func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + prev := genesisBlock.Digest() + if index > 0 { + prev = blocks[index-1].Digest() + } + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + Metadata: StateMachineMetadata{ + PChainHeight: 100, + SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Epoch: 1, + Prev: prev, + }).Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{}, + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: uint64(index), + }, + }, + } +} + +func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h-startHeight) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + } +} + +type testConfig struct { + blockStore blockStore + approvalsRetriever approvalsRetriever + signatureVerifier signatureVerifier + signatureAggregator signatureAggregator + blockBuilder blockBuilder + keyAggregator keyAggregator + validatorSetRetriever validatorSetRetriever +} + +func newStateMachine(t *testing.T) (StateMachine, *testConfig) { + bs := make(blockStore) + + var testConfig testConfig + testConfig.blockStore = bs + testConfig.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, + } + + sm := StateMachine{ + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, + SignatureAggregator: &testConfig.signatureAggregator, + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, + GetPChainHeight: func() uint64 { + return 100 + }, + GetUpgrades: func() any { + return nil + }, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, + } + return sm, &testConfig +} + func TestIdentifyCurrentState(t *testing.T) { bvd := &BlockValidationDescriptor{} for _, tc := range []struct { diff --git a/msm/verification.go b/msm/verification.go new file mode 100644 index 00000000..ecf314aa --- /dev/null +++ b/msm/verification.go @@ -0,0 +1,515 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "bytes" + "encoding/binary" + "fmt" + "time" + + "github.com/ava-labs/simplex" +) + +type verificationInput struct { + prevMD StateMachineMetadata + proposedBlockMD StateMachineMetadata + hasInnerBlock bool + innerBlockTimestamp time.Time // only set when hasInnerBlock is true + prevBlockSeq uint64 + nextBlockType BlockType + state state +} + +type verifier interface { + Verify(in verificationInput) error +} +type validationDescriptorVerifier struct { + getValidatorSet ValidatorSetRetriever +} + +func (vd *validationDescriptorVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + switch in.nextBlockType { + case BlockTypeSealing: + return vd.verifySealingBlock(prev, next) + default: + return vd.verifyEmptyValidationDescriptor(prev, next) + } +} + +func (vd *validationDescriptorVerifier) verifySealingBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + validators, err := vd.getValidatorSet(prev.NextPChainReferenceHeight) + if err != nil { + return err + } + + if next.BlockValidationDescriptor == nil { + return fmt.Errorf("validation descriptor should not be nil for a sealing block") + } + + if !validators.Equal(next.BlockValidationDescriptor.AggregatedMembership.Members) { + return fmt.Errorf("expected validator set specified at P-chain height %d does not match validator set encoded in new block", next.NextPChainReferenceHeight) + } + + return nil +} + +func (vd *validationDescriptorVerifier) verifyEmptyValidationDescriptor(_ SimplexEpochInfo, next SimplexEpochInfo) error { + if next.BlockValidationDescriptor != nil { + return fmt.Errorf("block validation descriptor should be nil but got %v", next.BlockValidationDescriptor) + } + return nil +} + +type nextEpochApprovalsVerifier struct { + sigVerifier SignatureVerifier + getValidatorSet ValidatorSetRetriever + keyAggregator KeyAggregator +} + +func (nv *nextEpochApprovalsVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + switch in.nextBlockType { + case BlockTypeSealing: + return nv.verifySealingBlock(prev, next) + case BlockTypeNormal: + return nv.verifyNormal(prev, next) + default: + return nv.verifyEmptyNextEpochApprovals(prev, next) + } +} + +func (nv *nextEpochApprovalsVerifier) verifySealingBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if next.NextEpochApprovals == nil { + return fmt.Errorf("next epoch approvals should not be nil for a sealing block") + } + + validators, err := nv.getValidatorSet(prev.NextPChainReferenceHeight) + if err != nil { + return err + } + + err = nv.verifySignature(prev, next, validators) + if err != nil { + return err + } + + approvingNodes := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) + canSeal, err := canSealBlock(validators, approvingNodes) + if err != nil { + return err + } + + if !canSeal { + return fmt.Errorf("not enough approvals to seal block") + } + + return nil +} + +func (nv *nextEpochApprovalsVerifier) verifyNormal(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if prev.NextPChainReferenceHeight == 0 { + return nil + } + + // Otherwise, prev.NextPChainReferenceHeight > 0, so this means we're collecting approvals + + if next.NextEpochApprovals == nil { + // The node that proposed the block should have included at least its own approval. + return fmt.Errorf("next epoch approvals should not be nil when collecting approvals") + } + + validators, err := nv.getValidatorSet(prev.NextPChainReferenceHeight) + if err != nil { + return err + } + + err = nv.verifySignature(prev, next, validators) + if err != nil { + return err + } + + // A node cannot remove other nodes' approvals, only add its own approval if it wasn't included in the previous block. + // So the set of signers in next.NextEpochApprovals should be a superset of the set of signers in prev.NextEpochApprovals. + if err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev, next); err != nil { + return err + } + + return nil +} + +func (nv *nextEpochApprovalsVerifier) verifyEmptyNextEpochApprovals(_ SimplexEpochInfo, next SimplexEpochInfo) error { + if next.NextEpochApprovals != nil { + return fmt.Errorf("next epoch approvals should be nil but got %v", next.NextEpochApprovals) + } + return nil +} + +func (nv *nextEpochApprovalsVerifier) verifySignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { + // First figure out which validators are approving the next epoch by looking at the bitmask of approving nodes, + // and then aggregate their public keys together to verify the signature. + + nodeIDsBitmask := next.NextEpochApprovals.NodeIDs + aggPK, err := nv.aggregatePubKeysForBitmask(nodeIDsBitmask, validators) + if err != nil { + return err + } + + message := nv.createMessageToBeVerified(prev) + + if err := nv.sigVerifier.VerifySignature(next.NextEpochApprovals.Signature, message, aggPK); err != nil { + return fmt.Errorf("failed to verify signature: %w", err) + } + return nil +} + +func (nv *nextEpochApprovalsVerifier) createMessageToBeVerified(prev SimplexEpochInfo) []byte { + pChainHeightBuff := pChainNextReferenceHeightAsBytes(prev) + + var bb bytes.Buffer + bb.Write(pChainHeightBuff) + + message := bb.Bytes() + return message +} + +func (nv *nextEpochApprovalsVerifier) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { + approvingNodes := bitmaskFromBytes(nodeIDsBitmask) + publicKeys := make([][]byte, 0, len(validators)) + validators.ForEach(func(i int, nbm NodeBLSMapping) { + if !approvingNodes.Contains(i) { + return + } + publicKeys = append(publicKeys, nbm.BLSKey) + }) + + aggPK, err := nv.keyAggregator.AggregateKeys(publicKeys...) + if err != nil { + return nil, fmt.Errorf("failed to aggregate public keys: %w", err) + } + return aggPK, nil +} + +func pChainNextReferenceHeightAsBytes(prev SimplexEpochInfo) []byte { + pChainHeight := prev.NextPChainReferenceHeight + pChainHeightBuff := make([]byte, 8) + binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) + return pChainHeightBuff +} + +type nextPChainReferenceHeightVerifier struct { + getValidatorSet ValidatorSetRetriever + getPChainHeight func() uint64 +} + +func (n *nextPChainReferenceHeightVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + switch in.nextBlockType { + case BlockTypeTelock, BlockTypeSealing: + if prev.NextPChainReferenceHeight != next.NextPChainReferenceHeight { + return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) + } + case BlockTypeNormal: + return n.verifyNextPChainRefHeightNormal(in.prevMD, prev, next) + case BlockTypeNewEpoch: + if next.NextPChainReferenceHeight != 0 { + return fmt.Errorf("expected P-chain reference height to be 0 but got %d", next.NextPChainReferenceHeight) + } + default: + return fmt.Errorf("unknown block type: %d", in.nextBlockType) + } + return nil +} + +func (n *nextPChainReferenceHeightVerifier) verifyNextPChainRefHeightNormal(prevMD StateMachineMetadata, prev SimplexEpochInfo, next SimplexEpochInfo) error { + // Next P-chain height can only increase, not decrease. + if next.NextPChainReferenceHeight > 0 && prev.PChainReferenceHeight > next.NextPChainReferenceHeight { + return fmt.Errorf("expected P-chain reference height to be non-decreasing, "+ + "but the previous P-chain reference height is %d and the proposed P-chain reference height is %d", prev.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // If the previous block already has a next P-chain reference height, + // we should keep the same next P-chain reference height until we reach it. + if prev.NextPChainReferenceHeight > 0 { + if next.NextPChainReferenceHeight != prev.NextPChainReferenceHeight { + return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) + } + return nil + } + + // If we reached here, then prev.NextPChainReferenceHeight == 0. + // It might be that this block is the first block that has set the next P-chain reference height for the epoch, + // so check if it has done so correctly by observing whether the validator set has indeed changed. + + currentValidatorSet, err := n.getValidatorSet(prevMD.SimplexEpochInfo.PChainReferenceHeight) + if err != nil { + return err + } + + newValidatorSet, err := n.getValidatorSet(next.NextPChainReferenceHeight) + if err != nil { + return err + } + + // If the validator set doesn't change, we shouldn't have increased the next P-chain reference height. + if currentValidatorSet.Equal(newValidatorSet) && next.NextPChainReferenceHeight > 0 { + return fmt.Errorf("validator set at proposed next P-chain reference height %d is the same as "+ + "validator set at previous block's P-chain reference height %d,"+ + "so expected next P-chain reference height to remain the same but got %d", + next.NextPChainReferenceHeight, prev.PChainReferenceHeight, next.NextPChainReferenceHeight) + } + + // Else, either the validator set has changed, or the next P-chain reference height is still 0. + // Both of these cases are fine, but we should verify that we have observed the next P-chain reference height if it is > 0. + + pChainHeight := n.getPChainHeight() + + if pChainHeight < next.NextPChainReferenceHeight { + return fmt.Errorf("haven't reached P-chain height %d yet, current P-chain height is only %d", next.NextPChainReferenceHeight, pChainHeight) + } + + return nil +} + +type epochNumberVerifier struct{} + +func (e *epochNumberVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + // An epoch number of 0 means this is not a Simplex block, so the next block should be the first Simplex block with epoch number 1. + if in.prevMD.SimplexEpochInfo.EpochNumber == 0 && in.proposedBlockMD.SimplexEpochInfo.EpochNumber != 1 { + return fmt.Errorf("expected epoch number of the first block created to be 1 but got %d", next.EpochNumber) + } + + // The only time in which we should increase the epoch number is when we have a block that marks the start of a new epoch. + switch in.nextBlockType { + case BlockTypeNewEpoch: + // TODO: we have to make sure that Telocks are pruned before moving to a new epoch, otherwise we hit a false negative below. + if in.prevBlockSeq != next.EpochNumber { + return fmt.Errorf("expected epoch number to be %d but got %d", in.prevBlockSeq, next.EpochNumber) + } + default: + if prev.EpochNumber != next.EpochNumber { + return fmt.Errorf("expected epoch number to be %d but got %d", prev.EpochNumber, next.EpochNumber) + } + } + return nil +} + +type sealingBlockSeqVerifier struct{} + +func (s *sealingBlockSeqVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + // A block should only have a sealing block if it is a Telock. + switch in.nextBlockType { + case BlockTypeNewEpoch, BlockTypeNormal, BlockTypeSealing: + if next.SealingBlockSeq != 0 { + return fmt.Errorf("expected sealing block sequence number to be 0 but got %d", next.SealingBlockSeq) + } + case BlockTypeTelock: + // This is not the first Telock, make sure the sealing block sequence number doesn't change. + + // prev.SealingBlockSeq > 0 means the previous block is a Telock. + if prev.SealingBlockSeq > 0 && next.SealingBlockSeq != prev.SealingBlockSeq { + return fmt.Errorf("expected sealing block sequence number to be %d but got %d", prev.SealingBlockSeq, next.SealingBlockSeq) + } + + // Else, either this is the first Telock, or the previous block's sealing block sequence is equal to this block's sealing block sequence. + + // We need to check the first case has a valid sealing block sequence, as the second case is fine by definition. + if prev.BlockValidationDescriptor != nil { + md, err := simplex.ProtocolMetadataFromBytes(in.prevMD.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed parsing protocol metadata: %w", err) + } + if next.SealingBlockSeq != md.Seq { + return fmt.Errorf("expected sealing block sequence number to be %d but got %d", md.Seq, next.SealingBlockSeq) + } + } + default: + return fmt.Errorf("unknown block type: %d", in.nextBlockType) + } + + return nil +} + +type pChainHeightVerifier struct { + getPChainHeight func() uint64 +} + +func (p *pChainHeightVerifier) Verify(in verificationInput) error { + currentPChainHeight := p.getPChainHeight() + + if in.proposedBlockMD.PChainHeight > currentPChainHeight { + return fmt.Errorf("invalid P-chain height (%d) is too big, expected to be ≤ %d", + in.proposedBlockMD.PChainHeight, currentPChainHeight) + } + + if in.prevMD.PChainHeight > in.proposedBlockMD.PChainHeight { + return fmt.Errorf("invalid P-chain height (%d) is smaller than parent block's P-chain height (%d)", + in.proposedBlockMD.PChainHeight, in.prevMD.PChainHeight) + } + + return nil +} + +type pChainReferenceHeightVerifier struct{} + +func (p *pChainReferenceHeightVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + switch in.nextBlockType { + case BlockTypeNewEpoch: + if prev.NextPChainReferenceHeight != next.PChainReferenceHeight { + return fmt.Errorf("expected P-chain reference height of the first block of epoch %d to be %d but got %d", + prev.SealingBlockSeq, prev.NextPChainReferenceHeight, next.PChainReferenceHeight) + } + default: + if prev.PChainReferenceHeight != next.PChainReferenceHeight { + return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.PChainReferenceHeight, next.PChainReferenceHeight) + } + } + + return nil +} + +type timestampVerifier struct { + getTime func() time.Time + timeSkewLimit time.Duration +} + +func (t *timestampVerifier) Verify(in verificationInput) error { + if !in.hasInnerBlock { + // If no inner block, the timestamp is inherited from the parent block. + if in.proposedBlockMD.Timestamp != in.prevMD.Timestamp { + return fmt.Errorf("block without inner block should inherit parent timestamp %d but got %d", in.prevMD.Timestamp, in.proposedBlockMD.Timestamp) + } + } else { + // If there is an inner block, the timestamp should be the same as the inner block's timestamp. + if in.proposedBlockMD.Timestamp != uint64(in.innerBlockTimestamp.UnixMilli()) { + return fmt.Errorf("block timestamp %d does not match inner block timestamp %d", in.proposedBlockMD.Timestamp, in.innerBlockTimestamp.UnixMilli()) + } + } + + timestamp := time.UnixMilli(int64(in.proposedBlockMD.Timestamp)) + + currentTime := t.getTime() + if currentTime.Add(t.timeSkewLimit).Before(timestamp) { + return fmt.Errorf("proposed block timestamp is too far in the future, current time is %v but got %v", currentTime, timestamp) + } + + if in.prevMD.Timestamp > in.proposedBlockMD.Timestamp { + return fmt.Errorf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", in.prevMD.Timestamp, in.proposedBlockMD.Timestamp) + } + return nil +} + +type prevSealingBlockHashVerifier struct { + getBlock BlockRetriever + latestPersistedHeight *uint64 +} + +func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { + prev, _ := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + // Sealing block of the first epoch must point to the first ever Simplex block as the previous sealing block. + if prev.EpochNumber == 1 && in.nextBlockType == BlockTypeSealing { + firstEverSimplexBlockSeq, err := findFirstSimplexBlock(p.getBlock, *p.latestPersistedHeight+1) + if err != nil { + return fmt.Errorf("failed to find first Simplex block: %w", err) + } + + block, _, err := p.getBlock(RetrievingOpts{Height: firstEverSimplexBlockSeq}) + if err != nil { + return fmt.Errorf("failed retrieving first ever simplex block %d: %w", firstEverSimplexBlockSeq, err) + } + + hash := block.Digest() + if !bytes.Equal(in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash[:], hash[:]) { + return fmt.Errorf("expected prev sealing block hash of the first ever simplex block to be %x but got %x", hash, in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) + } + + return nil + } + + // Otherwise, we can only have a previous sealing block hash if this is a sealing block, + // and in that case, the previous sealing block hash should match the hash of the sealing block of the previous epoch. + + switch in.nextBlockType { + case BlockTypeSealing: + prevSealingBlock, _, err := p.getBlock(RetrievingOpts{Height: in.prevMD.SimplexEpochInfo.EpochNumber}) + if err != nil { + return fmt.Errorf("failed retrieving block: %w", err) + } + hash := prevSealingBlock.Digest() + if !bytes.Equal(in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash[:], hash[:]) { + return fmt.Errorf("expected prev sealing block hash to be %x but got %x", hash, in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) + } + default: // non-sealing blocks should have an empty previous sealing block hash + if in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash != [32]byte{} { + return fmt.Errorf("expected prev sealing block hash of a non sealing block to be empty but got %x", in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) + } + } + + return nil +} + +type vmBlockSeqVerifier struct { + getBlock BlockRetriever +} + +func (v *vmBlockSeqVerifier) Verify(in verificationInput) error { + prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo + + // If this is the first ever Simplex block, the PrevVMBlockSeq is simply the seq of the previous block. + if prev.EpochNumber == 0 { + if next.PrevVMBlockSeq != in.prevBlockSeq { + return fmt.Errorf("expected PrevVMBlockSeq to be %d but got %d", in.prevBlockSeq, next.PrevVMBlockSeq) + } + return nil + } + + md, err := simplex.ProtocolMetadataFromBytes(in.proposedBlockMD.SimplexProtocolMetadata) + if err != nil { + return fmt.Errorf("failed parsing protocol metadata: %w", err) + } + + // Else, if the previous block has an inner block, we point to it. + // Otherwise, we point to the parent block's previous VM block seq. + prevBlock, _, err := v.getBlock(RetrievingOpts{Height: in.prevBlockSeq, Digest: md.Prev}) + if err != nil { + return fmt.Errorf("failed retrieving block: %w", err) + } + + expectedPrevVMBlockSeq := in.prevMD.SimplexEpochInfo.PrevVMBlockSeq + + if prevBlock.InnerBlock != nil { + expectedPrevVMBlockSeq = in.prevBlockSeq + } + + if next.PrevVMBlockSeq != expectedPrevVMBlockSeq { + return fmt.Errorf("expected PrevVMBlockSeq to be %d but got %d", expectedPrevVMBlockSeq, next.PrevVMBlockSeq) + } + + return nil +} + +func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { + if prev.NextEpochApprovals == nil { + return nil + } + // Make sure that previous signers are still there. + prevSigners := bitmaskFromBytes(prev.NextEpochApprovals.NodeIDs) + nextSigners := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) + // Remove all bits in nextSigners from prevSigners + prevSigners.Difference(&nextSigners) + // If we have some bits left, it means there was a bit in prevSigners that wasn't in nextSigners + if prevSigners.Len() > 0 { + return fmt.Errorf("some signers from parent block are missing from next epoch approvals of proposed block") + } + return nil +} diff --git a/msm/verification_test.go b/msm/verification_test.go new file mode 100644 index 00000000..7881a133 --- /dev/null +++ b/msm/verification_test.go @@ -0,0 +1,1016 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package metadata + +import ( + "context" + "crypto/sha256" + "fmt" + "testing" + "time" + + "github.com/ava-labs/simplex" + "github.com/stretchr/testify/require" +) + +func TestPChainHeightVerifier(t *testing.T) { + for _, tc := range []struct { + name string + pChainHeight uint64 + prevHeight uint64 + nextHeight uint64 + err string + }{ + { + name: "valid height", + pChainHeight: 200, + prevHeight: 100, + nextHeight: 150, + }, + { + name: "height equal to current", + pChainHeight: 200, + prevHeight: 100, + nextHeight: 200, + }, + { + name: "height too big", + pChainHeight: 100, + prevHeight: 50, + nextHeight: 150, + err: "invalid P-chain height (150) is too big, expected to be ≤ 100", + }, + { + name: "height smaller than parent", + pChainHeight: 200, + prevHeight: 150, + nextHeight: 100, + err: "invalid P-chain height (100) is smaller than parent block's P-chain height (150)", + }, + { + name: "height equal to parent", + pChainHeight: 200, + prevHeight: 100, + nextHeight: 100, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &pChainHeightVerifier{ + getPChainHeight: func() uint64 { return tc.pChainHeight }, + } + err := v.Verify(verificationInput{ + prevMD: StateMachineMetadata{PChainHeight: tc.prevHeight}, + proposedBlockMD: StateMachineMetadata{PChainHeight: tc.nextHeight}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestTimestampVerifier(t *testing.T) { + now := time.Now() + + timeSkewLimit := 5 * time.Second + + futureTime := now.Add(10 * time.Second) + + for _, tc := range []struct { + name string + hasInnerBlock bool + innerBlockTimestamp time.Time + timestamp uint64 + parentTimestamp uint64 + err string + }{ + { + name: "valid timestamp with inner block", + hasInnerBlock: true, + innerBlockTimestamp: now, + timestamp: uint64(now.UnixMilli()), + }, + { + name: "metadata timestamp does not match inner block", + hasInnerBlock: true, + innerBlockTimestamp: now, + timestamp: uint64(now.UnixMilli()) + 100, + err: fmt.Sprintf("block timestamp %d does not match inner block timestamp %d", uint64(now.UnixMilli())+100, now.UnixMilli()), + }, + { + name: "timestamp too far in the future", + hasInnerBlock: true, + innerBlockTimestamp: futureTime, + timestamp: uint64(futureTime.UnixMilli()), + err: fmt.Sprintf("proposed block timestamp is too far in the future, current time is %v but got %v", now, time.UnixMilli(futureTime.UnixMilli())), + }, + { + name: "timestamp older than parent", + hasInnerBlock: true, + innerBlockTimestamp: now, + timestamp: uint64(now.UnixMilli()), + parentTimestamp: uint64(now.UnixMilli()) + 10, + err: fmt.Sprintf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", uint64(now.UnixMilli())+10, uint64(now.UnixMilli())), + }, + { + name: "no inner block inherits parent timestamp", + hasInnerBlock: false, + timestamp: uint64(now.UnixMilli()), + parentTimestamp: uint64(now.UnixMilli()), + }, + { + name: "no inner block with different timestamp than parent", + hasInnerBlock: false, + timestamp: uint64(now.UnixMilli()) + 100, + parentTimestamp: uint64(now.UnixMilli()), + err: fmt.Sprintf("block without inner block should inherit parent timestamp %d but got %d", uint64(now.UnixMilli()), uint64(now.UnixMilli())+100), + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := ×tampVerifier{ + getTime: func() time.Time { return now }, + timeSkewLimit: timeSkewLimit, + } + err := v.Verify(verificationInput{ + hasInnerBlock: tc.hasInnerBlock, + innerBlockTimestamp: tc.innerBlockTimestamp, + proposedBlockMD: StateMachineMetadata{Timestamp: tc.timestamp}, + prevMD: StateMachineMetadata{Timestamp: tc.parentTimestamp}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestPChainReferenceHeightVerifier(t *testing.T) { + for _, tc := range []struct { + name string + nextBlockType BlockType + prev SimplexEpochInfo + next SimplexEpochInfo + err string + }{ + { + name: "new epoch block matching prev NextPChainReferenceHeight", + nextBlockType: BlockTypeNewEpoch, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200, SealingBlockSeq: 5}, + next: SimplexEpochInfo{PChainReferenceHeight: 200}, + }, + { + name: "new epoch block not matching prev NextPChainReferenceHeight", + nextBlockType: BlockTypeNewEpoch, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200, SealingBlockSeq: 5}, + next: SimplexEpochInfo{PChainReferenceHeight: 100}, + err: "expected P-chain reference height of the first block of epoch 5 to be 200 but got 100", + }, + { + name: "normal block matching prev PChainReferenceHeight", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{PChainReferenceHeight: 100}, + }, + { + name: "normal block not matching prev PChainReferenceHeight", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{PChainReferenceHeight: 200}, + err: "expected P-chain reference height to be 100 but got 200", + }, + { + name: "sealing block matching prev PChainReferenceHeight", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{PChainReferenceHeight: 100}, + }, + { + name: "telock block matching prev PChainReferenceHeight", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{PChainReferenceHeight: 100}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &pChainReferenceHeightVerifier{} + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestEpochNumberVerifier(t *testing.T) { + for _, tc := range []struct { + name string + nextBlockType BlockType + prevBlockSeq uint64 + prev SimplexEpochInfo + next SimplexEpochInfo + err string + }{ + { + name: "prev epoch 0 with wrong next epoch", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{EpochNumber: 0}, + next: SimplexEpochInfo{EpochNumber: 5}, + err: "expected epoch number of the first block created to be 1 but got 5", + }, + { + name: "new epoch block matching sealing seq", + nextBlockType: BlockTypeNewEpoch, + prevBlockSeq: 10, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{EpochNumber: 10}, + }, + { + name: "new epoch block not matching sealing seq", + nextBlockType: BlockTypeNewEpoch, + prevBlockSeq: 10, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{EpochNumber: 5}, + err: "expected epoch number to be 10 but got 5", + }, + { + name: "normal block same epoch", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{EpochNumber: 3}, + next: SimplexEpochInfo{EpochNumber: 3}, + }, + { + name: "normal block different epoch", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{EpochNumber: 3}, + next: SimplexEpochInfo{EpochNumber: 4}, + err: "expected epoch number to be 3 but got 4", + }, + { + name: "sealing block same epoch", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{EpochNumber: 2}, + next: SimplexEpochInfo{EpochNumber: 2}, + }, + { + name: "telock block same epoch", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{EpochNumber: 2}, + next: SimplexEpochInfo{EpochNumber: 2}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &epochNumberVerifier{} + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevBlockSeq: tc.prevBlockSeq, + prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestPrevSealingBlockHashVerifier(t *testing.T) { + // A simplex block (EpochNumber > 0) so findFirstSimplexBlock can locate it. + firstSimplexBlock := StateMachineBlock{ + InnerBlock: &testVMBlock{bytes: []byte{1, 2, 3}}, + Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, + } + firstSimplexBlockHash := firstSimplexBlock.Digest() + + // A block used for epoch >1 sealing lookups. + prevSealingBlock := StateMachineBlock{ + InnerBlock: &testVMBlock{bytes: []byte{4, 5, 6}}, + Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 5}}, + } + prevSealingBlockHash := prevSealingBlock.Digest() + + bs := make(testBlockStore) + bs[1] = firstSimplexBlock + bs[5] = prevSealingBlock + latestPersisted := uint64(1) + + for _, tc := range []struct { + name string + nextBlockType BlockType + prev SimplexEpochInfo + next SimplexEpochInfo + err string + }{ + { + name: "epoch 1 sealing block with correct hash", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{ + PrevSealingBlockHash: firstSimplexBlockHash, + }, + }, + { + name: "epoch 1 sealing block with wrong hash", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{9, 9, 9}, + }, + err: fmt.Sprintf("expected prev sealing block hash of the first ever simplex block to be %x but got %x", firstSimplexBlockHash, [32]byte{9, 9, 9}), + }, + { + name: "epoch >1 sealing block with correct hash", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{EpochNumber: 5}, + next: SimplexEpochInfo{ + PrevSealingBlockHash: prevSealingBlockHash, + }, + }, + { + name: "epoch >1 sealing block with wrong hash", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{EpochNumber: 5}, + next: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{9, 9, 9}, + }, + err: fmt.Sprintf("expected prev sealing block hash to be %x but got %x", prevSealingBlockHash, [32]byte{9, 9, 9}), + }, + { + name: "non-sealing block with empty hash", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{}, + }, + { + name: "non-sealing block with non-empty hash", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{EpochNumber: 1}, + next: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{1}, + }, + err: fmt.Sprintf("expected prev sealing block hash of a non sealing block to be empty but got %x", [32]byte{1}), + }, + { + name: "telock block with empty hash", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{EpochNumber: 2}, + next: SimplexEpochInfo{}, + }, + { + name: "new epoch block with empty hash", + nextBlockType: BlockTypeNewEpoch, + prev: SimplexEpochInfo{EpochNumber: 2}, + next: SimplexEpochInfo{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &prevSealingBlockHashVerifier{ + getBlock: bs.getBlock, + latestPersistedHeight: &latestPersisted, + } + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestNextPChainReferenceHeightVerifier(t *testing.T) { + validators1 := NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}} + validators2 := NodeBLSMappings{{BLSKey: []byte{2}, Weight: 1}} + + for _, tc := range []struct { + name string + nextBlockType BlockType + prev SimplexEpochInfo + prevPChainRef uint64 + next SimplexEpochInfo + getValidator ValidatorSetRetriever + pChainHeight uint64 + err string + }{ + { + name: "telock block matching height", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + }, + { + name: "telock block mismatched height", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 300}, + err: "expected P-chain reference height to be 200 but got 300", + }, + { + name: "sealing block matching height", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + }, + { + name: "sealing block mismatched height", + nextBlockType: BlockTypeSealing, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, + err: "expected P-chain reference height to be 200 but got 100", + }, + { + name: "normal block prev already has next height set", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + }, + { + name: "normal block prev already has next height set mismatch", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 300}, + err: "expected P-chain reference height to be 200 but got 300", + }, + { + name: "normal block next p-chain reference height less than current", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 200}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, + err: "expected P-chain reference height to be non-decreasing, but the previous P-chain reference height is 200 and the proposed P-chain reference height is 100", + }, + { + name: "normal block same validator set with non-zero next height", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators1, nil }, + err: "validator set at proposed next P-chain reference height 200 is the same as validator set at previous block's P-chain reference height 100,so expected next P-chain reference height to remain the same but got 200", + }, + { + name: "normal block no validator change and next height is zero", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 0}, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators1, nil }, + }, + { + name: "normal block validator change detected and p-chain height reached", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + getValidator: func(h uint64) (NodeBLSMappings, error) { + if h == 200 { + return validators2, nil + } + return validators1, nil + }, + pChainHeight: 200, + }, + { + name: "normal block validator change but p-chain height not reached", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{PChainReferenceHeight: 100}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, + getValidator: func(h uint64) (NodeBLSMappings, error) { + if h == 200 { + return validators2, nil + } + return validators1, nil + }, + pChainHeight: 150, + err: "haven't reached P-chain height 200 yet, current P-chain height is only 150", + }, + { + name: "new epoch block with zero next height", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{NextPChainReferenceHeight: 0}, + }, + { + name: "new epoch block with non-zero next height", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, + err: "expected P-chain reference height to be 0 but got 100", + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &nextPChainReferenceHeightVerifier{ + getValidatorSet: tc.getValidator, + getPChainHeight: func() uint64 { return tc.pChainHeight }, + } + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestVMBlockSeqVerifier(t *testing.T) { + prevMDBytes := (&simplex.ProtocolMetadata{Seq: 5, Prev: [32]byte{1}}).Bytes() + proposedMDBytes := (&simplex.ProtocolMetadata{Seq: 6, Prev: [32]byte{2}}).Bytes() + + blockWithInner := StateMachineBlock{ + InnerBlock: &testVMBlock{bytes: []byte{1}}, + } + blockWithoutInner := StateMachineBlock{} + + for _, tc := range []struct { + name string + prev SimplexEpochInfo + prevMD StateMachineMetadata + next SimplexEpochInfo + prevBlockSeq uint64 + block StateMachineBlock + err string + }{ + { + name: "first simplex block matching seq", + prev: SimplexEpochInfo{EpochNumber: 0}, + next: SimplexEpochInfo{PrevVMBlockSeq: 42}, + prevBlockSeq: 42, + }, + { + name: "first simplex block wrong seq", + prev: SimplexEpochInfo{EpochNumber: 0}, + next: SimplexEpochInfo{PrevVMBlockSeq: 10}, + prevBlockSeq: 42, + err: "expected PrevVMBlockSeq to be 42 but got 10", + }, + { + name: "prev block has block", + prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, + next: SimplexEpochInfo{PrevVMBlockSeq: 4}, + prevBlockSeq: 4, + block: blockWithInner, + }, + { + name: "prev block has block wrong seq", + prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, + next: SimplexEpochInfo{PrevVMBlockSeq: 99}, + prevBlockSeq: 4, + block: blockWithInner, + err: "expected PrevVMBlockSeq to be 4 but got 99", + }, + { + name: "prev block has no block uses parent PrevVMBlockSeq", + prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, + next: SimplexEpochInfo{PrevVMBlockSeq: 3}, + prevBlockSeq: 4, + block: blockWithoutInner, + }, + { + name: "prev block has no block wrong seq", + prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, + next: SimplexEpochInfo{PrevVMBlockSeq: 99}, + prevBlockSeq: 4, + block: blockWithoutInner, + err: "expected PrevVMBlockSeq to be 3 but got 99", + }, + } { + t.Run(tc.name, func(t *testing.T) { + bs := make(testBlockStore) + bs[tc.prevBlockSeq] = tc.block + + v := &vmBlockSeqVerifier{ + getBlock: bs.getBlock, + } + + prevMD := tc.prevMD + if prevMD.SimplexEpochInfo.EpochNumber == 0 && tc.prev.EpochNumber == 0 { + prevMD.SimplexEpochInfo = tc.prev + } + + err := v.Verify(verificationInput{ + prevMD: prevMD, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next, SimplexProtocolMetadata: proposedMDBytes}, + prevBlockSeq: tc.prevBlockSeq, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidationDescriptorVerifier(t *testing.T) { + validators := NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, + {BLSKey: []byte{2}, Weight: 1}, + } + + otherValidators := NodeBLSMappings{ + {BLSKey: []byte{3}, Weight: 1}, + } + + for _, tc := range []struct { + name string + nextBlockType BlockType + next SimplexEpochInfo + getValidator ValidatorSetRetriever + err string + }{ + { + name: "sealing block with matching validators", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{Members: validators}, + }, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + }, + { + name: "sealing block with mismatching validators", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + BlockValidationDescriptor: &BlockValidationDescriptor{ + AggregatedMembership: AggregatedMembership{Members: otherValidators}, + }, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + err: "expected validator set specified at P-chain height 100 does not match validator set encoded in new block", + }, + { + name: "sealing block with validator retrieval error", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + BlockValidationDescriptor: &BlockValidationDescriptor{}, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return nil, fmt.Errorf("unavailable") }, + err: "unavailable", + }, + { + name: "normal block with nil descriptor", + nextBlockType: BlockTypeNormal, + next: SimplexEpochInfo{}, + }, + { + name: "normal block with non-nil descriptor", + nextBlockType: BlockTypeNormal, + next: SimplexEpochInfo{ + BlockValidationDescriptor: &BlockValidationDescriptor{}, + }, + err: "block validation descriptor should be nil but got &{{[] {0}} {0}}", + }, + { + name: "telock block with nil descriptor", + nextBlockType: BlockTypeTelock, + next: SimplexEpochInfo{}, + }, + { + name: "new epoch block with nil descriptor", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &validationDescriptorVerifier{ + getValidatorSet: tc.getValidator, + } + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestNextEpochApprovalsVerifier(t *testing.T) { + validators := NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, + {BLSKey: []byte{2}, Weight: 1}, + {BLSKey: []byte{3}, Weight: 1}, + } + + for _, tc := range []struct { + name string + nextBlockType BlockType + prev SimplexEpochInfo + next SimplexEpochInfo + getValidator ValidatorSetRetriever + sigVerifier SignatureVerifier + keyAggregator KeyAggregator + err string + }{ + { + name: "sealing block with nil approvals", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{}, + err: "next epoch approvals should not be nil for a sealing block", + }, + { + name: "sealing block with validator retrieval error", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{7}, Signature: []byte("sig")}, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return nil, fmt.Errorf("unavailable") }, + err: "unavailable", + }, + { + name: "sealing block not enough approvals", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + sigVerifier: &testSigVerifier{}, + keyAggregator: &testKeyAggregator{}, + err: "not enough approvals to seal block", + }, + { + name: "sealing block enough approvals", + nextBlockType: BlockTypeSealing, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{7}, Signature: []byte("sig")}, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + sigVerifier: &testSigVerifier{}, + keyAggregator: &testKeyAggregator{}, + }, + { + name: "normal block no validator change", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 0}, + next: SimplexEpochInfo{}, + }, + { + name: "normal block collecting approvals with nil approvals", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{NextPChainReferenceHeight: 100}, + next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, + err: "next epoch approvals should not be nil when collecting approvals", + }, + { + name: "normal block collecting approvals valid", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + PChainReferenceHeight: 50, + }, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + sigVerifier: &testSigVerifier{}, + keyAggregator: &testKeyAggregator{}, + }, + { + name: "normal block collecting approvals signers not superset of prev", + nextBlockType: BlockTypeNormal, + prev: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + PChainReferenceHeight: 50, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{3}, Signature: []byte("sig")}, // bits 0,1 + }, + next: SimplexEpochInfo{ + NextPChainReferenceHeight: 100, + NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, // bit 0 only + }, + getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, + sigVerifier: &testSigVerifier{}, + keyAggregator: &testKeyAggregator{}, + err: "some signers from parent block are missing from next epoch approvals of proposed block", + }, + { + name: "telock block with nil approvals", + nextBlockType: BlockTypeTelock, + next: SimplexEpochInfo{}, + }, + { + name: "telock block with non-nil approvals", + nextBlockType: BlockTypeTelock, + next: SimplexEpochInfo{ + NextEpochApprovals: &NextEpochApprovals{}, + }, + err: "next epoch approvals should be nil but got &{[] [] {0}}", + }, + { + name: "new epoch block with nil approvals", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{}, + }, + { + name: "new epoch block with non-nil approvals", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{ + NextEpochApprovals: &NextEpochApprovals{}, + }, + err: "next epoch approvals should be nil but got &{[] [] {0}}", + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &nextEpochApprovalsVerifier{ + sigVerifier: tc.sigVerifier, + getValidatorSet: tc.getValidator, + keyAggregator: tc.keyAggregator, + } + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +func TestSealingBlockSeqVerifier(t *testing.T) { + prevProtocolMD := (&simplex.ProtocolMetadata{Seq: 5}).Bytes() + + for _, tc := range []struct { + name string + nextBlockType BlockType + prev SimplexEpochInfo + prevMD StateMachineMetadata + next SimplexEpochInfo + err string + }{ + { + name: "normal block with zero sealing seq", + nextBlockType: BlockTypeNormal, + next: SimplexEpochInfo{SealingBlockSeq: 0}, + }, + { + name: "normal block with non-zero sealing seq", + nextBlockType: BlockTypeNormal, + next: SimplexEpochInfo{SealingBlockSeq: 5}, + err: "expected sealing block sequence number to be 0 but got 5", + }, + { + name: "new epoch block with zero sealing seq", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{SealingBlockSeq: 0}, + }, + { + name: "new epoch block with non-zero sealing seq", + nextBlockType: BlockTypeNewEpoch, + next: SimplexEpochInfo{SealingBlockSeq: 3}, + err: "expected sealing block sequence number to be 0 but got 3", + }, + { + name: "telock block matching prev sealing seq", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{SealingBlockSeq: 10}, + next: SimplexEpochInfo{SealingBlockSeq: 10}, + }, + { + name: "telock block mismatching prev sealing seq", + nextBlockType: BlockTypeTelock, + prev: SimplexEpochInfo{SealingBlockSeq: 10}, + next: SimplexEpochInfo{SealingBlockSeq: 11}, + err: "expected sealing block sequence number to be 10 but got 11", + }, + { + name: "sealing block with zero seq", + nextBlockType: BlockTypeSealing, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevProtocolMD}, + next: SimplexEpochInfo{SealingBlockSeq: 0}, + }, + { + name: "sealing block with non-zero seq", + nextBlockType: BlockTypeSealing, + prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevProtocolMD}, + next: SimplexEpochInfo{SealingBlockSeq: 10}, + err: "expected sealing block sequence number to be 0 but got 10", + }, + } { + t.Run(tc.name, func(t *testing.T) { + v := &sealingBlockSeqVerifier{} + prevMD := tc.prevMD + prevMD.SimplexEpochInfo = tc.prev + err := v.Verify(verificationInput{ + nextBlockType: tc.nextBlockType, + prevMD: prevMD, + proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, + }) + if tc.err != "" { + require.EqualError(t, err, tc.err) + } else { + require.NoError(t, err) + } + }) + } +} + +// Test helpers + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[opts.Height] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, opts.Height) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +} + +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} + +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) +} + +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context) error { + return nil +} From b9604659838ce925ff91aa5fd0bcf56711bd226a Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 5 May 2026 22:02:01 +0200 Subject: [PATCH 02/16] rebase Signed-off-by: Yacov Manevich --- msm/encoding.go | 20 ++++------- msm/fake_node_test.go | 2 +- msm/msm.go | 66 ++++++++++++++++--------------------- msm/msm_test.go | 71 ++++++++++++++++++++++++++-------------- msm/verification.go | 11 ++++--- msm/verification_test.go | 1 + 6 files changed, 89 insertions(+), 82 deletions(-) diff --git a/msm/encoding.go b/msm/encoding.go index 59e425ad..a2eafb7d 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -199,7 +199,7 @@ func (nea *NextEpochApprovals) Equals(other *NextEpochApprovals) bool { type NodeBLSMappings []NodeBLSMapping -func (nbms NodeBLSMappings) TotalWeight() (int64, error) { +func (nbms NodeBLSMappings) TotalWeight() (uint64, error) { var totalWeight uint64 for _, nbm := range nbms { var err error @@ -216,27 +216,19 @@ func (nbms NodeBLSMappings) TotalWeight() (int64, error) { if totalWeight > math.MaxInt64 { return 0, fmt.Errorf("total weight of validators is too big, overflows int64: %d", totalWeight) } - return int64(totalWeight), nil + return totalWeight, nil } -func (nbms NodeBLSMappings) ApprovingWeight(approvingNodes bitmask) (int64, error) { - var approvingWeight uint64 +func (nbms NodeBLSMappings) ApprovingWeights(approvingNodes bitmask) []uint64 { + approvingWeights := make([]uint64, 0, len(nbms)) for i, nbm := range nbms { if !approvingNodes.Contains(i) { continue } - var err error - approvingWeight, err = safeAdd(approvingWeight, nbm.Weight) - if err != nil { - return 0, fmt.Errorf("failed to compute approving weights: %w", err) - } - } - - if approvingWeight > math.MaxInt64 { - return 0, fmt.Errorf("approving weight of validators is too big, overflows int64: %d", approvingWeight) + approvingWeights = append(approvingWeights, nbm.Weight) } - return int64(approvingWeight), nil + return approvingWeights } func (nbms NodeBLSMappings) Clone() NodeBLSMappings { diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index 4c980a44..3f5e6130 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -328,7 +328,7 @@ func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { fn.t.Logf("Building a block on top of %s parent with epoch %d", finalizedString, parentBlock.Metadata.SimplexEpochInfo.EpochNumber) - block, err := fn.sm.BuildBlock(context.Background(), parentBlock, simplex.ProtocolMetadata{ + block, err := fn.sm.BuildBlock(context.Background(), simplex.ProtocolMetadata{ Seq: lastMD.Seq + 1, Round: lastMD.Round + 1, Prev: prevBlockDigest, diff --git a/msm/msm.go b/msm/msm.go index cf78fffe..02262587 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "math" - "math/big" "sort" "time" @@ -43,6 +42,8 @@ func (smb *StateMachineBlock) Digest() [32]byte { // Used to aggregate validator signatures for epoch transitions. type SignatureAggregator interface { AggregateSignatures(signatures ...[]byte) ([]byte, error) + + IsQuorum(approverWeights []uint64, totalWeight uint64) bool } // ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. @@ -50,13 +51,11 @@ type ApprovalsRetriever interface { RetrieveApprovals() ValidatorSetApprovals } - // KeyAggregator combines multiple public keys into a single aggregated public key. type KeyAggregator interface { AggregateKeys(keys ...[]byte) ([]byte, error) } - // SignatureVerifier verifies a cryptographic signature against a message and public key. // Used to verify Approvals from validators for epoch transitions. type SignatureVerifier interface { @@ -106,7 +105,7 @@ type StateMachine struct { // BlockBuilder builds new VM blocks. BlockBuilder BlockBuilder // Logger is used for logging state machine operations. - Logger Logger + Logger simplex.Logger // GetValidatorSet retrieves the validator set at a given P-chain height. GetValidatorSet ValidatorSetRetriever // GetBlock retrieves a previously built or finalized block. @@ -140,9 +139,8 @@ const ( stateBuildBlockEpochSealed ) - // BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. -func (sm *StateMachine) BuildBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata simplex.ProtocolMetadata, simplexBlacklist *simplex.Blacklist) (*StateMachineBlock, error) { +func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex.ProtocolMetadata, simplexBlacklist *simplex.Blacklist) (*StateMachineBlock, error) { sm.maybeInit() // The zero sequence number is reserved for the genesis block, which should never be built. @@ -150,6 +148,11 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, parentBlock StateMachine return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", simplexMetadata.Seq) } + parentBlock, _, err := sm.GetBlock(RetrievingOpts{Height: simplexMetadata.Seq - 1, Digest: simplexMetadata.Prev}) + if err != nil { + return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", simplexMetadata.Seq-1, simplexMetadata.Prev.String(), err) + } + start := time.Now() sm.Logger.Debug("Building block", @@ -174,10 +177,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, parentBlock StateMachine // In order to know where in the epoch change process we are, // we identify the current state by looking at the parent block's epoch info. - currentState, err := identifyCurrentState(parentBlock.Metadata.SimplexEpochInfo) - if err != nil { - return nil, err - } + currentState := identifyCurrentState(parentBlock.Metadata.SimplexEpochInfo) simplexMetadataBytes := simplexMetadata.Bytes() prevBlockSeq := simplexMetadata.Seq - 1 @@ -222,10 +222,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc } prevMD := prevBlock.Metadata - currentState, err := identifyCurrentState(prevMD.SimplexEpochInfo) - if err != nil { - return fmt.Errorf("failed to identify previous state: %w", err) - } + currentState := identifyCurrentState(prevMD.SimplexEpochInfo) switch currentState { case stateFirstSimplexBlock: @@ -273,6 +270,7 @@ func (sm *StateMachine) init() { getValidatorSet: sm.GetValidatorSet, keyAggregator: sm.KeyAggregator, sigVerifier: sm.SignatureVerifier, + sigAggregator: sm.SignatureAggregator, }, &sealingBlockSeqVerifier{}, } @@ -341,7 +339,6 @@ func identifyCurrentState(prevBlockSimplexEpochInfo SimplexEpochInfo) state { return stateBuildCollectingApprovals } - // buildBlockNormalOp builds a block while not trying to transition to a new epoch. func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { // Since in the previous block, we were not transitioning to a new epoch, @@ -363,16 +360,16 @@ func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock Stat var childBlock VMBlock switch decisionToBuildBlock { - case blockBuildingDecisionBuildBlock, blockBuildingDecisionBuildBlockAndTransitionEpoch: + case decisionBuild, decisionBuildAndTransitionEpoch: // If we reached here, we need to build a new block, and maybe also transition to a new epoch. return sm.buildBlockAndMaybeTransitionEpoch(ctx, parentBlock, simplexMetadata, simplexBlacklist, childBlock, decisionToBuildBlock, newSimplexEpochInfo, pChainHeight) - case blockBuildingDecisionTransitionEpoch: + case decisionTransitionEpoch: // If we reached here, we don't need to build an inner block, yet we need to transition to a new epoch. // Initiate the epoch transition by setting the next P-chain reference height for the new epoch info, // and build a block without an inner block. newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil - case blockBuildingDecisionContextCanceled: + case decisionContextCanceled: return nil, ctx.Err() default: return nil, fmt.Errorf("unknown block building decision %d", decisionToBuildBlock) @@ -383,10 +380,10 @@ func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock blockBuildingDecider := blockBuildingDecider{ logger: sm.Logger, maxBlockBuildingWaitTime: sm.MaxBlockBuildingWaitTime, - pChainlistener: sm.PChainProgressListener, + pChainListener: sm.PChainProgressListener, getPChainHeight: sm.GetPChainHeight, waitForPendingBlock: sm.BlockBuilder.WaitForPendingBlock, - shouldTransitionEpoch: func(pChainHeight uint64) (bool, error) { + hasValidatorSetChanged: func(pChainHeight uint64) (bool, error) { // The given pChainHeight was sampled by the caller of shouldTransitionEpoch(). // We compare between the current validator set, defined by the P-chain reference height in the parent block, // and the new validator set defined by the given pChainHeight. @@ -418,7 +415,7 @@ func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, p return nil, err } - if decisionToBuildBlock == blockBuildingDecisionBuildBlockAndTransitionEpoch { + if decisionToBuildBlock == decisionBuildAndTransitionEpoch { // We need to also transition to a new epoch, in addition to building an inner block, // so set the next P-chain reference height for the new epoch info. newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight @@ -559,7 +556,7 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sm.SignatureAggregator, validators) + newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sm.SignatureAggregator, validators, sm.SignatureAggregator) if err != nil { return nil, err } @@ -752,6 +749,7 @@ func computeNewApprovals( pChainHeight uint64, aggregator SignatureAggregator, validators NodeBLSMappings, + sigAggr SignatureAggregator, ) (*approvals, error) { if nextEpochApprovals == nil { nextEpochApprovals = &NextEpochApprovals{} @@ -761,9 +759,9 @@ func computeNewApprovals( // We map each validator to its relative index in the validator set. nodeID2ValidatorIndex := make(map[nodeID]int) - validators.ForEach(func(i int, nbm NodeBLSMapping) { + for i, nbm := range validators { nodeID2ValidatorIndex[nbm.NodeID] = i - }) + } // We have the approvals obtained from peers, but we need to sanitize them by filtering out approvals that are not valid, // such as approvals that do not agree with our candidate auxiliary info digest and P-Chain height, @@ -778,7 +776,7 @@ func computeNewApprovals( // we check if we have enough approvals to seal the epoch by computing the relative approval ratio, // which is the ratio of the total weight of approving nodes divided by the total weight of all validators. - canSeal, err := canSealBlock(validators, newApprovingNodes) + canSeal, err := canSealBlock(validators, newApprovingNodes, sigAggr) if err != nil { return nil, err } @@ -834,23 +832,15 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova return aggregatedSignature, newApprovingNodes, nil } -func canSealBlock(validators NodeBLSMappings, newApprovingNodes bitmask) (bool, error) { - approvingWeight, err := computeApprovingWeight(validators, &newApprovingNodes) - if err != nil { - return false, err - } +func canSealBlock(validators NodeBLSMappings, newApprovingNodes bitmask, sigAggr SignatureAggregator) (bool, error) { + approvingWeights := validators.ApprovingWeights(newApprovingNodes) - totalWeight, err := computeTotalWeight(validators) + totalWeight, err := validators.TotalWeight() if err != nil { return false, err } - threshold := big.NewRat(2, 3) - - approvingRatio := big.NewRat(approvingWeight, totalWeight) - - canSeal := approvingRatio.Cmp(threshold) > 0 - return canSeal, nil + return sigAggr.IsQuorum(approvingWeights, totalWeight), nil } // sanitizeApprovals filters out approvals that are not valid by checking if they agree with our candidate auxiliary info digest and P-Chain height, @@ -954,4 +944,4 @@ type approvals struct { canSeal bool nodeIDs []byte signature []byte -} \ No newline at end of file +} diff --git a/msm/msm_test.go b/msm/msm_test.go index bb9e8a6b..c8651a15 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -82,6 +82,14 @@ func (sv *signatureAggregator) AggregateSignatures(signatures ...[]byte) ([]byte return bytes, nil } +func (sv *signatureAggregator) IsQuorum(approverWeights []uint64, totalWeights uint64) bool { + var sum uint64 + for _, w := range approverWeights { + sum += w + } + return sum*3 > totalWeights*2 +} + type noOpPChainListener struct{} func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { @@ -265,7 +273,7 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { testCase.configure(&sm2, testConfig2) } - block, err := sm1.BuildBlock(context.Background(), genesisBlock, testCase.md, nil) + block, err := sm1.BuildBlock(context.Background(), testCase.md, nil) require.NoError(t, err) require.NotNil(t, block) @@ -305,6 +313,10 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { sm1, testConfig1 := newStateMachine(t) sm2, testConfig2 := newStateMachine(t) + testConfig1.blockStore[0] = &outerBlock{ + block: preSimplexParent, + } + testConfig1.blockStore[42] = &outerBlock{block: preSimplexParent} testConfig2.blockStore[42] = &outerBlock{block: preSimplexParent} @@ -314,7 +326,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { Bytes: []byte{7, 8, 9}, } - block, err := sm1.BuildBlock(context.Background(), preSimplexParent, md, nil) + block, err := sm1.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.NotNil(t, block) @@ -488,7 +500,7 @@ func TestMSMNormalOp(t *testing.T) { testCase.setup(&sm2, testConfig2) } - block1, err := sm1.BuildBlock(context.Background(), lastBlock, *md, &blacklist) + block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) require.NoError(t, err) require.NotNil(t, block1) @@ -603,6 +615,8 @@ func TestMSMFullEpochLifecycle(t *testing.T) { sm, tc := newStateMachine(t) sm.GetValidatorSet = getValidatorSet sm.GetPChainHeight = getPChainHeight + tc.blockStore[0] = &outerBlock{block: genesis} + tc.blockStore[42] = &outerBlock{block: notGenesis} smVerify, tcVerify := newStateMachine(t) smVerify.GetValidatorSet = getValidatorSet @@ -628,7 +642,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { Prev: testCase.firstBlockBeforeSimplex.Digest(), } - block1, err := sm.BuildBlock(context.Background(), testCase.firstBlockBeforeSimplex, md, nil) + block1, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(1), @@ -659,7 +673,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // ----- Step 2: Build a normal block (no validator set change) ----- tc.blockBuilder.block = nextBlock(2) md = simplex.ProtocolMetadata{Seq: baseSeq + 2, Round: 1, Epoch: 1, Prev: block1.Digest()} - block2, err := sm.BuildBlock(context.Background(), *block1, md, nil) + block2, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(2), @@ -684,7 +698,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { tc.blockBuilder.block = nextBlock(3) md = simplex.ProtocolMetadata{Seq: baseSeq + 3, Round: 2, Epoch: 1, Prev: block2.Digest()} - block3, err := sm.BuildBlock(context.Background(), *block2, md, nil) + block3, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(3), @@ -725,7 +739,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { tc.blockBuilder.block = nextBlock(4) md = simplex.ProtocolMetadata{Seq: baseSeq + 4, Round: 3, Epoch: 1, Prev: block3.Digest()} - block4, err := sm.BuildBlock(context.Background(), *block3, md, nil) + block4, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(4), @@ -765,7 +779,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { tc.blockBuilder.block = nextBlock(5) md = simplex.ProtocolMetadata{Seq: baseSeq + 5, Round: 4, Epoch: 1, Prev: block4.Digest()} - block5, err := sm.BuildBlock(context.Background(), *block4, md, nil) + block5, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(5), @@ -805,7 +819,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { tc.blockBuilder.block = nextBlock(6) md = simplex.ProtocolMetadata{Seq: baseSeq + 6, Round: 5, Epoch: 1, Prev: block5.Digest()} - block6, err := sm.BuildBlock(context.Background(), *block5, md, nil) + block6, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(6), @@ -874,7 +888,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // However, despite the fact that the block builder is willing to build a new block, // a Telock shouldn't contain an inner block. if tc.blockStore[sealingSeq].finalization == nil { - telock, err := sm.BuildBlock(context.Background(), *block6, md, nil) + telock, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ @@ -899,7 +913,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // ----- Step 7: Build a new epoch block (sealing block is finalized) ----- - block7, err := sm.BuildBlock(context.Background(), *block6, md, nil) + block7, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ InnerBlock: nextBlock(7), @@ -1016,6 +1030,7 @@ type testConfig struct { func newStateMachine(t *testing.T) (StateMachine, *testConfig) { bs := make(blockStore) + bs[0] = &outerBlock{block: genesisBlock} var testConfig testConfig testConfig.blockStore = bs @@ -1213,7 +1228,7 @@ func TestComputeTotalWeight(t *testing.T) { } total, err := validators.TotalWeight() require.NoError(t, err) - require.Equal(t, int64(600), total) + require.Equal(t, uint64(600), total) }) t.Run("zero total weight", func(t *testing.T) { @@ -1237,30 +1252,26 @@ func TestComputeApprovingWeight(t *testing.T) { t.Run("all approving", func(t *testing.T) { bm := bitmaskFromBytes([]byte{7}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(600), weight) + weights := validators.ApprovingWeights(bm) + require.Equal(t, []uint64{100, 200, 300}, weights) }) t.Run("partial approving", func(t *testing.T) { bm := bitmaskFromBytes([]byte{5}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(400), weight) + weights := validators.ApprovingWeights(bm) + require.Equal(t, []uint64{100, 300}, weights) }) t.Run("none approving", func(t *testing.T) { bm := bitmaskFromBytes(nil) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(0), weight) + weights := validators.ApprovingWeights(bm) + require.Empty(t, weights) }) t.Run("single validator approving", func(t *testing.T) { bm := bitmaskFromBytes([]byte{2}) - weight, err := validators.ApprovingWeight(bm) - require.NoError(t, err) - require.Equal(t, int64(200), weight) + weights := validators.ApprovingWeights(bm) + require.Equal(t, []uint64{200}, weights) }) } @@ -1327,12 +1338,24 @@ func (concatAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { return bytes.Join(sigs, nil), nil } +func (concatAggregator) IsQuorum(approverWeights []uint64, totalWeights uint64) bool { + var sum uint64 + for _, w := range approverWeights { + sum += w + } + return sum*3 >= totalWeights*2 +} + type failingAggregator struct{} func (failingAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { return nil, fmt.Errorf("aggregation failed") } +func (failingAggregator) IsQuorum([]uint64, uint64) bool { + return false +} + func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} diff --git a/msm/verification.go b/msm/verification.go index ecf314aa..6ca66019 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -67,6 +67,7 @@ type nextEpochApprovalsVerifier struct { sigVerifier SignatureVerifier getValidatorSet ValidatorSetRetriever keyAggregator KeyAggregator + sigAggregator SignatureAggregator } func (nv *nextEpochApprovalsVerifier) Verify(in verificationInput) error { @@ -98,7 +99,7 @@ func (nv *nextEpochApprovalsVerifier) verifySealingBlock(prev SimplexEpochInfo, } approvingNodes := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) - canSeal, err := canSealBlock(validators, approvingNodes) + canSeal, err := canSealBlock(validators, approvingNodes, nv.sigAggregator) if err != nil { return err } @@ -179,12 +180,12 @@ func (nv *nextEpochApprovalsVerifier) createMessageToBeVerified(prev SimplexEpoc func (nv *nextEpochApprovalsVerifier) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { approvingNodes := bitmaskFromBytes(nodeIDsBitmask) publicKeys := make([][]byte, 0, len(validators)) - validators.ForEach(func(i int, nbm NodeBLSMapping) { + for i := range validators { if !approvingNodes.Contains(i) { - return + continue } - publicKeys = append(publicKeys, nbm.BLSKey) - }) + publicKeys = append(publicKeys, validators[i].BLSKey) + } aggPK, err := nv.keyAggregator.AggregateKeys(publicKeys...) if err != nil { diff --git a/msm/verification_test.go b/msm/verification_test.go index 7881a133..4ea6fbeb 100644 --- a/msm/verification_test.go +++ b/msm/verification_test.go @@ -844,6 +844,7 @@ func TestNextEpochApprovalsVerifier(t *testing.T) { sigVerifier: tc.sigVerifier, getValidatorSet: tc.getValidator, keyAggregator: tc.keyAggregator, + sigAggregator: &signatureAggregator{}, } err := v.Verify(verificationInput{ nextBlockType: tc.nextBlockType, From 80044bfac8452e72f739316cb2179676924bd4c9 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 16 Apr 2026 22:47:56 +0200 Subject: [PATCH 03/16] msm: consolidate notarized/finalized blocks into blockState slice Addresses review comments: finalized blocks were a prefix of notarized blocks, so keeping two separate slices was redundant and error prone (e.g. tryFinalizeNextBlock panicked when the two went out of sync). Replaces both with a single blocks []blockState slice where each entry carries a finalized flag. Also removes the duplicate lookup in the GetBlock test fixture that would return finalized blocks as non-finalized. Co-Authored-By: Claude Opus 4.7 (1M context) --- msm/fake_node_test.go | 153 ++++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 71 deletions(-) diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index 3f5e6130..de50a4bb 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -112,7 +112,7 @@ func TestFakeNodeEmptyMempool(t *testing.T) { node.mempoolEmpty = true // We build blocks until the sealing block is finalized. - for node.finalizedBlocks[len(node.finalizedBlocks)-1].Metadata.SimplexEpochInfo.BlockValidationDescriptor == nil { + for node.lastFinalizedBlock().Metadata.SimplexEpochInfo.BlockValidationDescriptor == nil { node.act() if flipCoin() { node.sm.ApprovalsRetriever = &approvalsRetriever{ @@ -168,13 +168,19 @@ type innerBlock struct { Prev [32]byte } +type blockState struct { + block StateMachineBlock + finalized bool + innerBlock VMBlock +} + type fakeNode struct { - t *testing.T - sm StateMachine - mempoolEmpty bool - notarizedBlocks []StateMachineBlock - finalizedBlocks []StateMachineBlock - innerChain []innerBlock + t *testing.T + sm StateMachine + mempoolEmpty bool + // blocks holds notarized blocks in order. Finalized blocks always form a + // prefix: all finalized entries precede all non-finalized entries. + blocks []blockState } func (fn *fakeNode) WaitForProgress(ctx context.Context, pChainHeight uint64) error { @@ -212,28 +218,21 @@ func newFakeNode(t *testing.T) *fakeNode { if opts.Height == 0 { return genesisBlock, nil, nil } - for _, block := range fn.finalizedBlocks { - if block.Digest() == opts.Digest { - return block, &simplex.Finalization{}, nil - } - md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) - if err != nil { - return StateMachineBlock{}, nil, err - } - if md.Seq == opts.Height { - return block, &simplex.Finalization{}, nil - } - } - for _, block := range fn.notarizedBlocks { - if block.Digest() == opts.Digest { - return block, nil, nil + for _, bs := range fn.blocks { + match := bs.block.Digest() == opts.Digest + if !match { + md, err := simplex.ProtocolMetadataFromBytes(bs.block.Metadata.SimplexProtocolMetadata) + if err != nil { + return StateMachineBlock{}, nil, err + } + match = md.Seq == opts.Height } - md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) - if err != nil { - return StateMachineBlock{}, nil, err - } - if md.Seq == opts.Height { - return block, nil, nil + if match { + var fin *simplex.Finalization + if bs.finalized { + fin = &simplex.Finalization{} + } + return bs.block, fin, nil } } @@ -244,12 +243,29 @@ func newFakeNode(t *testing.T) *fakeNode { return fn } +// lastFinalizedBlock returns the most recently finalized block. +// Panics if nothing has been finalized. +func (fn *fakeNode) lastFinalizedBlock() StateMachineBlock { + for i := len(fn.blocks) - 1; i >= 0; i-- { + if fn.blocks[i].finalized { + return fn.blocks[i].block + } + } + panic("no finalized block") +} + func (fn *fakeNode) Height() uint64 { - return uint64(len(fn.finalizedBlocks)) + var count uint64 + for _, bs := range fn.blocks { + if bs.finalized { + count++ + } + } + return count } func (fn *fakeNode) Epoch() uint64 { - return fn.notarizedBlocks[len(fn.notarizedBlocks)-1].Metadata.SimplexEpochInfo.EpochNumber + return fn.blocks[len(fn.blocks)-1].block.Metadata.SimplexEpochInfo.EpochNumber } func (fn *fakeNode) act() { @@ -266,18 +282,27 @@ func (fn *fakeNode) act() { } func (fn *fakeNode) canFinalize() bool { - return len(fn.notarizedBlocks) > len(fn.finalizedBlocks) + return fn.nextUnfinalizedIndex() < len(fn.blocks) +} + +func (fn *fakeNode) nextUnfinalizedIndex() int { + for i, bs := range fn.blocks { + if !bs.finalized { + return i + } + } + return len(fn.blocks) } func (fn *fakeNode) tryFinalizeNextBlock() { - nextIndex := len(fn.finalizedBlocks) + nextIndex := fn.nextUnfinalizedIndex() - if fn.isNextBlockTelock() { + if fn.isNextBlockTelock(nextIndex) { return } - block := fn.notarizedBlocks[nextIndex] - fn.finalizedBlocks = append(fn.finalizedBlocks, block) + fn.blocks[nextIndex].finalized = true + block := fn.blocks[nextIndex].block md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) require.NoError(fn.t, err) @@ -287,27 +312,23 @@ func (fn *fakeNode) tryFinalizeNextBlock() { // If we just finalized a sealing block, trim trailing Telock blocks. if block.Metadata.SimplexEpochInfo.BlockValidationDescriptor != nil { - fn.notarizedBlocks = fn.notarizedBlocks[:len(fn.finalizedBlocks)] - fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.notarizedBlocks)) + fn.blocks = fn.blocks[:nextIndex+1] + fn.t.Logf("Trimmed notarized blocks, new length: %d", len(fn.blocks)) } } -func (fn *fakeNode) isNextBlockTelock() bool { - if len(fn.finalizedBlocks) == 0 { +func (fn *fakeNode) isNextBlockTelock(nextIndex int) bool { + if nextIndex == 0 { return false } - return fn.notarizedBlocks[len(fn.finalizedBlocks)].Metadata.SimplexEpochInfo.SealingBlockSeq > 0 + return fn.blocks[nextIndex].block.Metadata.SimplexEpochInfo.SealingBlockSeq > 0 } func (fn *fakeNode) buildAndNotarizeBlock() { vmBlock, block := fn.buildBlock() require.NoError(fn.t, fn.sm.VerifyBlock(context.Background(), block)) - fn.notarizedBlocks = append(fn.notarizedBlocks, *block) - - if vmBlock != nil { - fn.innerChain = append(fn.innerChain, *vmBlock.(*innerBlock)) - } + fn.blocks = append(fn.blocks, blockState{block: *block, innerBlock: vmBlock}) } func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { @@ -342,8 +363,8 @@ func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetada var lastMD *simplex.ProtocolMetadata var err error lastBlockDigest := genesisBlock.Digest() - if len(fn.notarizedBlocks) > 0 { - lastBlock := fn.notarizedBlocks[len(fn.notarizedBlocks)-1] + if len(fn.blocks) > 0 { + lastBlock := fn.blocks[len(fn.blocks)-1].block lastBlockDigest = lastBlock.Digest() lastMD, err = simplex.ProtocolMetadataFromBytes(lastBlock.Metadata.SimplexProtocolMetadata) require.NoError(fn.t, err) @@ -358,8 +379,8 @@ func (fn *fakeNode) prepareMetadataAndPrevBlockDigest() (*simplex.ProtocolMetada func (fn *fakeNode) BuildBlock(context.Context, uint64) (VMBlock, error) { // Count the number of inner blocks in the chain var count int - for _, block := range fn.notarizedBlocks { - if block.InnerBlock != nil { + for _, bs := range fn.blocks { + if bs.block.InnerBlock != nil { count++ } } @@ -376,34 +397,24 @@ func (fn *fakeNode) BuildBlock(context.Context, uint64) (VMBlock, error) { } func (fn *fakeNode) getParentBlock() StateMachineBlock { - var parentBlock StateMachineBlock - if len(fn.notarizedBlocks) > 0 { - parentBlock = fn.notarizedBlocks[len(fn.notarizedBlocks)-1] - } else { - gb := genesisBlock.InnerBlock.(*InnerBlock) - parentBlock = StateMachineBlock{ - InnerBlock: &innerBlock{ - InnerBlock: *gb, - }, - } + if len(fn.blocks) > 0 { + return fn.blocks[len(fn.blocks)-1].block + } + gb := genesisBlock.InnerBlock.(*InnerBlock) + return StateMachineBlock{ + InnerBlock: &innerBlock{ + InnerBlock: *gb, + }, } - return parentBlock } func (fn *fakeNode) getLastVMBlockDigest() [32]byte { - var lastVMBlockDigest = genesisBlock.Digest() - - notarizedBlocks := fn.notarizedBlocks - for len(notarizedBlocks) > 0 { - lastNotarizedBlock := notarizedBlocks[len(notarizedBlocks)-1] - if lastNotarizedBlock.InnerBlock == nil { - notarizedBlocks = notarizedBlocks[:len(notarizedBlocks)-1] - continue + for i := len(fn.blocks) - 1; i >= 0; i-- { + if fn.blocks[i].block.InnerBlock != nil { + return fn.blocks[i].block.Digest() } - lastVMBlockDigest = lastNotarizedBlock.Digest() - break } - return lastVMBlockDigest + return genesisBlock.Digest() } func randomBuff(n int) []byte { From 84c769656b60c2101ea4cf390b6a046bc4851194 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 5 May 2026 22:48:19 +0200 Subject: [PATCH 04/16] Remove RetrievingOpts in favor of explicit parameters Signed-off-by: Yacov Manevich --- msm/fake_node_test.go | 19 ++++++++----------- msm/msm.go | 24 ++++++++---------------- msm/msm_test.go | 20 ++++++++++---------- msm/verification.go | 6 +++--- msm/verification_test.go | 6 +++--- 5 files changed, 32 insertions(+), 43 deletions(-) diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index de50a4bb..e4084e12 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -169,8 +169,8 @@ type innerBlock struct { } type blockState struct { - block StateMachineBlock - finalized bool + block StateMachineBlock + finalized bool innerBlock VMBlock } @@ -214,18 +214,18 @@ func newFakeNode(t *testing.T) *fakeNode { fn.sm.BlockBuilder = fn fn.sm.PChainProgressListener = fn - fn.sm.GetBlock = func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - if opts.Height == 0 { + fn.sm.GetBlock = func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + if seq == 0 { return genesisBlock, nil, nil } for _, bs := range fn.blocks { - match := bs.block.Digest() == opts.Digest + match := bs.block.Digest() == digest if !match { md, err := simplex.ProtocolMetadataFromBytes(bs.block.Metadata.SimplexProtocolMetadata) if err != nil { return StateMachineBlock{}, nil, err } - match = md.Seq == opts.Height + match = md.Seq == seq } if match { var fin *simplex.Finalization @@ -236,7 +236,7 @@ func newFakeNode(t *testing.T) *fakeNode { } } - require.Failf(t, "not found block", "height: %d", opts.Height) + require.Failf(t, "not found block", "height: %d", seq) return StateMachineBlock{}, nil, fmt.Errorf("block not found") } @@ -336,10 +336,7 @@ func (fn *fakeNode) buildBlock() (VMBlock, *StateMachineBlock) { lastMD, prevBlockDigest := fn.prepareMetadataAndPrevBlockDigest() - _, finalization, err := fn.sm.GetBlock(RetrievingOpts{ - Digest: prevBlockDigest, - Height: lastMD.Seq, - }) + _, finalization, err := fn.sm.GetBlock(lastMD.Seq, prevBlockDigest) require.NoError(fn.t, err) finalizedString := "not finalized" diff --git a/msm/msm.go b/msm/msm.go index 02262587..668e5c5b 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -65,18 +65,10 @@ type SignatureVerifier interface { // ValidatorSetRetriever retrieves the validator set at a given P-chain height. type ValidatorSetRetriever func(pChainHeight uint64) (NodeBLSMappings, error) -// RetrievingOpts specifies the options for retrieving a block by height and/or digest. -type RetrievingOpts struct { - // Height is the sequence number of the block to retrieve. - Height uint64 - // Digest is the expected hash of the block, used for validation. - Digest [32]byte -} - -// BlockRetriever retrieves a block and its finalization status given the retrieval options. +// BlockRetriever retrieves a block and its finalization status given the block's sequence number and expected digest. // If the block cannot be found it returns ErrBlockNotFound. // If an error occurs during retrieval, it returns a non-nil error. -type BlockRetriever func(RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) +type BlockRetriever func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) // BlockBuilder builds a new VM block with the given observed P-chain height. type BlockBuilder interface { @@ -148,7 +140,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex. return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", simplexMetadata.Seq) } - parentBlock, _, err := sm.GetBlock(RetrievingOpts{Height: simplexMetadata.Seq - 1, Digest: simplexMetadata.Prev}) + parentBlock, _, err := sm.GetBlock(simplexMetadata.Seq-1, simplexMetadata.Prev) if err != nil { return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", simplexMetadata.Seq-1, simplexMetadata.Prev.String(), err) } @@ -216,7 +208,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc return fmt.Errorf("attempted to build a genesis inner block") } - prevBlock, _, err := sm.GetBlock(RetrievingOpts{Digest: pmd.Prev, Height: seq - 1}) + prevBlock, _, err := sm.GetBlock(seq-1, pmd.Prev) if err != nil { return fmt.Errorf("failed to retrieve previous (%d) inner block: %w", seq-1, err) } @@ -617,7 +609,7 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat // If this is not the first epoch, and this is the sealing block, we set the hash of the previous sealing block. if simplexEpochInfo.EpochNumber > 1 { - prevSealingBlock, finalization, err := sm.GetBlock(RetrievingOpts{Height: simplexEpochInfo.EpochNumber}) + prevSealingBlock, finalization, err := sm.GetBlock(simplexEpochInfo.EpochNumber, [32]byte{}) if err != nil { sm.Logger.Error("Error retrieving previous sealing block", zap.Uint64("seq", simplexEpochInfo.EpochNumber), zap.Error(err)) return nil, fmt.Errorf("failed to retrieve previous sealing InnerBlock at epoch %d: %w", simplexEpochInfo.EpochNumber-1, err) @@ -633,7 +625,7 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat if err != nil { return nil, fmt.Errorf("failed to find first simplex block: %w", err) } - firstSimplexBlockRetrieved, _, err := sm.GetBlock(RetrievingOpts{Height: firstSimplexBlock}) + firstSimplexBlockRetrieved, _, err := sm.GetBlock(firstSimplexBlock, [32]byte{}) if err != nil { return nil, fmt.Errorf("failed to retrieve first simplex block at height %d: %w", firstSimplexBlock, err) } @@ -692,7 +684,7 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S PrevVMBlockSeq: computePrevVMBlockSeq(parentBlock, prevBlockSeq), } - _, finalization, err := sm.GetBlock(RetrievingOpts{Height: sealingBlockSeq}) + _, finalization, err := sm.GetBlock(sealingBlockSeq, [32]byte{}) if err != nil { return nil, fmt.Errorf("failed to retrieve sealing block at sequence %d: %w", sealingBlockSeq, err) } @@ -881,7 +873,7 @@ func findFirstSimplexBlock(getBlock BlockRetriever, endHeight uint64) (uint64, e if haltError != nil { return true } - block, _, err := getBlock(RetrievingOpts{Height: uint64(i)}) + block, _, err := getBlock(uint64(i), [32]byte{}) if errors.Is(err, simplex.ErrBlockNotFound) { return false } diff --git a/msm/msm_test.go b/msm/msm_test.go index c8651a15..c76da241 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -43,10 +43,10 @@ func (bs blockStore) clone() blockStore { return newStore } -func (bs blockStore) getBlock(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - blk, exits := bs[opts.Height] +func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[seq] if !exits { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, opts.Height) + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) } return blk.block, blk.finalization, nil } @@ -1167,7 +1167,7 @@ func TestComputePrevVMBlockSeq(t *testing.T) { func TestFindFirstSimplexBlock(t *testing.T) { t.Run("endHeight too big", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { return StateMachineBlock{}, nil, nil } _, err := findFirstSimplexBlock(getBlock, math.MaxUint64) @@ -1175,8 +1175,8 @@ func TestFindFirstSimplexBlock(t *testing.T) { }) t.Run("found at height 3", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - if opts.Height < 3 { + getBlock := func(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + if seq < 3 { return StateMachineBlock{}, nil, nil } return StateMachineBlock{ @@ -1189,7 +1189,7 @@ func TestFindFirstSimplexBlock(t *testing.T) { }) t.Run("no simplex blocks found", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { return StateMachineBlock{}, nil, nil } _, err := findFirstSimplexBlock(getBlock, 5) @@ -1197,8 +1197,8 @@ func TestFindFirstSimplexBlock(t *testing.T) { }) t.Run("block not found errors are skipped", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - if opts.Height < 2 { + getBlock := func(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + if seq < 2 { return StateMachineBlock{}, nil, simplex.ErrBlockNotFound } return StateMachineBlock{ @@ -1211,7 +1211,7 @@ func TestFindFirstSimplexBlock(t *testing.T) { }) t.Run("retrieval error propagated", func(t *testing.T) { - getBlock := func(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { + getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { return StateMachineBlock{}, nil, fmt.Errorf("disk error") } _, err := findFirstSimplexBlock(getBlock, 5) diff --git a/msm/verification.go b/msm/verification.go index 6ca66019..d3af5cc3 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -424,7 +424,7 @@ func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { return fmt.Errorf("failed to find first Simplex block: %w", err) } - block, _, err := p.getBlock(RetrievingOpts{Height: firstEverSimplexBlockSeq}) + block, _, err := p.getBlock(firstEverSimplexBlockSeq, [32]byte{}) if err != nil { return fmt.Errorf("failed retrieving first ever simplex block %d: %w", firstEverSimplexBlockSeq, err) } @@ -442,7 +442,7 @@ func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { switch in.nextBlockType { case BlockTypeSealing: - prevSealingBlock, _, err := p.getBlock(RetrievingOpts{Height: in.prevMD.SimplexEpochInfo.EpochNumber}) + prevSealingBlock, _, err := p.getBlock(in.prevMD.SimplexEpochInfo.EpochNumber, [32]byte{}) if err != nil { return fmt.Errorf("failed retrieving block: %w", err) } @@ -481,7 +481,7 @@ func (v *vmBlockSeqVerifier) Verify(in verificationInput) error { // Else, if the previous block has an inner block, we point to it. // Otherwise, we point to the parent block's previous VM block seq. - prevBlock, _, err := v.getBlock(RetrievingOpts{Height: in.prevBlockSeq, Digest: md.Prev}) + prevBlock, _, err := v.getBlock(in.prevBlockSeq, md.Prev) if err != nil { return fmt.Errorf("failed retrieving block: %w", err) } diff --git a/msm/verification_test.go b/msm/verification_test.go index 4ea6fbeb..43f99a58 100644 --- a/msm/verification_test.go +++ b/msm/verification_test.go @@ -942,10 +942,10 @@ func TestSealingBlockSeqVerifier(t *testing.T) { type testBlockStore map[uint64]StateMachineBlock -func (bs testBlockStore) getBlock(opts RetrievingOpts) (StateMachineBlock, *simplex.Finalization, error) { - blk, ok := bs[opts.Height] +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] if !ok { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, opts.Height) + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) } return blk, nil, nil } From 5911aee56cffb974b7f033e55bd43e993ec30be0 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Thu, 7 May 2026 18:34:04 +0200 Subject: [PATCH 05/16] Use signature aggregator creator Signed-off-by: Yacov Manevich --- api.go | 27 +++++++- blacklist.go | 2 +- blacklist_test.go | 4 +- epoch.go | 44 ++++++------ epoch_failover_test.go | 20 +++--- epoch_test.go | 53 +++++++++------ global.go | 11 +++ msm/encoding.go | 37 ++++------ msm/msm.go | 48 ++++--------- msm/msm_test.go | 132 +++++++++++++----------------------- msm/verification.go | 8 +-- msm/verification_test.go | 8 +-- pos_test.go | 20 +++--- recovery_test.go | 49 +++++++------ replication_test.go | 32 +++++---- replication_timeout_test.go | 5 +- testutil/comm.go | 16 ++--- testutil/controlled.go | 8 +-- testutil/network.go | 21 +++--- testutil/node.go | 6 +- testutil/util.go | 22 +++++- 21 files changed, 298 insertions(+), 275 deletions(-) diff --git a/api.go b/api.go index 50fa88e3..2e01b78d 100644 --- a/api.go +++ b/api.go @@ -58,7 +58,7 @@ type Storage interface { type Communication interface { // Nodes returns all nodes that participate in the epoch. - Nodes() []NodeID + Nodes() NodeWeights // Send sends a message to the given destination node Send(msg *Message, destination NodeID) @@ -130,8 +130,33 @@ type SignatureAggregator interface { // Aggregate aggregates several signatures into a QuorumCertificate Aggregate([]Signature) (QuorumCertificate, error) + // AppendSignatures appends signatures to an existing signature. + // If the existing signature is empty, it just aggregates the given signatures. + AppendSignatures([]byte, ...[]byte) ([]byte, error) + // IsQuorum returns true if the given signers constitute a quorum. // In the case of PoA, this means at least a quorum of the nodes are given. // In the case of PoS, this means at least two thirds of the st. IsQuorum([]NodeID) bool } + +// NodeWeights is a list of NodeWeight elements. +type NodeWeights []NodeWeight + +// NodesIDs returns the NodeIDs of the nodes in the NodeWeights. +func (nws NodeWeights) NodesIDs() []NodeID { + nodes := make([]NodeID, len(nws)) + for i, nw := range nws { + nodes[i] = nw.Node + } + return nodes +} + +// NodeWeight is a struct that pairs a node with its weight in the signature aggregator. +type NodeWeight struct { + Node NodeID + Weight uint64 +} + +// SignatureAggregatorCreator creates a SignatureAggregator from a list of nodes and their weights. +type SignatureAggregatorCreator func([]NodeWeight) SignatureAggregator diff --git a/blacklist.go b/blacklist.go index 842ba638..edc7abbb 100644 --- a/blacklist.go +++ b/blacklist.go @@ -206,7 +206,7 @@ func (bl *Blacklist) ApplyUpdates(updates []BlacklistUpdate, round uint64) Black } // garbageCollectSuspectedNodes returns a new list of suspected nodes for the given round. -// Nodes that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. +// NodesIDs that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. // It will also garbage-collect any redeem votes from past orbits, unless hey have surpassed the threshold of f+1. // It does not modify the current blacklist. func (bl *Blacklist) garbageCollectSuspectedNodes(round uint64) SuspectedNodes { diff --git a/blacklist_test.go b/blacklist_test.go index f47fb5ea..faee4080 100644 --- a/blacklist_test.go +++ b/blacklist_test.go @@ -465,8 +465,8 @@ func TestComputeBlacklistUpdates(t *testing.T) { func TestAdvanceRound(t *testing.T) { nodes := []uint16{0, 1, 2, 3} - // Nodes 0, 2 are suspected. - // Nodes 1 and 3 are not suspected. + // NodesIDs 0, 2 are suspected. + // NodesIDs 1 and 3 are not suspected. // Node 2 can be redeemed. suspectedNodesBefore := SuspectedNodes{ {NodeIndex: 0, SuspectingCount: 2, OrbitSuspected: 1, RedeemingCount: 1, OrbitToRedeem: 1}, diff --git a/epoch.go b/epoch.go index c81a51e8..220abce8 100644 --- a/epoch.go +++ b/epoch.go @@ -69,7 +69,7 @@ type EpochConfig struct { Signer Signer Verifier SignatureVerifier BlockDeserializer BlockDeserializer - SignatureAggregator SignatureAggregator + SignatureAggregatorCreator SignatureAggregatorCreator Comm Communication Storage Storage WAL WriteAheadLog @@ -83,7 +83,8 @@ type EpochConfig struct { type Epoch struct { EpochConfig // Runtime - oneTimeVerifier *OneTimeVerifier + signatureAggregator SignatureAggregator + oneTimeVerifier *OneTimeVerifier buildBlockScheduler *BasicScheduler blockVerificationScheduler *BlockDependencyManager lock sync.Mutex @@ -94,6 +95,7 @@ type Epoch struct { blockBuilderCtx context.Context blockBuilderCancelFunc context.CancelFunc nodes NodeIDs + nodeWeights NodeWeights eligibleNodeIDs map[string]struct{} rounds map[uint64]*Round emptyVotes map[uint64]*EmptyVoteSet @@ -198,8 +200,9 @@ func (e *Epoch) init() error { e.finishCtx, e.finishFn = context.WithCancel(context.Background()) e.blockBuilderCtx = context.Background() e.blockBuilderCancelFunc = func() {} - e.nodes = e.Comm.Nodes() - SortNodes(e.nodes) + e.nodeWeights = e.Comm.Nodes() + SortNodesWeights(e.nodeWeights) + e.nodes = e.nodeWeights.NodesIDs() e.timedOutRounds = make(map[uint16]uint64, len(e.nodes)) e.redeemedRounds = make(map[uint16]uint64, len(e.nodes)) e.rounds = make(map[uint64]*Round) @@ -208,6 +211,7 @@ func (e *Epoch) init() error { e.futureMessages = make(messagesFromNode, len(e.nodes)) e.replicationState = NewReplicationState(e.Logger, e.Comm, e.ID, e.MaxRoundWindow, e.ReplicationEnabled, e.StartTime, &e.lock, e.RandomSource) e.timeoutHandler = NewTimeoutHandler(e.Logger, "emptyVoteRebroadcast", e.StartTime, e.MaxRebroadcastWait, e.emptyVoteTimeoutTaskRunner) + e.signatureAggregator = e.SignatureAggregatorCreator(e.nodeWeights) for _, node := range e.nodes { e.futureMessages[string(node)] = make(map[uint64]*messagesForRound) @@ -739,7 +743,7 @@ func (e *Epoch) handleFinalizationMessage(message *Finalization, from NodeID) er return nil } - if err := VerifyQC(message.QC, e.Logger, "Finalization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { + if err := VerifyQC(message.QC, e.Logger, "Finalization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { e.Logger.Debug("Received an invalid finalization", zap.Int("round", int(message.Finalization.Round)), zap.Stringer("NodeID", from)) @@ -1159,7 +1163,7 @@ func (e *Epoch) maybeCollectFinalization(round *Round) error { var finalizations []*FinalizeVote for _, finalizationsWithTheSameDigest := range finalizationsByMD { - if e.SignatureAggregator.IsQuorum(NodeIDsFromVotes(finalizationsWithTheSameDigest)) { + if e.signatureAggregator.IsQuorum(NodeIDsFromVotes(finalizationsWithTheSameDigest)) { finalizations = finalizationsWithTheSameDigest break } @@ -1174,7 +1178,7 @@ func (e *Epoch) maybeCollectFinalization(round *Round) error { } func (e *Epoch) assembleFinalization(round *Round, finalizationVotes []*FinalizeVote) error { - finalization, err := NewFinalization(e.Logger, e.SignatureAggregator, finalizationVotes) + finalization, err := NewFinalization(e.Logger, e.signatureAggregator, finalizationVotes) if err != nil { return err } @@ -1387,13 +1391,13 @@ func (e *Epoch) maybeAssembleEmptyNotarization() error { } // Check if we found a quorum of votes for the same metadata - popularEmptyVote, signatures, found := findEmptyVoteThatIsQuorum(emptyVotes.votes, e.SignatureAggregator.IsQuorum) + popularEmptyVote, signatures, found := findEmptyVoteThatIsQuorum(emptyVotes.votes, e.signatureAggregator.IsQuorum) if !found { e.Logger.Debug("Could not find empty vote with a quorum or more votes", zap.Uint64("round", e.round)) return nil } - qc, err := e.SignatureAggregator.Aggregate(signatures) + qc, err := e.signatureAggregator.Aggregate(signatures) if err != nil { e.Logger.Error("Could not aggregate empty votes signatures", zap.Error(err), zap.Uint64("round", e.round)) return nil @@ -1500,7 +1504,7 @@ func (e *Epoch) maybeCollectNotarization() error { } } - if !e.SignatureAggregator.IsQuorum(NodeIDsFromVotes(votesForOurBlock)) { + if !e.signatureAggregator.IsQuorum(NodeIDsFromVotes(votesForOurBlock)) { e.Logger.Debug("Not enough votes to form a notarization for our block", zap.Uint64("round", e.round), zap.Int("voteForOurBlock", len(votesForOurBlock)), @@ -1508,7 +1512,7 @@ func (e *Epoch) maybeCollectNotarization() error { return nil } - notarization, err := NewNotarization(e.Logger, e.SignatureAggregator, votesForCurrentRound, block.BlockHeader()) + notarization, err := NewNotarization(e.Logger, e.signatureAggregator, votesForCurrentRound, block.BlockHeader()) if err != nil { return err } @@ -1594,7 +1598,7 @@ func (e *Epoch) handleEmptyNotarizationMessage(emptyNotarization *EmptyNotarizat } // Otherwise, this round is not notarized or finalized yet, so verify the empty notarization and store it. - if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, emptyNotarization, from); err != nil { + if err := VerifyQC(emptyNotarization.QC, e.Logger, "Empty notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, emptyNotarization, from); err != nil { return nil } @@ -1650,7 +1654,7 @@ func (e *Epoch) handleNotarizationMessage(message *Notarization, from NodeID) er return nil } - if err := VerifyQC(message.QC, e.Logger, "Notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { + if err := VerifyQC(message.QC, e.Logger, "Notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, message, from); err != nil { return nil } @@ -3206,20 +3210,20 @@ func (e *Epoch) verifyQuorumRound(q QuorumRound, from NodeID) error { if q.Finalization != nil { // extra check needed if we have a finalized block - err := VerifyQC(q.Finalization.QC, e.Logger, "Finalization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Finalization, from) + err := VerifyQC(q.Finalization.QC, e.Logger, "Finalization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Finalization, from) if err != nil { return errors.New("invalid finalization") } } if q.Notarization != nil { - if err := VerifyQC(q.Notarization.QC, e.Logger, "Notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Notarization, from); err != nil { + if err := VerifyQC(q.Notarization.QC, e.Logger, "Notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.Notarization, from); err != nil { return fmt.Errorf("invalid notarization: %v", err) } } if q.EmptyNotarization != nil { - err := VerifyQC(q.EmptyNotarization.QC, e.Logger, "Empty notarization", e.SignatureAggregator.IsQuorum, e.eligibleNodeIDs, q.EmptyNotarization, from) + err := VerifyQC(q.EmptyNotarization.QC, e.Logger, "Empty notarization", e.signatureAggregator.IsQuorum, e.eligibleNodeIDs, q.EmptyNotarization, from) if err != nil { return fmt.Errorf("invalid empty notarization QC: %v", err) } @@ -3420,10 +3424,10 @@ func (e *Epoch) nextSeqToCommit() uint64 { return e.Storage.NumBlocks() } -// SortNodes sorts the nodes in place by their byte representations. -func SortNodes(nodes []NodeID) { - slices.SortFunc(nodes, func(a, b NodeID) int { - return bytes.Compare(a[:], b[:]) +// SortNodesWeights sorts the nodes in place by their byte representations. +func SortNodesWeights(nodes NodeWeights) { + slices.SortFunc(nodes, func(a, b NodeWeight) int { + return bytes.Compare(a.Node[:], b.Node[:]) }) } diff --git a/epoch_failover_test.go b/epoch_failover_test.go index a6c39042..321acb1f 100644 --- a/epoch_failover_test.go +++ b/epoch_failover_test.go @@ -85,9 +85,9 @@ func TestEpochLeaderFailoverWithEmptyNotarization(t *testing.T) { func TestEpochRebroadcastsEmptyVoteAfterBlockProposalReceived(t *testing.T) { bb := testutil.NewTestBlockBuilder() - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes) + comm := newRebroadcastComm(nodes.EqualWeightedNodeWeights()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -349,12 +349,12 @@ func TestEpochLeaderFailoverDoNotPersistEmptyRoundTwice(t *testing.T) { } func TestEpochLeaderRecursivelyFetchNotarizedBlocks(t *testing.T) { - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} bb := testutil.NewTestBlockBuilder() recordedMessages := make(chan *Message, 100) - comm := &recordingComm{Communication: testutil.NoopComm(nodes), SentMessages: recordedMessages} + comm := &recordingComm{Communication: testutil.NoopComm(nodes.EqualWeightedNodeWeights()), SentMessages: recordedMessages} conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) e, err := NewEpoch(conf) @@ -1115,18 +1115,18 @@ func TestEpochBlacklist(t *testing.T) { } type rebroadcastComm struct { - nodes []NodeID + nodes NodeWeights emptyVotes chan *EmptyVote } -func newRebroadcastComm(nodes []NodeID) *rebroadcastComm { +func newRebroadcastComm(nodes NodeWeights) *rebroadcastComm { return &rebroadcastComm{ nodes: nodes, emptyVotes: make(chan *EmptyVote, 10), } } -func (r *rebroadcastComm) Nodes() []NodeID { +func (r *rebroadcastComm) Nodes() NodeWeights { return r.nodes } @@ -1142,9 +1142,9 @@ func (r *rebroadcastComm) Broadcast(msg *Message) { func TestEpochRebroadcastsEmptyVote(t *testing.T) { bb := testutil.NewTestBlockBuilder() - nodes := []NodeID{{1}, {2}, {3}, {4}} + nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes) + comm := newRebroadcastComm(nodes.EqualWeightedNodeWeights()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -1227,7 +1227,7 @@ func runCrashAndRestartExecution(t *testing.T, e *Epoch, bb *testutil.TestBlockB // Case 2: t.Run(fmt.Sprintf("%s-with-crash", t.Name()), func(t *testing.T) { - conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], testutil.NewNoopComm(nodes), bbAfterCrash) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0].Node, testutil.NewNoopComm(nodes.NodesIDs()), bbAfterCrash) conf.Storage = cloneStorage conf.WAL = cloneWAL diff --git a/epoch_test.go b/epoch_test.go index 83004dff..39bb1aab 100644 --- a/epoch_test.go +++ b/epoch_test.go @@ -118,7 +118,8 @@ func TestFinalizeSameSequence(t *testing.T) { require.NoError(t, err) // create a notarization and now we should send a finalize vote for seq 1 again - notarization, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[1:]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes[1:]) require.NoError(t, err) testutil.InjectTestNotarization(t, e, notarization, nodes[1]) @@ -202,7 +203,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat } for range numEmptyNotarizations { - leader := LeaderForRound(e.Comm.Nodes(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) if e.ID.Equals(leader) { fVote := advanceWithFinalizeCheck(t, e, recordingComm, bb) finalizeVoteSeqs[fVote.Finalization.Seq] = fVote @@ -236,7 +237,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat verified <- struct{}{} } - leader := LeaderForRound(e.Comm.Nodes(), 1+numEmptyNotarizations+numNotarizations) + leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), 1+numEmptyNotarizations+numNotarizations) if e.ID.Equals(leader) { return } @@ -275,7 +276,8 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat } // create a notarization and now we should send a finalize vote for seqToDoubleFinalize again - notarization, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[1:]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes[1:]) require.NoError(t, err) testutil.InjectTestNotarization(t, e, notarization, nodes[1]) @@ -377,7 +379,8 @@ func TestEpochHandleNotarizationFutureRound(t *testing.T) { require.NoError(t, e.Start()) // Create a notarization for round 1 which is a future round because we haven't gone through round 0 yet. - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, secondBlock, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, secondBlock, nodes) require.NoError(t, err) // Give the node the notarization message before receiving the first block @@ -442,7 +445,8 @@ func TestEpochIndexFinalization(t *testing.T) { // when we receive that finalization, we should commit the rest of the finalizations for seqs // 1 & 2 - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, e.Comm.Nodes()) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, e.Comm.Nodes().NodesIDs()) testutil.InjectTestFinalization(t, e, &finalization, nodes[1]) storage.WaitForBlockCommit(2) @@ -539,7 +543,8 @@ func TestEpochIncreasesRoundAfterFinalization(t *testing.T) { require.Equal(t, uint64(0), storage.NumBlocks()) // create the finalized block - finalization, _ := testutil.NewFinalizationRecord(t, l, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes) testutil.InjectTestFinalization(t, e, &finalization, nodes[1]) storage.WaitForBlockCommit(1) @@ -746,10 +751,10 @@ func TestEpochStartedTwice(t *testing.T) { } func advanceRoundFromEmpty(t *testing.T, e *Epoch) { - leader := LeaderForRound(e.Comm.Nodes(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) require.False(t, e.ID.Equals(leader), "epoch cannot be the leader for the empty round") - emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes(), e.Metadata().Round) + emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) err := e.HandleMessage(&Message{ EmptyNotarization: emptyNote, }, leader) @@ -1010,8 +1015,10 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { wal.AssertWALSize(1) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + t.Run("notarization with unknown signer isn't taken into account", func(t *testing.T) { - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {5}}) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {5}}) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1025,7 +1032,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("notarization with double signer isn't taken into account", func(t *testing.T) { - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}}) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, []NodeID{{2}, {3}}) require.NoError(t, err) tqc := notarization.QC.(testutil.TestQC) @@ -1081,7 +1088,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("finalization with unknown signer isn't taken into account", func(t *testing.T) { - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {5}}) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {5}}) err = e.HandleMessage(&Message{ Finalization: &finalization, @@ -1092,7 +1099,7 @@ func TestEpochQCSignedByNonExistentNodes(t *testing.T) { }) t.Run("finalization with double signer isn't taken into account", func(t *testing.T) { - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block, []NodeID{{2}, {3}, {3}}) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, block, []NodeID{{2}, {3}, {3}}) err = e.HandleMessage(&Message{ Finalization: &finalization, @@ -1250,7 +1257,8 @@ func TestEpochSendsBlockDigestRequest(t *testing.T) { require.True(t, built) block := bb.GetBuiltBlock() - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1495,7 +1503,8 @@ func TestDoubleIncrementOnPersistNotarization(t *testing.T) { require.True(t, ok) block := bb.GetBuiltBlock() - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1586,7 +1595,8 @@ func TestRejectsOldNotarizationAndVotes(t *testing.T) { } // send notarization for round 1, after the finalization was sent - notarization, err := testutil.NewNotarization(conf.Logger, conf.SignatureAggregator, block, nodes) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarization, err := testutil.NewNotarization(conf.Logger, sigAggr, block, nodes) require.NoError(t, err) err = e.HandleMessage(&Message{ @@ -1634,7 +1644,7 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, nodes := e.Comm.Nodes() quorum := simplex.Quorum(len(nodes)) // leader is the proposer of the new block for the given round - leader := simplex.LeaderForRound(nodes, e.Metadata().Round) + leader := simplex.LeaderForRound(nodes.NodesIDs(), e.Metadata().Round) md := e.Metadata() if injectedMD != nil { md = *injectedMD @@ -1667,8 +1677,9 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, var notarization *simplex.Notarization if notarize { // start at one since our node has already voted - n, err := testutil.NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) - testutil.InjectTestNotarization(t, e, n, nodes[1]) + sigAggr := e.SignatureAggregatorCreator(nodes) + n, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes.NodesIDs()[0:quorum]) + testutil.InjectTestNotarization(t, e, n, nodes[1].Node) e.WAL.(*testutil.TestWAL).AssertNotarization(block.Metadata.Round) require.NoError(t, err) @@ -1677,10 +1688,10 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, if finalize { for i := 0; i <= quorum; i++ { - if nodes[i].Equals(e.ID) { + if nodes[i].Node.Equals(e.ID) { continue } - testutil.InjectTestFinalizeVote(t, e, block, nodes[i]) + testutil.InjectTestFinalizeVote(t, e, block, nodes[i].Node) } if nextSeqToCommit != block.Metadata.Seq { diff --git a/global.go b/global.go index 4fdec9bd..6e70a0a6 100644 --- a/global.go +++ b/global.go @@ -51,3 +51,14 @@ func (nodes NodeIDs) IndexOf(id NodeID) int { } return -1 } + +func (nodes NodeIDs) EqualWeightedNodeWeights() NodeWeights { + weights := make(NodeWeights, len(nodes)) + for i, node := range nodes { + weights[i] = NodeWeight{ + Node: node, + Weight: 1, + } + } + return weights +} \ No newline at end of file diff --git a/msm/encoding.go b/msm/encoding.go index a2eafb7d..64d09cda 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -5,9 +5,9 @@ package metadata import ( "bytes" - "fmt" - "math" "slices" + + "github.com/ava-labs/simplex" ) //go:generate go run github.com/StephenButtolph/canoto/canoto encoding.go @@ -199,36 +199,27 @@ func (nea *NextEpochApprovals) Equals(other *NextEpochApprovals) bool { type NodeBLSMappings []NodeBLSMapping -func (nbms NodeBLSMappings) TotalWeight() (uint64, error) { - var totalWeight uint64 - for _, nbm := range nbms { - var err error - totalWeight, err = safeAdd(totalWeight, nbm.Weight) - if err != nil { - return 0, fmt.Errorf("failed to sum weights of all nodes: %w", err) +func (nbms NodeBLSMappings) NodeWeights() simplex.NodeWeights { + nodeWeights := make(simplex.NodeWeights, len(nbms)) + for i, nbm := range nbms { + nodeWeights[i] = simplex.NodeWeight{ + Node: nbm.NodeID[:], + Weight: nbm.Weight, } } - - if totalWeight == 0 { - return 0, fmt.Errorf("total weight of validators is 0") - } - - if totalWeight > math.MaxInt64 { - return 0, fmt.Errorf("total weight of validators is too big, overflows int64: %d", totalWeight) - } - return totalWeight, nil + return nodeWeights } -func (nbms NodeBLSMappings) ApprovingWeights(approvingNodes bitmask) []uint64 { - approvingWeights := make([]uint64, 0, len(nbms)) +func (nbms NodeBLSMappings) SelectSubset(bitmask bitmask) []simplex.NodeID { + nodeIDs := make([]simplex.NodeID, 0, len(nbms)) for i, nbm := range nbms { - if !approvingNodes.Contains(i) { + if !bitmask.Contains(i) { continue } - approvingWeights = append(approvingWeights, nbm.Weight) + nodeIDs = append(nodeIDs, nbm.NodeID[:]) } - return approvingWeights + return nodeIDs } func (nbms NodeBLSMappings) Clone() NodeBLSMappings { diff --git a/msm/msm.go b/msm/msm.go index 668e5c5b..796d86c0 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -39,13 +39,6 @@ func (smb *StateMachineBlock) Digest() [32]byte { return sha256.Sum256(combined) } -// Used to aggregate validator signatures for epoch transitions. -type SignatureAggregator interface { - AggregateSignatures(signatures ...[]byte) ([]byte, error) - - IsQuorum(approverWeights []uint64, totalWeight uint64) bool -} - // ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. type ApprovalsRetriever interface { RetrieveApprovals() ValidatorSetApprovals @@ -104,8 +97,8 @@ type StateMachine struct { GetBlock BlockRetriever // ApprovalsRetriever retrieves validator approvals for epoch transitions. ApprovalsRetriever ApprovalsRetriever - // SignatureAggregator aggregates signatures from validators. - SignatureAggregator SignatureAggregator + // SignatureAggregatorCreator creates a new SignatureAggregator for aggregating validator signatures for epoch transitions. + SignatureAggregatorCreator simplex.SignatureAggregatorCreator // KeyAggregator aggregates public keys from validators. KeyAggregator KeyAggregator // SignatureVerifier verifies signatures from validators. @@ -262,7 +255,7 @@ func (sm *StateMachine) init() { getValidatorSet: sm.GetValidatorSet, keyAggregator: sm.KeyAggregator, sigVerifier: sm.SignatureVerifier, - sigAggregator: sm.SignatureAggregator, + sigAggregatorCreator: sm.SignatureAggregatorCreator, }, &sealingBlockSeqVerifier{}, } @@ -548,7 +541,9 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sm.SignatureAggregator, validators, sm.SignatureAggregator) + sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) + + newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators) if err != nil { return nil, err } @@ -567,9 +562,12 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren // in which case we just carry over the approvals we have so far to the next block, // so that eventually we'll have enough approvals to seal the epoch. if !newApprovals.canSeal { + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch",) return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) } + sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") + // Else, we have enough approvals to seal the epoch, so we create the sealing block. return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, newApprovals, pChainHeight) } @@ -739,9 +737,8 @@ func computeNewApprovals( nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, pChainHeight uint64, - aggregator SignatureAggregator, + sigAggr simplex.SignatureAggregator, validators NodeBLSMappings, - sigAggr SignatureAggregator, ) (*approvals, error) { if nextEpochApprovals == nil { nextEpochApprovals = &NextEpochApprovals{} @@ -761,17 +758,14 @@ func computeNewApprovals( approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes) // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. - aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, aggregator) + aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr) if err != nil { return nil, err } // we check if we have enough approvals to seal the epoch by computing the relative approval ratio, // which is the ratio of the total weight of approving nodes divided by the total weight of all validators. - canSeal, err := canSealBlock(validators, newApprovingNodes, sigAggr) - if err != nil { - return nil, err - } + canSeal := sigAggr.IsQuorum(validators.SelectSubset(newApprovingNodes)) return &approvals{ canSeal: canSeal, @@ -782,7 +776,7 @@ func computeNewApprovals( // computeNewApproverSignaturesAndSigners computes the signatures of the nodes that approve the next epoch including the previous aggregated signature, // and bitmask of nodes that correspond to those signatures, and aggregates all signatures together. -func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int, aggregator SignatureAggregator) ([]byte, bitmask, error) { +func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int, sigAggr simplex.SignatureAggregator) ([]byte, bitmask, error) { if nextEpochApprovals == nil { return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") } @@ -811,12 +805,9 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // Add the existing signature into the list of signatures to aggregate existingSignature := nextEpochApprovals.Signature - if existingSignature != nil { - newSignatures = append(newSignatures, existingSignature) - } // Finally, we aggregate all signatures together, to compute the new aggregated signature. - aggregatedSignature, err := aggregator.AggregateSignatures(newSignatures...) + aggregatedSignature, err := sigAggr.AppendSignatures(existingSignature, newSignatures...) if err != nil { return nil, bitmask{}, fmt.Errorf("failed to aggregate signatures: %w", err) } @@ -824,17 +815,6 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova return aggregatedSignature, newApprovingNodes, nil } -func canSealBlock(validators NodeBLSMappings, newApprovingNodes bitmask, sigAggr SignatureAggregator) (bool, error) { - approvingWeights := validators.ApprovingWeights(newApprovingNodes) - - totalWeight, err := validators.TotalWeight() - if err != nil { - return false, err - } - - return sigAggr.IsQuorum(approvingWeights, totalWeight), nil -} - // sanitizeApprovals filters out approvals that are not valid by checking if they agree with our candidate auxiliary info digest and P-Chain height, // and if they are from the validator set and haven't already been approved. func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask) ValidatorSetApprovals { diff --git a/msm/msm_test.go b/msm/msm_test.go index c76da241..6ec03b39 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -9,6 +9,7 @@ import ( "crypto/rand" "encoding/asn1" "fmt" + "maps" "math" "testing" "time" @@ -37,9 +38,7 @@ type blockStore map[uint64]*outerBlock func (bs blockStore) clone() blockStore { newStore := make(blockStore) - for k, v := range bs { - newStore[k] = v - } + maps.Copy(newStore, bs) return newStore } @@ -68,26 +67,44 @@ func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, p } type signatureAggregator struct { + weightByNodeID map[string]uint64 + totalWeight uint64 } type aggregatrdSignature struct { Signatures [][]byte } -func (sv *signatureAggregator) AggregateSignatures(signatures ...[]byte) ([]byte, error) { - bytes, err := asn1.Marshal(aggregatrdSignature{Signatures: signatures}) - if err != nil { - return nil, err +func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + all := make([][]byte, 0, len(sigs)+1) + all = append(all, sigs...) + if len(existing) > 0 { + all = append(all, existing) } - return bytes, nil + return asn1.Marshal(aggregatrdSignature{Signatures: all}) } -func (sv *signatureAggregator) IsQuorum(approverWeights []uint64, totalWeights uint64) bool { +func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { var sum uint64 - for _, w := range approverWeights { - sum += w + for _, signer := range signers { + sum += sv.weightByNodeID[string(signer)] + } + return sum*3 > sv.totalWeight*2 +} + +func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { + return func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} + for _, nw := range weights { + s.weightByNodeID[string(nw.Node)] = nw.Weight + s.totalWeight += nw.Weight + } + return s } - return sum*3 > totalWeights*2 } type noOpPChainListener struct{} @@ -147,7 +164,7 @@ var ( func TestMSMFirstBlockAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ - Round: 0, + Round: 1, Seq: 1, Epoch: 1, Prev: genesisBlock.Digest(), @@ -734,7 +751,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { // node1 is at index 0 in validatorSet2 → bitmask bit 0 → {1} bitmask := []byte{1} - sig, err := aggr.AggregateSignatures([]byte("sig1")) + sig, err := aggr.AppendSignatures(nil, []byte("sig1")) require.NoError(t, err) tc.blockBuilder.block = nextBlock(4) @@ -773,7 +790,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { } // node2 is at index 1 → bitmask bits 0,1 → {3} - sig, err = aggr.AggregateSignatures([]byte("sig2"), sig) + sig, err = aggr.AppendSignatures(sig, []byte("sig2")) require.NoError(t, err) bitmask = []byte{3} @@ -813,7 +830,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { } // node3 is at index 2 → bitmask bits 0,1,2 → {7} - sig6, err := aggr.AggregateSignatures([]byte("sig3"), sig) + sig6, err := aggr.AppendSignatures(sig, []byte("sig3")) require.NoError(t, err) bitmask = []byte{7} @@ -1046,7 +1063,7 @@ func newStateMachine(t *testing.T) (StateMachine, *testConfig) { MaxBlockBuildingWaitTime: time.Second, ApprovalsRetriever: &testConfig.approvalsRetriever, SignatureVerifier: &testConfig.signatureVerifier, - SignatureAggregator: &testConfig.signatureAggregator, + SignatureAggregatorCreator: newSignatureAggregatorCreator(), BlockBuilder: &testConfig.blockBuilder, KeyAggregator: &testConfig.keyAggregator, GetPChainHeight: func() uint64 { @@ -1219,62 +1236,6 @@ func TestFindFirstSimplexBlock(t *testing.T) { }) } -func TestComputeTotalWeight(t *testing.T) { - t.Run("valid weights", func(t *testing.T) { - validators := NodeBLSMappings{ - {Weight: 100}, - {Weight: 200}, - {Weight: 300}, - } - total, err := validators.TotalWeight() - require.NoError(t, err) - require.Equal(t, uint64(600), total) - }) - - t.Run("zero total weight", func(t *testing.T) { - validators := NodeBLSMappings{{Weight: 0}} - _, err := validators.TotalWeight() - require.ErrorContains(t, err, "total weight of validators is 0") - }) - - t.Run("empty validators", func(t *testing.T) { - _, err := NodeBLSMappings{}.TotalWeight() - require.ErrorContains(t, err, "total weight of validators is 0") - }) -} - -func TestComputeApprovingWeight(t *testing.T) { - validators := NodeBLSMappings{ - {Weight: 100}, - {Weight: 200}, - {Weight: 300}, - } - - t.Run("all approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{7}) - weights := validators.ApprovingWeights(bm) - require.Equal(t, []uint64{100, 200, 300}, weights) - }) - - t.Run("partial approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{5}) - weights := validators.ApprovingWeights(bm) - require.Equal(t, []uint64{100, 300}, weights) - }) - - t.Run("none approving", func(t *testing.T) { - bm := bitmaskFromBytes(nil) - weights := validators.ApprovingWeights(bm) - require.Empty(t, weights) - }) - - t.Run("single validator approving", func(t *testing.T) { - bm := bitmaskFromBytes([]byte{2}) - weights := validators.ApprovingWeights(bm) - require.Equal(t, []uint64{200}, weights) - }) -} - func TestSanitizeApprovals(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} @@ -1334,25 +1295,30 @@ func TestSanitizeApprovals(t *testing.T) { // concatAggregator concatenates signatures for easy verification in tests. type concatAggregator struct{} -func (concatAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { - return bytes.Join(sigs, nil), nil +func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") } -func (concatAggregator) IsQuorum(approverWeights []uint64, totalWeights uint64) bool { - var sum uint64 - for _, w := range approverWeights { - sum += w - } - return sum*3 >= totalWeights*2 +func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + result := bytes.Join(sigs, nil) + return append(result, existing...), nil +} + +func (concatAggregator) IsQuorum([]simplex.NodeID) bool { + return false } type failingAggregator struct{} -func (failingAggregator) AggregateSignatures(sigs ...[]byte) ([]byte, error) { +func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { return nil, fmt.Errorf("aggregation failed") } -func (failingAggregator) IsQuorum([]uint64, uint64) bool { +func (failingAggregator) IsQuorum([]simplex.NodeID) bool { return false } diff --git a/msm/verification.go b/msm/verification.go index d3af5cc3..ba8ef333 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -67,7 +67,7 @@ type nextEpochApprovalsVerifier struct { sigVerifier SignatureVerifier getValidatorSet ValidatorSetRetriever keyAggregator KeyAggregator - sigAggregator SignatureAggregator + sigAggregatorCreator simplex.SignatureAggregatorCreator } func (nv *nextEpochApprovalsVerifier) Verify(in verificationInput) error { @@ -99,10 +99,8 @@ func (nv *nextEpochApprovalsVerifier) verifySealingBlock(prev SimplexEpochInfo, } approvingNodes := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) - canSeal, err := canSealBlock(validators, approvingNodes, nv.sigAggregator) - if err != nil { - return err - } + sigAggr := nv.sigAggregatorCreator(validators.NodeWeights()) + canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvingNodes)) if !canSeal { return fmt.Errorf("not enough approvals to seal block") diff --git a/msm/verification_test.go b/msm/verification_test.go index 43f99a58..b515a4f1 100644 --- a/msm/verification_test.go +++ b/msm/verification_test.go @@ -841,10 +841,10 @@ func TestNextEpochApprovalsVerifier(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { v := &nextEpochApprovalsVerifier{ - sigVerifier: tc.sigVerifier, - getValidatorSet: tc.getValidator, - keyAggregator: tc.keyAggregator, - sigAggregator: &signatureAggregator{}, + sigVerifier: tc.sigVerifier, + getValidatorSet: tc.getValidator, + keyAggregator: tc.keyAggregator, + sigAggregatorCreator: newSignatureAggregatorCreator(), } err := v.Verify(verificationInput{ nextBlockType: tc.nextBlockType, diff --git a/pos_test.go b/pos_test.go index 6f7af6bf..177492e1 100644 --- a/pos_test.go +++ b/pos_test.go @@ -23,16 +23,18 @@ func TestPoS(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} - posSigAggregator := &testutil.TestSignatureAggregator{ - IsQuorumFunc: func(signatures []simplex.NodeID) bool { - var totalWeight uint64 - for _, signer := range signatures { - totalWeight += weights[int(signer[0])] - } - return totalWeight > 6 - }, + posSigAggregatorCreator := func(_ []simplex.NodeWeight) simplex.SignatureAggregator { + return &testutil.TestSignatureAggregator{ + IsQuorumFunc: func(signatures []simplex.NodeID) bool { + var totalWeight uint64 + for _, signer := range signatures { + totalWeight += weights[int(signer[0])] + } + return totalWeight > 6 + }, + } } - testConf := &testutil.TestNodeConfig{SigAggregator: posSigAggregator, ReplicationEnabled: true} + testConf := &testutil.TestNodeConfig{SigAggregatorCreator: posSigAggregatorCreator, ReplicationEnabled: true} net := testutil.NewControlledNetwork(t, nodes) testutil.NewControlledSimplexNode(t, nodes[0], net, testConf) diff --git a/recovery_test.go b/recovery_test.go index 0e82dccd..e87e8b71 100644 --- a/recovery_test.go +++ b/recovery_test.go @@ -250,7 +250,8 @@ func TestWalCreatedProperly(t *testing.T) { records, err = e.WAL.ReadAll() require.NoError(t, err) require.Len(t, records, 2) - expectedNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + expectedNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) require.Equal(t, expectedNotarizationRecord, records[1]) @@ -423,6 +424,8 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { require.NoError(t, err) t.Cleanup(e.Stop) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + protocolMetadata := e.Metadata() firstBlock, ok := bb.BuildBlock(ctx, protocolMetadata, emptyBlacklist) require.True(t, ok) @@ -431,7 +434,7 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { record := BlockRecord(firstBlock.BlockHeader(), fBytes) wal.Append(record) - firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(firstNotarizationRecord) @@ -445,12 +448,12 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { wal.Append(record) // Add notarization for second block - secondNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) + secondNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(secondNotarizationRecord) // Create finalization record for second block - finalization2, finalizationRecord := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) + finalization2, finalizationRecord := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) wal.Append(finalizationRecord) err = e.Start() @@ -460,7 +463,7 @@ func TestRecoverFromMultipleNotarizations(t *testing.T) { require.Equal(t, uint64(0), e.Storage.NumBlocks()) // now if we send finalization for block 1, we should index both 1 & 2 - finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) err = e.HandleMessage(&Message{ Finalization: &finalization1, }, nodes[1]) @@ -499,11 +502,12 @@ func TestRecoveryBlocksIndexed(t *testing.T) { record := BlockRecord(firstBlock.BlockHeader(), fBytes) wal.Append(record) - firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + sigAggr := conf.SignatureAggregatorCreator(conf.Comm.Nodes()) + firstNotarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) wal.Append(firstNotarizationRecord) - _, finalizationBytes := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) + _, finalizationBytes := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) wal.Append(finalizationBytes) protocolMetadata.Round = 1 @@ -524,9 +528,9 @@ func TestRecoveryBlocksIndexed(t *testing.T) { record = BlockRecord(thirdBlock.BlockHeader(), tBytes) wal.Append(record) - finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, firstBlock, nodes[0:quorum]) - finalization2, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, secondBlock, nodes[0:quorum]) - fCer3, _ := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, thirdBlock, nodes[0:quorum]) + finalization1, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, nodes[0:quorum]) + finalization2, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, secondBlock, nodes[0:quorum]) + fCer3, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, thirdBlock, nodes[0:quorum]) conf.Storage.Index(ctx, firstBlock, finalization1) conf.Storage.Index(ctx, secondBlock, finalization2) @@ -655,7 +659,8 @@ func TestWalRecoveryTriggersEmptyVoteTimeout(t *testing.T) { require.NoError(t, wal.Append(blockRecord)) // lets add some notarizations - notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) require.NoError(t, wal.Append(notarizationRecord)) @@ -761,7 +766,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord := BlockRecord(block.BlockHeader(), bBytes) - notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block, nodes[0:quorum]) + notarizationRecord, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block, nodes[0:quorum]) require.NoError(t, err) return [][]byte{blockRecord, notarizationRecord} @@ -779,7 +784,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - _, finalizationRecord1 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + _, finalizationRecord1 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) // Create empty notarization for round 0 emptyNotarization0 := testutil.NewEmptyNotarization(nodes[0:quorum], 0) @@ -800,7 +805,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes0, err := block0.Bytes() require.NoError(t, err) blockRecord0 := BlockRecord(block0.BlockHeader(), bBytes0) - notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block0, nodes[0:quorum]) + notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block0, nodes[0:quorum]) require.NoError(t, err) block1, ok := bb.BuildBlock(ctx, ProtocolMetadata{Round: 1, Epoch: 0, Seq: 1}, emptyBlacklist) @@ -808,7 +813,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes1, err := block1.Bytes() require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) require.NoError(t, err) block2, ok := bb.BuildBlock(ctx, ProtocolMetadata{Round: 2, Epoch: 0, Seq: 2}, emptyBlacklist) @@ -816,7 +821,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes2, err := block2.Bytes() require.NoError(t, err) blockRecord2 := BlockRecord(block2.BlockHeader(), bBytes2) - notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) require.NoError(t, err) // Create empty notarization for round 3 @@ -843,7 +848,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes3, err := block3.Bytes() require.NoError(t, err) blockRecord3 := BlockRecord(block3.BlockHeader(), bBytes3) - _, finalizationRecord3 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block3, nodes[0:quorum]) + _, finalizationRecord3 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block3, nodes[0:quorum]) // Create empty notarization for round 2 emptyNotarization2 := testutil.NewEmptyNotarization(nodes[0:quorum], 2) @@ -856,7 +861,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes1, err := block1.Bytes() require.NoError(t, err) blockRecord1 := BlockRecord(block1.BlockHeader(), bBytes1) - notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block1, nodes[0:quorum]) + notarizationRecord1, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block1, nodes[0:quorum]) require.NoError(t, err) // Return in reverse order @@ -878,7 +883,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes0, err := block0.Bytes() require.NoError(t, err) blockRecord0 := BlockRecord(block0.BlockHeader(), bBytes0) - notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block0, nodes[0:quorum]) + notarizationRecord0, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block0, nodes[0:quorum]) require.NoError(t, err) // Create finalization for round 10 (highest) @@ -887,7 +892,7 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { bBytes10, err := block10.Bytes() require.NoError(t, err) blockRecord10 := BlockRecord(block10.BlockHeader(), bBytes10) - _, finalizationRecord10 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block10, nodes[0:quorum]) + _, finalizationRecord10 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block10, nodes[0:quorum]) // Create empty notarization for round 5 emptyNotarization5 := testutil.NewEmptyNotarization(nodes[0:quorum], 5) @@ -913,10 +918,10 @@ func TestWalRecoverySetsRoundCorrectly(t *testing.T) { require.NoError(t, err) blockRecord2 := BlockRecord(block2.BlockHeader(), bBytes2) - notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + notarizationRecord2, err := testutil.NewNotarizationRecord(conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) require.NoError(t, err) - _, finalizationRecord2 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregator, block2, nodes[0:quorum]) + _, finalizationRecord2 := testutil.NewFinalizationRecord(t, conf.Logger, conf.SignatureAggregatorCreator(conf.Comm.Nodes()), block2, nodes[0:quorum]) // All records for same round return [][]byte{ diff --git a/replication_test.go b/replication_test.go index 7b5988a7..d7e4685a 100644 --- a/replication_test.go +++ b/replication_test.go @@ -132,7 +132,8 @@ func TestReplicationAdversarialNode(t *testing.T) { require.Equal(t, uint64(0), laggingNode.E.Metadata().Round) net.Connect(laggingNode.E.ID) - finalization, _ := NewFinalizationRecord(t, laggingNode.E.Logger, laggingNode.E.SignatureAggregator, blocks[1], nodes[:quorum]) + sigAggr := laggingNode.E.SignatureAggregatorCreator(laggingNode.E.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, laggingNode.E.Logger, sigAggr, blocks[1], nodes[:quorum]) finalizationMsg := &simplex.Message{ Finalization: &finalization, } @@ -375,7 +376,8 @@ func TestReplicationStartsBeforeCurrentRound(t *testing.T) { record := simplex.BlockRecord(firstBlock.BlockHeader(), fBytes) laggingNode.WAL.Append(record) - firstNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, laggingNode.E.SignatureAggregator, firstBlock, nodes[0:quorum]) + sigAggr := laggingNode.E.SignatureAggregatorCreator(laggingNode.E.Comm.Nodes()) + firstNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, sigAggr, firstBlock, nodes[0:quorum]) require.NoError(t, err) laggingNode.WAL.Append(firstNotarizationRecord) @@ -385,7 +387,7 @@ func TestReplicationStartsBeforeCurrentRound(t *testing.T) { record = simplex.BlockRecord(secondBlock.BlockHeader(), sBytes) laggingNode.WAL.Append(record) - secondNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, laggingNode.E.SignatureAggregator, secondBlock, nodes[0:quorum]) + secondNotarizationRecord, err := NewNotarizationRecord(laggingNode.E.Logger, sigAggr, secondBlock, nodes[0:quorum]) require.NoError(t, err) laggingNode.WAL.Append(secondNotarizationRecord) @@ -414,7 +416,7 @@ func TestReplicationFutureFinalization(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} quorum := simplex.Quorum(len(nodes)) - conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[1], NoopComm(nodes), bb) + conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[1], NewNoopComm(nodes), bb) e, err := simplex.NewEpoch(conf) require.NoError(t, err) @@ -441,7 +443,8 @@ func TestReplicationFutureFinalization(t *testing.T) { }, nodes[0]) require.NoError(t, err) - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // send finalization err = e.HandleMessage(&simplex.Message{ Finalization: &finalization, @@ -600,7 +603,7 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, }, bb) @@ -627,7 +630,8 @@ func TestReplicationStuckInProposingBlock(t *testing.T) { highBlock, _ := blocks[3].VerifiedBlock.(*TestBlock) - highFinalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, highBlock, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + highFinalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, highBlock, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -973,7 +977,8 @@ func TestReplicationVerifyNotarization(t *testing.T) { block := bb.GetBuiltBlock() - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -988,7 +993,7 @@ func TestReplicationVerifyNotarization(t *testing.T) { } } - notarization, err := NewNotarization(e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + notarization, err := NewNotarization(e.Logger, sigAggr, block, nodes[0:quorum]) require.NoError(t, err) // Corrupt the QC @@ -1060,7 +1065,8 @@ func TestReplicationVerifyEmptyNotarization(t *testing.T) { block := bb.GetBuiltBlock() - finalization, _ := NewFinalizationRecord(t, e.Logger, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := NewFinalizationRecord(t, e.Logger, sigAggr, block, nodes[0:quorum]) // Trigger the replication process to start by sending a finalization for a block we do not have e.HandleMessage(&simplex.Message{ @@ -1348,7 +1354,7 @@ func TestReplicationStoresFinalization(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) conf, _, storage := DefaultTestNodeEpochConfig(t, nodes[3], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, }, bb) conf.ReplicationEnabled = true @@ -1533,7 +1539,7 @@ func TestReplicationStartsRoundFromFinalization(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) broadcastMessages := make(chan *simplex.Message, 100) conf, wal, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, BroadcastMessages: broadcastMessages, }, bb) @@ -1643,7 +1649,7 @@ func TestReplicationStartsRoundFromFinalizationWithBlock(t *testing.T) { sentMessages := make(chan *simplex.Message, 100) broadcastMessages := make(chan *simplex.Message, 100) conf, wal, storage := DefaultTestNodeEpochConfig(t, nodes[0], &recordingComm{ - Communication: NoopComm(nodes), + Communication: NewNoopComm(nodes), SentMessages: sentMessages, BroadcastMessages: broadcastMessages, }, bb) diff --git a/replication_timeout_test.go b/replication_timeout_test.go index cbaa6ead..e64e99e8 100644 --- a/replication_timeout_test.go +++ b/replication_timeout_test.go @@ -622,7 +622,8 @@ func TestReplicationResendsFinalizedBlocksThatFailedVerification(t *testing.T) { block := bb.GetBuiltBlock() block.VerificationError = errors.New("block verification failed") - finalization, _ := testutil.NewFinalizationRecord(t, l, e.SignatureAggregator, block, nodes[0:quorum]) + sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) + finalization, _ := testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes[0:quorum]) // send the finalization to start the replication process e.HandleMessage(&simplex.Message{ @@ -658,7 +659,7 @@ func TestReplicationResendsFinalizedBlocksThatFailedVerification(t *testing.T) { block.Data = append(block.Data, 0) block.ComputeDigest() - finalization, _ = testutil.NewFinalizationRecord(t, l, e.SignatureAggregator, block, nodes[0:quorum]) + finalization, _ = testutil.NewFinalizationRecord(t, l, sigAggr, block, nodes[0:quorum]) replicationResponse = &simplex.ReplicationResponse{ Data: []simplex.QuorumRound{ { diff --git a/testutil/comm.go b/testutil/comm.go index f5fb9cca..a3b5c2d3 100644 --- a/testutil/comm.go +++ b/testutil/comm.go @@ -22,10 +22,10 @@ import ( // - bool: true if the message can be transmitted, false otherwise type MessageFilter func(msg *simplex.Message, from simplex.NodeID, to simplex.NodeID) bool -type NoopComm []simplex.NodeID +type NoopComm simplex.NodeWeights -func (n NoopComm) Nodes() []simplex.NodeID { - return n +func (n NoopComm) Nodes() simplex.NodeWeights { + return simplex.NodeWeights(n) } func (n NoopComm) Send(*simplex.Message, simplex.NodeID) { @@ -51,8 +51,8 @@ func NewTestComm(from simplex.NodeID, net *BasicInMemoryNetwork, messageFilter M } } -func (c *TestComm) Nodes() []simplex.NodeID { - return c.net.nodes +func (c *TestComm) Nodes() simplex.NodeWeights { + return c.net.nodeWeights } func (c *TestComm) Send(msg *simplex.Message, destination simplex.NodeID) { @@ -177,7 +177,7 @@ func (c *TestComm) Broadcast(msg *simplex.Message) { if !c.isMessagePermitted(msg, instance.E.ID) { continue } - // Skip sending the message to yourself or disconnected nodes + // Skip sending the message to yourself or disconnected nodeWeights if bytes.Equal(c.from, instance.E.ID) || c.net.IsDisconnected(instance.E.ID) { continue } @@ -191,6 +191,6 @@ func AllowAllMessages(*simplex.Message, simplex.NodeID, simplex.NodeID) bool { return true } -func NewNoopComm(nodes []simplex.NodeID) NoopComm { - return NoopComm(nodes) +func NewNoopComm(nodes simplex.NodeIDs) NoopComm { + return NoopComm(nodes.EqualWeightedNodeWeights()) } diff --git a/testutil/controlled.go b/testutil/controlled.go index fe45eec5..929835f0 100644 --- a/testutil/controlled.go +++ b/testutil/controlled.go @@ -19,9 +19,9 @@ type ControlledInMemoryNetwork struct { } // NewControlledNetwork creates an in-memory network. Node IDs must be provided before -// adding instances, as nodes require prior knowledge of all participants. -func NewControlledNetwork(t *testing.T, nodes []simplex.NodeID) *ControlledInMemoryNetwork { - simplex.SortNodes(nodes) +// adding instances, as nodeWeights require prior knowledge of all participants. +func NewControlledNetwork(t *testing.T, nodes simplex.NodeIDs) *ControlledInMemoryNetwork { + simplex.SortNodesWeights(nodes.EqualWeightedNodeWeights()) net := &ControlledInMemoryNetwork{ BasicInMemoryNetwork: NewBasicInMemoryNetwork(t, nodes), Instances: make([]*ControlledNode, 0), @@ -72,7 +72,7 @@ func (n *ControlledInMemoryNetwork) AdvanceWithoutLeader(round uint64, laggingNo } for _, n := range n.Instances { - leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes(), n.E.Metadata().Round)) + leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes().NodesIDs(), n.E.Metadata().Round)) if leader || laggingNodeId.Equals(n.E.ID) { continue } diff --git a/testutil/network.go b/testutil/network.go index fa4555a8..e9ee0f79 100644 --- a/testutil/network.go +++ b/testutil/network.go @@ -14,17 +14,20 @@ import ( ) type BasicInMemoryNetwork struct { - t *testing.T - nodes []simplex.NodeID - lock sync.RWMutex + t *testing.T + nodes []simplex.NodeID + nodeWeights simplex.NodeWeights + lock sync.RWMutex disconnected map[string]struct{} instances []*BasicNode } -func NewBasicInMemoryNetwork(t *testing.T, nodes []simplex.NodeID) *BasicInMemoryNetwork { - simplex.SortNodes(nodes) +func NewBasicInMemoryNetwork(t *testing.T, nodes simplex.NodeIDs) *BasicInMemoryNetwork { + nodeWeights := nodes.EqualWeightedNodeWeights() + simplex.SortNodesWeights(nodeWeights) return &BasicInMemoryNetwork{ t: t, + nodeWeights: nodeWeights, nodes: nodes, disconnected: make(map[string]struct{}), instances: make([]*BasicNode, 0), @@ -99,9 +102,9 @@ func (b *BasicInMemoryNetwork) StartInstances() { b.lock.RLock() defer b.lock.RUnlock() - require.Equal(b.t, len(b.nodes), len(b.instances)) + require.Equal(b.t, len(b.nodeWeights), len(b.instances)) - for i := len(b.nodes) - 1; i >= 0; i-- { + for i := len(b.nodeWeights) - 1; i >= 0; i-- { b.instances[i].Start() } } @@ -142,8 +145,8 @@ func (b *BasicInMemoryNetwork) AddNode(node *BasicNode) { defer b.lock.Unlock() allowed := false - for _, id := range b.nodes { - if bytes.Equal(id, node.E.ID) { + for _, nodeWeight := range b.nodeWeights { + if bytes.Equal(nodeWeight.Node, node.E.ID) { allowed = true break } diff --git a/testutil/node.go b/testutil/node.go index 2cc09eee..adfaf6bd 100644 --- a/testutil/node.go +++ b/testutil/node.go @@ -198,8 +198,8 @@ func UpdateEpochConfig(epochConfig *simplex.EpochConfig, testConfig *TestNodeCon epochConfig.Comm = testConfig.Comm } - if testConfig.SigAggregator != nil { - epochConfig.SignatureAggregator = testConfig.SigAggregator + if testConfig.SigAggregatorCreator != nil { + epochConfig.SignatureAggregatorCreator = testConfig.SigAggregatorCreator } if testConfig.BlockBuilder != nil { @@ -236,7 +236,7 @@ type TestNodeConfig struct { // optional InitialStorage []simplex.VerifiedFinalizedBlock Comm simplex.Communication - SigAggregator simplex.SignatureAggregator + SigAggregatorCreator simplex.SignatureAggregatorCreator ReplicationEnabled bool BlockBuilder *testControlledBlockBuilder diff --git a/testutil/util.go b/testutil/util.go index d8923672..9426e43f 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -32,7 +32,9 @@ func DefaultTestNodeEpochConfig(t *testing.T, nodeID simplex.NodeID, comm simple Verifier: &testVerifier{}, Storage: storage, BlockBuilder: bb, - SignatureAggregator: &TestSignatureAggregator{N: len(comm.Nodes())}, + SignatureAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + return &TestSignatureAggregator{N: len(weights)} + }, BlockDeserializer: &BlockDeserializer{}, QCDeserializer: &testQCDeserializer{t: t}, StartTime: time.Now(), @@ -120,12 +122,30 @@ func (t *testQCDeserializer) DeserializeQuorumCertificate(bytes []byte) (simplex return TestQC(qc), err } +type TestSignatureAggregatorCreator struct { + Err error + N int + IsQuorumFunc func(signatures []simplex.NodeID) bool +} + + type TestSignatureAggregator struct { Err error N int IsQuorumFunc func(signatures []simplex.NodeID) bool } +func (t *TestSignatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + if t.Err != nil { + return nil, t.Err + } + result := append([]byte{}, existing...) + for _, s := range sigs { + result = append(result, s...) + } + return result, nil +} + func (t *TestSignatureAggregator) Aggregate(signatures []simplex.Signature) (simplex.QuorumCertificate, error) { return TestQC(signatures), t.Err } From 54ab6d90dc88185c84f019f8eb8490f0a9bca752 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Mon, 11 May 2026 15:48:53 +0200 Subject: [PATCH 06/16] Remove findFirstSimplexBlock in favor of explicit initialization Signed-off-by: Yacov Manevich --- msm/encoding.go | 4 +- msm/encoding_test.go | 11 +- msm/fake_node_test.go | 13 ++ msm/msm.go | 256 +++++++++++++++++++++------------------ msm/msm_test.go | 145 ++++++++-------------- msm/verification.go | 15 +-- msm/verification_test.go | 3 + 7 files changed, 222 insertions(+), 225 deletions(-) diff --git a/msm/encoding.go b/msm/encoding.go index 64d09cda..1b863f3a 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -265,10 +265,10 @@ type ValidatorSetApproval struct { type ValidatorSetApprovals []ValidatorSetApproval -func (vsa ValidatorSetApprovals) Filter(f func(int, ValidatorSetApproval) bool) ValidatorSetApprovals { +func (vsa ValidatorSetApprovals) Filter(f func(int, ValidatorSetApproval, simplex.Logger) bool, logger simplex.Logger) ValidatorSetApprovals { result := make(ValidatorSetApprovals, 0, len(vsa)) for i, v := range vsa { - if f(i, v) { + if f(i, v, logger) { result = append(result, v) } } diff --git a/msm/encoding_test.go b/msm/encoding_test.go index 4c7bc321..efae8a0a 100644 --- a/msm/encoding_test.go +++ b/msm/encoding_test.go @@ -6,6 +6,8 @@ package metadata import ( "testing" + "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -402,22 +404,23 @@ func TestNodeBLSMappingsCompare(t *testing.T) { } func TestValidatorSetApprovalsFilter(t *testing.T) { + logger := testutil.MakeLogger(t) approvals := ValidatorSetApprovals{ {NodeID: nodeID{1}, PChainHeight: 10}, {NodeID: nodeID{2}, PChainHeight: 20}, {NodeID: nodeID{3}, PChainHeight: 30}, } - filtered := approvals.Filter(func(_ int, v ValidatorSetApproval) bool { + filtered := approvals.Filter(func(_ int, v ValidatorSetApproval, _ simplex.Logger) bool { return v.PChainHeight > 15 - }) + }, logger) require.Len(t, filtered, 2) require.Equal(t, uint64(20), filtered[0].PChainHeight) require.Equal(t, uint64(30), filtered[1].PChainHeight) // Filter all - filtered = approvals.Filter(func(int, ValidatorSetApproval) bool { + filtered = approvals.Filter(func(int, ValidatorSetApproval, simplex.Logger) bool { return false - }) + }, logger) require.Empty(t, filtered) } diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index e4084e12..97af1986 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -18,6 +18,7 @@ import ( func TestFakeNode(t *testing.T) { validatorSetRetriever := validatorSetRetriever{ resultMap: map[uint64]NodeBLSMappings{ + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, @@ -84,6 +85,7 @@ func TestFakeNode(t *testing.T) { func TestFakeNodeEmptyMempool(t *testing.T) { validatorSetRetriever := validatorSetRetriever{ resultMap: map[uint64]NodeBLSMappings{ + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, @@ -240,6 +242,17 @@ func newFakeNode(t *testing.T) *fakeNode { return StateMachineBlock{}, nil, fmt.Errorf("block not found") } + fn.sm.FirstEverSimplexBlock = func() *StateMachineBlock { + for _, block := range fn.blocks { + if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { + continue + } + return &block.block + } + require.FailNow(t, "block not found") + return nil + } + return fn } diff --git a/msm/msm.go b/msm/msm.go index 796d86c0..dbce6351 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -6,10 +6,7 @@ package metadata import ( "context" "crypto/sha256" - "errors" "fmt" - "math" - "sort" "time" "github.com/ava-labs/simplex" @@ -105,7 +102,15 @@ type StateMachine struct { SignatureVerifier SignatureVerifier // PChainProgressListener listens for changes in the P-chain height to trigger block building or epoch transitions. PChainProgressListener PChainProgressListener - + // FirstEverSimplexBlock is the first block ever built by Simplex, or nil if Simplex has yet to build a block. + FirstEverSimplexBlock func() *StateMachineBlock + // LastNonSimplexBlockPChainHeight is the P-chain height of the last block built by a non-Simplex proposer. + // It is used to determine the validator set of the first ever Simplex epoch. + LastNonSimplexBlockPChainHeight uint64 + // LastNonSimplexInnerBlock is the inner block of the last block built by a non-Simplex proposer. + LastNonSimplexInnerBlock VMBlock + // GenesisValidatorSet is the validator set used for the genesis block. + GenesisValidatorSet NodeBLSMappings // initialized tracks whether the state machine has been initialized. // This is used to lazily initialize the verifiers. initialized bool @@ -169,7 +174,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex. switch currentState { case stateFirstSimplexBlock: - return sm.buildBlockZero(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes) + return sm.buildBlockZero(parentBlock, simplexMetadataBytes, simplexBlacklistBytes) case stateBuildBlockNormalOp: return sm.buildBlockNormalOp(ctx, parentBlock, simplexMetadataBytes, simplexBlacklistBytes, prevBlockSeq) case stateBuildCollectingApprovals: @@ -211,7 +216,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc switch currentState { case stateFirstSimplexBlock: - err = sm.verifyBlockZero(ctx, block, prevBlock) + err = sm.verifyBlockZero(block, prevBlock) default: err = sm.verifyNonZeroBlock(ctx, block, prevBlock.Metadata, currentState, seq-1) } @@ -237,7 +242,11 @@ func (sm *StateMachine) init() { }, &pChainReferenceHeightVerifier{}, &epochNumberVerifier{}, + &validationDescriptorVerifier{ + getValidatorSet: sm.GetValidatorSet, + }, &prevSealingBlockHashVerifier{ + firstEverSimplexBlock: sm.FirstEverSimplexBlock, getBlock: sm.GetBlock, latestPersistedHeight: &sm.LatestPersistedHeight, }, @@ -248,9 +257,6 @@ func (sm *StateMachine) init() { &vmBlockSeqVerifier{ getBlock: sm.GetBlock, }, - &validationDescriptorVerifier{ - getValidatorSet: sm.GetValidatorSet, - }, &nextEpochApprovalsVerifier{ getValidatorSet: sm.GetValidatorSet, keyAggregator: sm.KeyAggregator, @@ -353,6 +359,7 @@ func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock Stat // Initiate the epoch transition by setting the next P-chain reference height for the new epoch info, // and build a block without an inner block. newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight + sm.Logger.Debug("Transitioning epoch without building block", zap.Uint64("newPChainRefHeight", pChainHeight)) return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil case decisionContextCanceled: return nil, ctx.Err() @@ -385,6 +392,11 @@ func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock } if !currentValidatorSet.Equal(newValidatorSet) { + sm.Logger.Debug("Validator set has changed, should transition epoch", + zap.String("currentValidatorSet", fmt.Sprintf("%v", currentValidatorSet.NodeWeights())), + zap.String("newValidatorSet", fmt.Sprintf("%v", newValidatorSet.NodeWeights())), + zap.Uint64("currentPChainRefHeight", parentBlock.Metadata.SimplexEpochInfo.PChainReferenceHeight), + zap.Uint64("newPChainHeight", pChainHeight)) return true, nil } return false, nil @@ -393,7 +405,14 @@ func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock return blockBuildingDecider } -func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, childBlock VMBlock, decisionToBuildBlock blockBuildingDecision, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { +func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, + parentBlock StateMachineBlock, + simplexMetadata []byte, + simplexBlacklist []byte, + childBlock VMBlock, + decisionToBuildBlock blockBuildingDecision, + newSimplexEpochInfo SimplexEpochInfo, + pChainHeight uint64) (*StateMachineBlock, error) { // TODO: This P-chain height should be taken from the ICM epoch childBlock, err := sm.BlockBuilder.BuildBlock(ctx, pChainHeight) if err != nil { @@ -404,6 +423,7 @@ func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, p // We need to also transition to a new epoch, in addition to building an inner block, // so set the next P-chain reference height for the new epoch info. newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight + sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", pChainHeight)) } return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil @@ -411,12 +431,23 @@ func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, p // buildBlockZero builds the first ever block for Simplex, // which is a special block that introduces the first validator set and starts the first epoch. -func (sm *StateMachine) buildBlockZero(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { - pChainHeight := sm.GetPChainHeight() +func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { + if sm.LastNonSimplexInnerBlock == nil { + sm.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") + return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + } - newValidatorSet, err := sm.GetValidatorSet(pChainHeight) - if err != nil { - return nil, err + pChainHeight := sm.LastNonSimplexBlockPChainHeight + + var validatorSet NodeBLSMappings + if sm.LastNonSimplexInnerBlock.Height() == 0 { + validatorSet = sm.GenesisValidatorSet + } else { + var err error + validatorSet, err = sm.GetValidatorSet(pChainHeight) + if err != nil { + return nil, err + } } var prevVMBlockSeq uint64 @@ -428,16 +459,37 @@ func (sm *StateMachine) buildBlockZero(ctx context.Context, parentBlock StateMac sm.Logger.Error("Parent block has no inner block, cannot determine previous VM block sequence for zero block") return nil, fmt.Errorf("failed constructing zero block: parent block has no inner block") } - simplexEpochInfo := constructSimplexZeroBlock(pChainHeight, newValidatorSet, prevVMBlockSeq) - return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) + timestamp := sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli() + simplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, validatorSet, prevVMBlockSeq) + + md, err := simplex.ProtocolMetadataFromBytes(simplexMetadata) + if err != nil { + return nil, fmt.Errorf("failed to parse simplex metadata: %w", err) + } + md.Prev = sm.LastNonSimplexInnerBlock.Digest() + md.Seq = sm.LastNonSimplexInnerBlock.Height() + + return &StateMachineBlock{ + Metadata: StateMachineMetadata{ + Timestamp: uint64(timestamp), + SimplexProtocolMetadata: simplexMetadata, + SimplexBlacklist: simplexBlacklist, + SimplexEpochInfo: simplexEpochInfo, + PChainHeight: pChainHeight, + }, + }, nil } -func (sm *StateMachine) verifyBlockZero(ctx context.Context, block *StateMachineBlock, prevBlock StateMachineBlock) error { +func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock StateMachineBlock) error { if block == nil { return fmt.Errorf("block is nil") } + if sm.LastNonSimplexInnerBlock == nil { + return fmt.Errorf("failed verifying zero block: last non-Simplex inner block is nil") + } + simplexEpochInfo := block.Metadata.SimplexEpochInfo if simplexEpochInfo.EpochNumber != 1 { @@ -448,23 +500,23 @@ func (sm *StateMachine) verifyBlockZero(ctx context.Context, block *StateMachine return fmt.Errorf("parent inner block (%s) has no inner block", prevBlock.Digest()) } + pChainHeight := sm.LastNonSimplexBlockPChainHeight prevVMBlockSeq := prevBlock.InnerBlock.Height() - currentPChainHeight := sm.GetPChainHeight() - - if block.Metadata.PChainHeight > currentPChainHeight { - return fmt.Errorf("invalid P-chain height (%d) is too big, expected to be ≤ %d", - block.Metadata.PChainHeight, currentPChainHeight) + if block.Metadata.PChainHeight != pChainHeight { + return fmt.Errorf("invalid P-chain height (%d), expected to be %d", + block.Metadata.PChainHeight, pChainHeight) } - if prevBlock.Metadata.PChainHeight > block.Metadata.PChainHeight { - return fmt.Errorf("invalid P-chain height (%d) is smaller than parent InnerBlock's P-chain height (%d)", - block.Metadata.PChainHeight, prevBlock.Metadata.PChainHeight) - } - - expectedValidatorSet, err := sm.GetValidatorSet(simplexEpochInfo.PChainReferenceHeight) - if err != nil { - return fmt.Errorf("failed to retrieve validator set at height %d: %w", simplexEpochInfo.PChainReferenceHeight, err) + var expectedValidatorSet NodeBLSMappings + if prevBlock.InnerBlock.Height() == 0 { + expectedValidatorSet = sm.GenesisValidatorSet + } else { + var err error + expectedValidatorSet, err = sm.GetValidatorSet(pChainHeight) + if err != nil { + return fmt.Errorf("failed to retrieve validator set at height %d: %w", pChainHeight, err) + } } if simplexEpochInfo.BlockValidationDescriptor == nil { @@ -473,48 +525,31 @@ func (sm *StateMachine) verifyBlockZero(ctx context.Context, block *StateMachine membership := simplexEpochInfo.BlockValidationDescriptor.AggregatedMembership.Members if !NodeBLSMappings(membership).Equal(expectedValidatorSet) { - return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", simplexEpochInfo.PChainReferenceHeight) + return fmt.Errorf("invalid BlockValidationDescriptor: should match validator set at P-chain height %d", pChainHeight) } // If we have compared all fields so far, the rest of the fields we compare by constructing an explicit expected SimplexEpochInfo - expectedSimplexEpochInfo := constructSimplexZeroBlock(simplexEpochInfo.PChainReferenceHeight, expectedValidatorSet, prevVMBlockSeq) + expectedSimplexEpochInfo := constructSimplexZeroBlockSimplexEpochInfo(pChainHeight, expectedValidatorSet, prevVMBlockSeq) if !expectedSimplexEpochInfo.Equal(&simplexEpochInfo) { return fmt.Errorf("invalid SimplexEpochInfo: expected %v, got %v", expectedSimplexEpochInfo, simplexEpochInfo) } - _, err = sm.verifyZeroBlockTimestamp(block, prevBlock) - if err != nil { - return err + // The InnerBlock must match the last non-Simplex inner block. + if block.InnerBlock != nil { + return fmt.Errorf("zero block must not have an inner block") } - - if block.InnerBlock == nil { - return nil + if prevBlock.InnerBlock.Digest() != sm.LastNonSimplexInnerBlock.Digest() { + return fmt.Errorf("zero block inner block digest does not match last non-Simplex inner block digest") } - return block.InnerBlock.Verify(ctx) -} - -func (sm *StateMachine) verifyZeroBlockTimestamp(block *StateMachineBlock, prevBlock StateMachineBlock) (time.Time, error) { - var proposedTime time.Time - if block.InnerBlock != nil { - proposedTime = block.InnerBlock.Timestamp() - } else { - proposedTime = time.UnixMilli(int64(prevBlock.Metadata.Timestamp)) + // The timestamp must equal the last non-Simplex inner block's timestamp. + expectedTimestamp := uint64(sm.LastNonSimplexInnerBlock.Timestamp().UnixMilli()) + if block.Metadata.Timestamp != expectedTimestamp { + return fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, block.Metadata.Timestamp) } - expectedTimestamp := proposedTime.UnixMilli() - if expectedTimestamp != int64(block.Metadata.Timestamp) { - return time.Time{}, fmt.Errorf("expected timestamp to be %d but got %d", expectedTimestamp, int64(block.Metadata.Timestamp)) - } - currentTime := sm.GetTime() - if currentTime.Add(sm.TimeSkewLimit).Before(proposedTime) { - return time.Time{}, fmt.Errorf("proposed block timestamp is too far in the future, current time is %s but got %s", currentTime.String(), proposedTime.String()) - } - if prevBlock.Metadata.Timestamp > block.Metadata.Timestamp { - return time.Time{}, fmt.Errorf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", prevBlock.Metadata.Timestamp, block.Metadata.Timestamp) - } - return proposedTime, nil + return nil } func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { @@ -538,12 +573,14 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren // We retrieve approvals that validators have sent us for the next epoch. // These approvals are signed by validators of the next epoch. approvalsFromPeers := sm.ApprovalsRetriever.RetrieveApprovals() + sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) + nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight prevNextEpochApprovals := parentBlock.Metadata.SimplexEpochInfo.NextEpochApprovals sigAggr := sm.SignatureAggregatorCreator(validators.NodeWeights()) - newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators) + newApprovals, err := computeNewApprovals(prevNextEpochApprovals, approvalsFromPeers, nextPChainHeight, sigAggr, validators, sm.Logger) if err != nil { return nil, err } @@ -618,16 +655,11 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat } simplexEpochInfo.PrevSealingBlockHash = prevSealingBlock.Digest() } else { // Else, this is the first epoch, so we use the hash of the first ever Simplex block. - - firstSimplexBlock, err := findFirstSimplexBlock(sm.GetBlock, sm.LatestPersistedHeight+1) - if err != nil { - return nil, fmt.Errorf("failed to find first simplex block: %w", err) + firstSimplexBlock := sm.FirstEverSimplexBlock() + if firstSimplexBlock == nil { + return nil, fmt.Errorf("first ever Simplex block is not set, but attempted to create a sealing block for the first epoch") } - firstSimplexBlockRetrieved, _, err := sm.GetBlock(firstSimplexBlock, [32]byte{}) - if err != nil { - return nil, fmt.Errorf("failed to retrieve first simplex block at height %d: %w", firstSimplexBlock, err) - } - simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlockRetrieved.Digest() + simplexEpochInfo.PrevSealingBlockHash = firstSimplexBlock.Digest() } return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, simplexEpochInfo, pChainHeight) @@ -712,8 +744,8 @@ func (sm *StateMachine) buildBlockEpochSealed(ctx context.Context, parentBlock S return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, parentBlock.Metadata.PChainHeight, simplexMetadata, simplexBlacklist), nil } -// constructSimplexZeroBlock constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. -func constructSimplexZeroBlock(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { +// constructSimplexZeroBlockSimplexEpochInfo constructs the SimplexEpochInfo for the zero block, which is the first ever block built by Simplex. +func constructSimplexZeroBlockSimplexEpochInfo(pChainHeight uint64, newValidatorSet NodeBLSMappings, prevVMBlockSeq uint64) SimplexEpochInfo { newSimplexEpochInfo := SimplexEpochInfo{ PChainReferenceHeight: pChainHeight, EpochNumber: 1, @@ -739,6 +771,7 @@ func computeNewApprovals( pChainHeight uint64, sigAggr simplex.SignatureAggregator, validators NodeBLSMappings, + logger simplex.Logger, ) (*approvals, error) { if nextEpochApprovals == nil { nextEpochApprovals = &NextEpochApprovals{} @@ -752,13 +785,16 @@ func computeNewApprovals( nodeID2ValidatorIndex[nbm.NodeID] = i } + oldApprovalFromPeersCount := len(approvalsFromPeers) // We have the approvals obtained from peers, but we need to sanitize them by filtering out approvals that are not valid, // such as approvals that do not agree with our candidate auxiliary info digest and P-Chain height, // and approvals that are from nodes that are not in the validator set or have already approved in prior blocks. - approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes) + approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes, logger) + logger.Debug("Santizied approvals after filtering out invalid approvals", zap.Int("numApprovalsBefore", oldApprovalFromPeersCount), zap.Int("numApprovalsAfter", len(approvalsFromPeers))) + // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. - aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr) + aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr, logger) if err != nil { return nil, err } @@ -776,7 +812,14 @@ func computeNewApprovals( // computeNewApproverSignaturesAndSigners computes the signatures of the nodes that approve the next epoch including the previous aggregated signature, // and bitmask of nodes that correspond to those signatures, and aggregates all signatures together. -func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int, sigAggr simplex.SignatureAggregator) ([]byte, bitmask, error) { +func computeNewApproverSignaturesAndSigners( + nextEpochApprovals *NextEpochApprovals, + approvalsFromPeers ValidatorSetApprovals, + oldApprovingNodes bitmask, + nodeID2ValidatorIndex map[nodeID]int, + sigAggr simplex.SignatureAggregator, + logger simplex.Logger, + ) ([]byte, bitmask, error) { if nextEpochApprovals == nil { return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") } @@ -786,6 +829,11 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // We will overwrite the old approving nodes with the new approving nodes, by turning on the bits for the new approvers. newApprovingNodes := oldApprovingNodes.Clone() + logger.Debug("Existing approving nodes bitmask before adding new approvals", + zap.Int("count", oldApprovingNodes.Len())) + + logger.Debug("New approvals from peers that we will consider for aggregation", zap.Int("count", len(approvalsFromPeers))) + for _, approval := range approvalsFromPeers { approvingNodeIndexOfNewApprover, exists := nodeID2ValidatorIndex[approval.NodeID] if !exists { @@ -817,23 +865,32 @@ func computeNewApproverSignaturesAndSigners(nextEpochApprovals *NextEpochApprova // sanitizeApprovals filters out approvals that are not valid by checking if they agree with our candidate auxiliary info digest and P-Chain height, // and if they are from the validator set and haven't already been approved. -func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask) ValidatorSetApprovals { - filter1 := approvalsThatAgreeWithAuxInfoAndPChainHeight(pChainHeight) +func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nodeID2ValidatorIndex map[nodeID]int, oldApprovingNodes bitmask, logger simplex.Logger) ValidatorSetApprovals { + filter1 := approvalsThatAgreeWithPChainHeight(pChainHeight) filter2 := approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes.Clone(), nodeID2ValidatorIndex) - return approvals.Filter(filter1).Filter(filter2).UniqueByNodeID() + return approvals.Filter(filter1, logger).Filter(filter2, logger).UniqueByNodeID() } -func approvalsThatAgreeWithAuxInfoAndPChainHeight(pChainHeight uint64) func(i int, approval ValidatorSetApproval) bool { - return func(i int, approval ValidatorSetApproval) bool { +func approvalsThatAgreeWithPChainHeight(pChainHeight uint64) func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { // Pick only approvals that agree with our P-Chain height - return approval.PChainHeight == pChainHeight + ok := approval.PChainHeight == pChainHeight + if !ok { + logger.Debug("Filtering out approval that does not agree with our P-Chain height", + zap.String("nodeID", fmt.Sprintf("%x", approval.NodeID)), + zap.Uint64("approvalPChainHeight", approval.PChainHeight), + zap.Uint64("expectedPChainHeight", pChainHeight)) + } + return ok } } -func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(i int, approval ValidatorSetApproval) bool { - return func(i int, approval ValidatorSetApproval) bool { +func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { approvingNodeIndexOfNewApprover, exists := nodeID2ValidatorIndex[approval.NodeID] if !exists { + logger.Debug("Filtering out approval from node that is not in the validator set", + zap.String("nodeID", fmt.Sprintf("%x", approval.NodeID))) // If the approving node is not in the validator set, we ignore this approval. return false } @@ -842,39 +899,6 @@ func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes b } } -func findFirstSimplexBlock(getBlock BlockRetriever, endHeight uint64) (uint64, error) { - var haltError error - - if endHeight > math.MaxInt-1 { - return 0, fmt.Errorf("endHeight %d is too big, must be at most %d", endHeight, math.MaxInt-1) - } - - firstSimplexBlock := sort.Search(int(endHeight+1), func(i int) bool { - if haltError != nil { - return true - } - block, _, err := getBlock(uint64(i), [32]byte{}) - if errors.Is(err, simplex.ErrBlockNotFound) { - return false - } - if err != nil { - haltError = fmt.Errorf("error retrieving block at height %d: %w", i, err) - return false - } - // The first Simplex block is such that its epoch info isn't the zero value. - return !block.Metadata.SimplexEpochInfo.IsZero() - }) - if haltError != nil { - return 0, haltError - } - - if uint64(firstSimplexBlock) > endHeight { - return 0, fmt.Errorf("no simplex blocks found in range [%d, %d]", 0, endHeight) - } - - return uint64(firstSimplexBlock), nil -} - func computePrevVMBlockSeq(parentBlock StateMachineBlock, prevBlockSeq uint64) uint64 { // Either our parent block has no inner block, in which case we just inherit its previous VM block sequence, if parentBlock.InnerBlock == nil { diff --git a/msm/msm_test.go b/msm/msm_test.go index 6ec03b39..a7ada2fb 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -224,28 +224,15 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.PChainHeight = 110 }, - err: "invalid P-chain height (110) is too big", + err: "invalid P-chain height (110), expected to be 100", }, { name: "P-chain height smaller than parent", md: validMD, - configure: func(_ *StateMachine, tc *testConfig) { - tc.blockStore[0] = &outerBlock{ - block: StateMachineBlock{ - InnerBlock: &InnerBlock{TS: time.Now(), Bytes: []byte{1, 2, 3}}, - Metadata: StateMachineMetadata{PChainHeight: 110}, - }, - } + configure: func(sm *StateMachine, tc *testConfig) { + sm.LastNonSimplexBlockPChainHeight = 99 }, - err: "invalid P-chain height (100) is smaller than parent InnerBlock's P-chain height (110)", - }, - { - name: "validator set retrieval fails", - md: validMD, - configure: func(_ *StateMachine, tc *testConfig) { - tc.validatorSetRetriever.err = fmt.Errorf("validator set unavailable") - }, - err: "failed to retrieve validator set", + err: "invalid P-chain height (100), expected to be 99", }, { name: "nil BlockValidationDescriptor", @@ -258,8 +245,8 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { { name: "membership mismatch", md: validMD, - configure: func(_ *StateMachine, tc *testConfig) { - tc.validatorSetRetriever.result = NodeBLSMappings{ + configure: func(sm *StateMachine, tc *testConfig) { + sm.GenesisValidatorSet = NodeBLSMappings{ {BLSKey: []byte{1}, Weight: 1}, } }, @@ -337,6 +324,9 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { testConfig1.blockStore[42] = &outerBlock{block: preSimplexParent} testConfig2.blockStore[42] = &outerBlock{block: preSimplexParent} + sm1.LastNonSimplexInnerBlock = testConfig1.blockStore[42].block.InnerBlock + sm2.LastNonSimplexInnerBlock = testConfig1.blockStore[42].block.InnerBlock + testConfig1.blockBuilder.block = &InnerBlock{ TS: time.Now(), BlockHeight: 43, @@ -350,11 +340,6 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { require.NoError(t, sm2.VerifyBlock(context.Background(), block)) require.Equal(t, &StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: testConfig1.blockBuilder.block.Timestamp(), - BlockHeight: 43, - Bytes: []byte{7, 8, 9}, - }, Metadata: StateMachineMetadata{ Timestamp: uint64(testConfig1.blockBuilder.block.Timestamp().UnixMilli()), PChainHeight: 100, @@ -437,7 +422,7 @@ func TestMSMNormalOp(t *testing.T) { mutateBlock: func(block *StateMachineBlock) { block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} }, - err: "failed to find first Simplex block", + err: "expected validator set specified at P-chain height 0 does not match validator set encoded in new block", }, { name: "non-zero sealing block seq", @@ -635,10 +620,18 @@ func TestMSMFullEpochLifecycle(t *testing.T) { tc.blockStore[0] = &outerBlock{block: genesis} tc.blockStore[42] = &outerBlock{block: notGenesis} + sm.LastNonSimplexInnerBlock = testCase.firstBlockBeforeSimplex.InnerBlock + sm.GenesisValidatorSet = validatorSet1 + sm.LastNonSimplexBlockPChainHeight = pChainHeight1 + smVerify, tcVerify := newStateMachine(t) smVerify.GetValidatorSet = getValidatorSet smVerify.GetPChainHeight = getPChainHeight + smVerify.LastNonSimplexInnerBlock = testCase.firstBlockBeforeSimplex.InnerBlock + smVerify.GenesisValidatorSet = validatorSet1 + smVerify.LastNonSimplexBlockPChainHeight = pChainHeight1 + // addBlock adds a block to both block stores so builder and verifier stay in sync. addBlock := func(seq uint64, block StateMachineBlock, fin *simplex.Finalization) { tc.blockStore[seq] = &outerBlock{block: block, finalization: fin} @@ -662,9 +655,8 @@ func TestMSMFullEpochLifecycle(t *testing.T) { block1, err := sm.BuildBlock(context.Background(), md, nil) require.NoError(t, err) require.Equal(t, &StateMachineBlock{ - InnerBlock: nextBlock(1), Metadata: StateMachineMetadata{ - Timestamp: uint64(startTime.Add(1 * time.Millisecond).UnixMilli()), + Timestamp: uint64(startTime.UnixMilli()), PChainHeight: pChainHeight1, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{ @@ -701,7 +693,7 @@ func TestMSMFullEpochLifecycle(t *testing.T) { SimplexEpochInfo: SimplexEpochInfo{ PChainReferenceHeight: pChainHeight1, EpochNumber: 1, - PrevVMBlockSeq: baseSeq + 1, + PrevVMBlockSeq: baseSeq, }, }, }, block2) @@ -1056,6 +1048,22 @@ func newStateMachine(t *testing.T) (StateMachine, *testConfig) { } sm := StateMachine{ + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + LastNonSimplexBlockPChainHeight: 100, + FirstEverSimplexBlock: func() *StateMachineBlock { + var res *StateMachineBlock + min := uint64(math.MaxUint64) + for seq, block := range testConfig.blockStore { + if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { + continue + } + if seq < min { + min = seq + res = &block.block + } + } + return res + }, GetTime: time.Now, TimeSkewLimit: time.Second * 5, Logger: testutil.MakeLogger(t), @@ -1074,6 +1082,7 @@ func newStateMachine(t *testing.T) (StateMachine, *testConfig) { }, GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, PChainProgressListener: &noOpPChainListener{}, + LastNonSimplexInnerBlock: genesisBlock.InnerBlock, } return sm, &testConfig } @@ -1182,60 +1191,6 @@ func TestComputePrevVMBlockSeq(t *testing.T) { }) } -func TestFindFirstSimplexBlock(t *testing.T) { - t.Run("endHeight too big", func(t *testing.T) { - getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, nil - } - _, err := findFirstSimplexBlock(getBlock, math.MaxUint64) - require.ErrorContains(t, err, fmt.Sprintf(" is too big, must be at most %d", math.MaxInt64-1)) - }) - - t.Run("found at height 3", func(t *testing.T) { - getBlock := func(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - if seq < 3 { - return StateMachineBlock{}, nil, nil - } - return StateMachineBlock{ - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, - }, nil, nil - } - result, err := findFirstSimplexBlock(getBlock, 5) - require.NoError(t, err) - require.Equal(t, uint64(3), result) - }) - - t.Run("no simplex blocks found", func(t *testing.T) { - getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, nil - } - _, err := findFirstSimplexBlock(getBlock, 5) - require.ErrorContains(t, err, "no simplex blocks found") - }) - - t.Run("block not found errors are skipped", func(t *testing.T) { - getBlock := func(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - if seq < 2 { - return StateMachineBlock{}, nil, simplex.ErrBlockNotFound - } - return StateMachineBlock{ - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, - }, nil, nil - } - result, err := findFirstSimplexBlock(getBlock, 5) - require.NoError(t, err) - require.Equal(t, uint64(2), result) - }) - - t.Run("retrieval error propagated", func(t *testing.T) { - getBlock := func(_ uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - return StateMachineBlock{}, nil, fmt.Errorf("disk error") - } - _, err := findFirstSimplexBlock(getBlock, 5) - require.ErrorContains(t, err, "disk error") - }) -} - func TestSanitizeApprovals(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} @@ -1248,13 +1203,15 @@ func TestSanitizeApprovals(t *testing.T) { node2: 2, } + logger := testutil.MakeLogger(t) + t.Run("filters by p-chain height", func(t *testing.T) { approvals := ValidatorSetApprovals{ {NodeID: node0, PChainHeight: 100}, {NodeID: node1, PChainHeight: 200}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node0, result[0].NodeID) }) @@ -1265,7 +1222,7 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node1, PChainHeight: 100}, } oldApproving := bitmaskFromBytes([]byte{1}) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node1, result[0].NodeID) }) @@ -1276,7 +1233,7 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node2, PChainHeight: 100}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) require.Equal(t, node2, result[0].NodeID) }) @@ -1287,7 +1244,7 @@ func TestSanitizeApprovals(t *testing.T) { {NodeID: node0, PChainHeight: 100}, } oldApproving := bitmaskFromBytes(nil) - result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving) + result := sanitizeApprovals(approvals, 100, nodeID2Index, oldApproving, logger) require.Len(t, result, 1) }) } @@ -1333,6 +1290,8 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { node2: 2, } + logger := testutil.MakeLogger(t) + t.Run("duplicate peer with already-approved node does not double-aggregate", func(t *testing.T) { // node0 is already in the previous approvals (bit 0 set). A duplicate peer // entry for node0 must not append node0's signature to the new aggregate @@ -1348,7 +1307,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, 1, newApproving.Len()) @@ -1364,7 +1323,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node1, Signature: []byte("sig1")}, } - _, _, err := computeNewApproverSignaturesAndSigners(nil, peers, oldApproving, nodeID2Index, concatAggregator{}) + _, _, err := computeNewApproverSignaturesAndSigners(nil, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.Error(t, err) }) @@ -1377,7 +1336,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node1, Signature: []byte("sig1")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.True(t, newApproving.Contains(1)) @@ -1396,7 +1355,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node2, Signature: []byte("sig2")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) // preserved from old require.True(t, newApproving.Contains(2)) // newly added @@ -1411,7 +1370,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { } oldApproving := bitmaskFromBytes([]byte{1}) - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, nil, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, nil, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, []byte("existing"), aggSig) @@ -1427,7 +1386,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}) + aggSig, newApproving, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, concatAggregator{}, logger) require.NoError(t, err) require.True(t, newApproving.Contains(0)) require.Equal(t, 1, newApproving.Len()) @@ -1441,7 +1400,7 @@ func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { {NodeID: node0, Signature: []byte("sig0")}, } - _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}) + _, _, err := computeNewApproverSignaturesAndSigners(prevApprovals, peers, oldApproving, nodeID2Index, failingAggregator{}, logger) require.ErrorContains(t, err, "aggregation failed") }) } diff --git a/msm/verification.go b/msm/verification.go index ba8ef333..532cf57c 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -410,6 +410,7 @@ func (t *timestampVerifier) Verify(in verificationInput) error { type prevSealingBlockHashVerifier struct { getBlock BlockRetriever latestPersistedHeight *uint64 + firstEverSimplexBlock func() *StateMachineBlock } func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { @@ -417,17 +418,11 @@ func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { // Sealing block of the first epoch must point to the first ever Simplex block as the previous sealing block. if prev.EpochNumber == 1 && in.nextBlockType == BlockTypeSealing { - firstEverSimplexBlockSeq, err := findFirstSimplexBlock(p.getBlock, *p.latestPersistedHeight+1) - if err != nil { - return fmt.Errorf("failed to find first Simplex block: %w", err) + firstSimplexBlock := p.firstEverSimplexBlock() + if firstSimplexBlock == nil { + return fmt.Errorf("first ever simplex block sequence number is not set but verifying a sealing block of the first epoch") } - - block, _, err := p.getBlock(firstEverSimplexBlockSeq, [32]byte{}) - if err != nil { - return fmt.Errorf("failed retrieving first ever simplex block %d: %w", firstEverSimplexBlockSeq, err) - } - - hash := block.Digest() + hash := firstSimplexBlock.Digest() if !bytes.Equal(in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash[:], hash[:]) { return fmt.Errorf("expected prev sealing block hash of the first ever simplex block to be %x but got %x", hash, in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) } diff --git a/msm/verification_test.go b/msm/verification_test.go index b515a4f1..898723dc 100644 --- a/msm/verification_test.go +++ b/msm/verification_test.go @@ -379,6 +379,9 @@ func TestPrevSealingBlockHashVerifier(t *testing.T) { v := &prevSealingBlockHashVerifier{ getBlock: bs.getBlock, latestPersistedHeight: &latestPersisted, + firstEverSimplexBlock: func() *StateMachineBlock { + return &firstSimplexBlock + }, } err := v.Verify(verificationInput{ nextBlockType: tc.nextBlockType, From e197d396d22fec7fb9f3862868d58a6d0e490ab5 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Mon, 11 May 2026 23:33:43 +0200 Subject: [PATCH 07/16] Create explicit constructor for MSM Signed-off-by: Yacov Manevich --- msm/fake_node_test.go | 2 +- msm/msm.go | 74 ++++++++++++++++++++++++++----------------- msm/msm_test.go | 12 ++++--- 3 files changed, 53 insertions(+), 35 deletions(-) diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index 97af1986..a0cf2b94 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -178,7 +178,7 @@ type blockState struct { type fakeNode struct { t *testing.T - sm StateMachine + sm *StateMachine mempoolEmpty bool // blocks holds notarized blocks in order. Finalized blocks always form a // prefix: all finalized entries precede all non-finalized entries. diff --git a/msm/msm.go b/msm/msm.go index dbce6351..cdbd93bc 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -71,6 +71,15 @@ type BlockBuilder interface { // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { + // verifiers is the list of verifiers used to verify proposed blocks. + // Each verifier is responsible for verifying a specific aspect of the block's metadata. + verifiers []verifier + + *Config +} + +// Config contains the dependencies and configuration parameters needed to initialize the StateMachine. +type Config struct { // LatestPersistedHeight is the height of the most recently persisted block. LatestPersistedHeight uint64 // MaxBlockBuildingWaitTime is the maximum duration to wait for the VM to build a block @@ -111,13 +120,6 @@ type StateMachine struct { LastNonSimplexInnerBlock VMBlock // GenesisValidatorSet is the validator set used for the genesis block. GenesisValidatorSet NodeBLSMappings - // initialized tracks whether the state machine has been initialized. - // This is used to lazily initialize the verifiers. - initialized bool - - // verifiers is the list of verifiers used to verify proposed blocks. - // Each verifier is responsible for verifying a specific aspect of the block's metadata. - verifiers []verifier } type state uint8 @@ -129,10 +131,14 @@ const ( stateBuildBlockEpochSealed ) +func NewStateMachine(config *Config) *StateMachine { + sm := StateMachine{Config: config} + sm.init() + return &sm +} + // BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex.ProtocolMetadata, simplexBlacklist *simplex.Blacklist) (*StateMachineBlock, error) { - sm.maybeInit() - // The zero sequence number is reserved for the genesis block, which should never be built. if simplexMetadata.Seq == 0 { return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", simplexMetadata.Seq) @@ -189,8 +195,6 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex. // VerifyBlock validates a proposed block by checking its metadata, epoch info, // and inner block against the previous block and the current state. func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBlock) error { - sm.maybeInit() - if block == nil { return fmt.Errorf("InnerBlock is nil") } @@ -223,45 +227,57 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc return err } -func (sm *StateMachine) maybeInit() { - if sm.initialized { - return - } - sm.init() - sm.initialized = true -} - func (sm *StateMachine) init() { sm.verifiers = []verifier{ &pChainHeightVerifier{ - getPChainHeight: sm.GetPChainHeight, + getPChainHeight: func() uint64 { + return sm.Config.GetPChainHeight() + }, }, ×tampVerifier{ timeSkewLimit: sm.TimeSkewLimit, - getTime: sm.GetTime, + getTime: func() time.Time { + return sm.Config.GetTime() + }, }, &pChainReferenceHeightVerifier{}, &epochNumberVerifier{}, &validationDescriptorVerifier{ - getValidatorSet: sm.GetValidatorSet, + getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { + return sm.Config.GetValidatorSet(pChainHeight) + }, }, &prevSealingBlockHashVerifier{ - firstEverSimplexBlock: sm.FirstEverSimplexBlock, - getBlock: sm.GetBlock, + firstEverSimplexBlock: func() *StateMachineBlock { + return sm.Config.FirstEverSimplexBlock() + }, + getBlock: func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + return sm.Config.GetBlock(seq, digest) + }, latestPersistedHeight: &sm.LatestPersistedHeight, }, &nextPChainReferenceHeightVerifier{ - getPChainHeight: sm.GetPChainHeight, - getValidatorSet: sm.GetValidatorSet, + getPChainHeight: func() uint64 { + return sm.Config.GetPChainHeight() + }, + getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { + return sm.Config.GetValidatorSet(pChainHeight) + }, }, &vmBlockSeqVerifier{ - getBlock: sm.GetBlock, + getBlock: func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + return sm.Config.GetBlock(seq, digest) + }, }, &nextEpochApprovalsVerifier{ - getValidatorSet: sm.GetValidatorSet, + getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { + return sm.Config.GetValidatorSet(pChainHeight) + }, keyAggregator: sm.KeyAggregator, sigVerifier: sm.SignatureVerifier, - sigAggregatorCreator: sm.SignatureAggregatorCreator, + sigAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + return sm.Config.SignatureAggregatorCreator(weights) + }, }, &sealingBlockSeqVerifier{}, } diff --git a/msm/msm_test.go b/msm/msm_test.go index a7ada2fb..b139b357 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -274,7 +274,7 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { } if testCase.configure != nil { - testCase.configure(&sm2, testConfig2) + testCase.configure(sm2, testConfig2) } block, err := sm1.BuildBlock(context.Background(), testCase.md, nil) @@ -498,8 +498,8 @@ func TestMSMNormalOp(t *testing.T) { } if testCase.setup != nil { - testCase.setup(&sm1, testConfig1) - testCase.setup(&sm2, testConfig2) + testCase.setup(sm1, testConfig1) + testCase.setup(sm2, testConfig2) } block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) @@ -1037,7 +1037,7 @@ type testConfig struct { validatorSetRetriever validatorSetRetriever } -func newStateMachine(t *testing.T) (StateMachine, *testConfig) { +func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { bs := make(blockStore) bs[0] = &outerBlock{block: genesisBlock} @@ -1047,7 +1047,7 @@ func newStateMachine(t *testing.T) (StateMachine, *testConfig) { {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, } - sm := StateMachine{ + smConfig := Config{ GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, LastNonSimplexBlockPChainHeight: 100, FirstEverSimplexBlock: func() *StateMachineBlock { @@ -1084,6 +1084,8 @@ func newStateMachine(t *testing.T) (StateMachine, *testConfig) { PChainProgressListener: &noOpPChainListener{}, LastNonSimplexInnerBlock: genesisBlock.InnerBlock, } + + sm := NewStateMachine(&smConfig) return sm, &testConfig } From cdd6dd137ff53a53b2cd25d9b22e13849da19887 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Mon, 11 May 2026 23:52:21 +0200 Subject: [PATCH 08/16] fix spelling errors Signed-off-by: Yacov Manevich --- msm/msm.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/msm/msm.go b/msm/msm.go index cdbd93bc..775a7beb 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -683,8 +683,7 @@ func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock Stat // wrapBlock creates a new StateMachineBlock by wrapping the VM block (if applicable) and adding the appropriate metadata. func (sm *StateMachine) wrapBlock(parentBlock StateMachineBlock, childBlock VMBlock, newSimplexEpochInfo SimplexEpochInfo, pChainHeight uint64, simplexMetadata, simplexBlacklist []byte) *StateMachineBlock { - parentMetadata := parentBlock.Metadata - timestamp := parentMetadata.Timestamp + timestamp := parentBlock.Metadata.Timestamp hasChildBlock := childBlock != nil @@ -931,7 +930,7 @@ var ( func ensureNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { if prev.NextEpochApprovals == nil { - // Condition satisifed vacuously. + // Condition satisfied vacuously. return nil } // Else, prev.NextEpochApprovals is not nil. From c04eb9eaf840b07115a1db4aa64aa04431049241 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 16:42:34 +0200 Subject: [PATCH 09/16] nits + gofmt Signed-off-by: Yacov Manevich --- api.go | 2 +- epoch.go | 4 ++-- global.go | 4 ++-- msm/encoding.go | 26 +++++++++++++++++++++++++- msm/fake_node_test.go | 4 ++-- msm/msm.go | 40 ++++++++-------------------------------- msm/msm_test.go | 40 ++++++++++++++++------------------------ msm/verification.go | 8 ++++---- testutil/network.go | 8 ++++---- testutil/node.go | 10 +++++----- testutil/util.go | 7 +++---- 11 files changed, 72 insertions(+), 81 deletions(-) diff --git a/api.go b/api.go index 2e01b78d..fdf5db39 100644 --- a/api.go +++ b/api.go @@ -154,7 +154,7 @@ func (nws NodeWeights) NodesIDs() []NodeID { // NodeWeight is a struct that pairs a node with its weight in the signature aggregator. type NodeWeight struct { - Node NodeID + Node NodeID Weight uint64 } diff --git a/epoch.go b/epoch.go index 220abce8..d6a4466d 100644 --- a/epoch.go +++ b/epoch.go @@ -83,8 +83,8 @@ type EpochConfig struct { type Epoch struct { EpochConfig // Runtime - signatureAggregator SignatureAggregator - oneTimeVerifier *OneTimeVerifier + signatureAggregator SignatureAggregator + oneTimeVerifier *OneTimeVerifier buildBlockScheduler *BasicScheduler blockVerificationScheduler *BlockDependencyManager lock sync.Mutex diff --git a/global.go b/global.go index 6e70a0a6..b632acea 100644 --- a/global.go +++ b/global.go @@ -56,9 +56,9 @@ func (nodes NodeIDs) EqualWeightedNodeWeights() NodeWeights { weights := make(NodeWeights, len(nodes)) for i, node := range nodes { weights[i] = NodeWeight{ - Node: node, + Node: node, Weight: 1, } } return weights -} \ No newline at end of file +} diff --git a/msm/encoding.go b/msm/encoding.go index 1b863f3a..414dd241 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -109,6 +109,30 @@ func (sei *SimplexEpochInfo) Equal(other *SimplexEpochInfo) bool { return true } +func (sei *SimplexEpochInfo) NextState() state { + prevBlockSimplexEpochInfo := sei + // If this is the first ever epoch, then this is also the first ever block to be built by Simplex. + if prevBlockSimplexEpochInfo.EpochNumber == 0 { + return stateFirstSimplexBlock + } + + // If we don't have a next P-chain preference height, it means we are not transitioning to a new epoch just yet. + if prevBlockSimplexEpochInfo.NextPChainReferenceHeight == 0 { + return stateBuildBlockNormalOp + } + + // If the previous block has a sealing block sequence, it's a Telock. + // If it has a block validation descriptor, it's a sealing block. + // Either way, the epoch has been sealed. + if prevBlockSimplexEpochInfo.SealingBlockSeq > 0 || prevBlockSimplexEpochInfo.BlockValidationDescriptor != nil { + return stateBuildBlockEpochSealed + } + + // In any other case, NextPChainReferenceHeight > 0 but the previous block is not a Telock or sealing block, + // it means we are in the process of collecting approvals for the next epoch. + return stateBuildCollectingApprovals +} + type NodeBLSMapping struct { NodeID nodeID `canoto:"fixed bytes,1"` BLSKey []byte `canoto:"bytes,2"` @@ -203,7 +227,7 @@ func (nbms NodeBLSMappings) NodeWeights() simplex.NodeWeights { nodeWeights := make(simplex.NodeWeights, len(nbms)) for i, nbm := range nbms { nodeWeights[i] = simplex.NodeWeight{ - Node: nbm.NodeID[:], + Node: nbm.NodeID[:], Weight: nbm.Weight, } } diff --git a/msm/fake_node_test.go b/msm/fake_node_test.go index a0cf2b94..58eefb30 100644 --- a/msm/fake_node_test.go +++ b/msm/fake_node_test.go @@ -18,7 +18,7 @@ import ( func TestFakeNode(t *testing.T) { validatorSetRetriever := validatorSetRetriever{ resultMap: map[uint64]NodeBLSMappings{ - 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, @@ -85,7 +85,7 @@ func TestFakeNode(t *testing.T) { func TestFakeNodeEmptyMempool(t *testing.T) { validatorSetRetriever := validatorSetRetriever{ resultMap: map[uint64]NodeBLSMappings{ - 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, + 0: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 100: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 1, NodeID: [20]byte{2}}}, 200: {{BLSKey: []byte{1}, Weight: 1, NodeID: [20]byte{1}}, {BLSKey: []byte{2}, Weight: 2, NodeID: [20]byte{2}}, {BLSKey: []byte{3}, Weight: 1, NodeID: [20]byte{3}}}, diff --git a/msm/msm.go b/msm/msm.go index 775a7beb..524527c0 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -125,7 +125,7 @@ type Config struct { type state uint8 const ( - stateFirstSimplexBlock state = iota + stateFirstSimplexBlock state = iota + 1 stateBuildBlockNormalOp stateBuildCollectingApprovals stateBuildBlockEpochSealed @@ -173,7 +173,7 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex. // In order to know where in the epoch change process we are, // we identify the current state by looking at the parent block's epoch info. - currentState := identifyCurrentState(parentBlock.Metadata.SimplexEpochInfo) + currentState := parentBlock.Metadata.SimplexEpochInfo.NextState() simplexMetadataBytes := simplexMetadata.Bytes() prevBlockSeq := simplexMetadata.Seq - 1 @@ -216,7 +216,7 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc } prevMD := prevBlock.Metadata - currentState := identifyCurrentState(prevMD.SimplexEpochInfo) + currentState := prevMD.SimplexEpochInfo.NextState() switch currentState { case stateFirstSimplexBlock: @@ -273,8 +273,8 @@ func (sm *StateMachine) init() { getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { return sm.Config.GetValidatorSet(pChainHeight) }, - keyAggregator: sm.KeyAggregator, - sigVerifier: sm.SignatureVerifier, + keyAggregator: sm.KeyAggregator, + sigVerifier: sm.SignatureVerifier, sigAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { return sm.Config.SignatureAggregatorCreator(weights) }, @@ -323,29 +323,6 @@ func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMach return block.InnerBlock.Verify(ctx) } -func identifyCurrentState(prevBlockSimplexEpochInfo SimplexEpochInfo) state { - // If this is the first ever epoch, then this is also the first ever block to be built by Simplex. - if prevBlockSimplexEpochInfo.EpochNumber == 0 { - return stateFirstSimplexBlock - } - - // If we don't have a next P-chain preference height, it means we are not transitioning to a new epoch just yet. - if prevBlockSimplexEpochInfo.NextPChainReferenceHeight == 0 { - return stateBuildBlockNormalOp - } - - // If the previous block has a sealing block sequence, it's a Telock. - // If it has a block validation descriptor, it's a sealing block. - // Either way, the epoch has been sealed. - if prevBlockSimplexEpochInfo.SealingBlockSeq > 0 || prevBlockSimplexEpochInfo.BlockValidationDescriptor != nil { - return stateBuildBlockEpochSealed - } - - // In any other case, NextPChainReferenceHeight > 0 but the previous block is not a Telock or sealing block, - // it means we are in the process of collecting approvals for the next epoch. - return stateBuildCollectingApprovals -} - // buildBlockNormalOp builds a block while not trying to transition to a new epoch. func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte, prevBlockSeq uint64) (*StateMachineBlock, error) { // Since in the previous block, we were not transitioning to a new epoch, @@ -488,7 +465,7 @@ func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMet return &StateMachineBlock{ Metadata: StateMachineMetadata{ - Timestamp: uint64(timestamp), + Timestamp: uint64(timestamp), SimplexProtocolMetadata: simplexMetadata, SimplexBlacklist: simplexBlacklist, SimplexEpochInfo: simplexEpochInfo, @@ -615,7 +592,7 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren // in which case we just carry over the approvals we have so far to the next block, // so that eventually we'll have enough approvals to seal the epoch. if !newApprovals.canSeal { - sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch",) + sm.Logger.Debug("Not enough approvals to seal epoch, building block without sealing the epoch") return sm.buildBlockImpatiently(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) } @@ -807,7 +784,6 @@ func computeNewApprovals( approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes, logger) logger.Debug("Santizied approvals after filtering out invalid approvals", zap.Int("numApprovalsBefore", oldApprovalFromPeersCount), zap.Int("numApprovalsAfter", len(approvalsFromPeers))) - // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr, logger) if err != nil { @@ -834,7 +810,7 @@ func computeNewApproverSignaturesAndSigners( nodeID2ValidatorIndex map[nodeID]int, sigAggr simplex.SignatureAggregator, logger simplex.Logger, - ) ([]byte, bitmask, error) { +) ([]byte, bitmask, error) { if nextEpochApprovals == nil { return nil, bitmask{}, fmt.Errorf("next epoch approvals is nil") } diff --git a/msm/msm_test.go b/msm/msm_test.go index b139b357..e118783d 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -162,7 +162,7 @@ var ( } ) -func TestMSMFirstBlockAfterGenesis(t *testing.T) { +func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ Round: 1, Seq: 1, @@ -182,7 +182,7 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { md: validMD, }, { - name: "trying to build a genesis block", + name: "verifying a genesis block", md: validMD, mutateBlock: func(block *StateMachineBlock) { md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) @@ -262,17 +262,9 @@ func TestMSMFirstBlockAfterGenesis(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - sm1, testConfig1 := newStateMachine(t) + sm1, _ := newStateMachine(t) sm2, testConfig2 := newStateMachine(t) - testConfig1.blockStore[0] = &outerBlock{ - block: genesisBlock, - } - - testConfig2.blockStore[0] = &outerBlock{ - block: genesisBlock, - } - if testCase.configure != nil { testCase.configure(sm2, testConfig2) } @@ -1048,7 +1040,7 @@ func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { } smConfig := Config{ - GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, LastNonSimplexBlockPChainHeight: 100, FirstEverSimplexBlock: func() *StateMachineBlock { var res *StateMachineBlock @@ -1064,24 +1056,24 @@ func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { } return res }, - GetTime: time.Now, - TimeSkewLimit: time.Second * 5, - Logger: testutil.MakeLogger(t), - GetBlock: testConfig.blockStore.getBlock, - MaxBlockBuildingWaitTime: time.Second, - ApprovalsRetriever: &testConfig.approvalsRetriever, - SignatureVerifier: &testConfig.signatureVerifier, + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, SignatureAggregatorCreator: newSignatureAggregatorCreator(), - BlockBuilder: &testConfig.blockBuilder, - KeyAggregator: &testConfig.keyAggregator, + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, GetPChainHeight: func() uint64 { return 100 }, GetUpgrades: func() any { return nil }, - GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, - PChainProgressListener: &noOpPChainListener{}, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, LastNonSimplexInnerBlock: genesisBlock.InnerBlock, } @@ -1123,7 +1115,7 @@ func TestIdentifyCurrentState(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - result := identifyCurrentState(tc.input) + result := tc.input.NextState() require.Equal(t, tc.expected, result) }) } diff --git a/msm/verification.go b/msm/verification.go index 532cf57c..d9d30e83 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -64,10 +64,10 @@ func (vd *validationDescriptorVerifier) verifyEmptyValidationDescriptor(_ Simple } type nextEpochApprovalsVerifier struct { - sigVerifier SignatureVerifier - getValidatorSet ValidatorSetRetriever - keyAggregator KeyAggregator - sigAggregatorCreator simplex.SignatureAggregatorCreator + sigVerifier SignatureVerifier + getValidatorSet ValidatorSetRetriever + keyAggregator KeyAggregator + sigAggregatorCreator simplex.SignatureAggregatorCreator } func (nv *nextEpochApprovalsVerifier) Verify(in verificationInput) error { diff --git a/testutil/network.go b/testutil/network.go index e9ee0f79..6d908af8 100644 --- a/testutil/network.go +++ b/testutil/network.go @@ -14,10 +14,10 @@ import ( ) type BasicInMemoryNetwork struct { - t *testing.T - nodes []simplex.NodeID - nodeWeights simplex.NodeWeights - lock sync.RWMutex + t *testing.T + nodes []simplex.NodeID + nodeWeights simplex.NodeWeights + lock sync.RWMutex disconnected map[string]struct{} instances []*BasicNode } diff --git a/testutil/node.go b/testutil/node.go index adfaf6bd..21784a1e 100644 --- a/testutil/node.go +++ b/testutil/node.go @@ -234,11 +234,11 @@ func UpdateEpochConfig(epochConfig *simplex.EpochConfig, testConfig *TestNodeCon // NodeConfig type TestNodeConfig struct { // optional - InitialStorage []simplex.VerifiedFinalizedBlock - Comm simplex.Communication - SigAggregatorCreator simplex.SignatureAggregatorCreator - ReplicationEnabled bool - BlockBuilder *testControlledBlockBuilder + InitialStorage []simplex.VerifiedFinalizedBlock + Comm simplex.Communication + SigAggregatorCreator simplex.SignatureAggregatorCreator + ReplicationEnabled bool + BlockBuilder *testControlledBlockBuilder // Long Running Tests MaxRoundWindow uint64 diff --git a/testutil/util.go b/testutil/util.go index 9426e43f..f0232059 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -35,9 +35,9 @@ func DefaultTestNodeEpochConfig(t *testing.T, nodeID simplex.NodeID, comm simple SignatureAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { return &TestSignatureAggregator{N: len(weights)} }, - BlockDeserializer: &BlockDeserializer{}, - QCDeserializer: &testQCDeserializer{t: t}, - StartTime: time.Now(), + BlockDeserializer: &BlockDeserializer{}, + QCDeserializer: &testQCDeserializer{t: t}, + StartTime: time.Now(), } return conf, wal, storage } @@ -128,7 +128,6 @@ type TestSignatureAggregatorCreator struct { IsQuorumFunc func(signatures []simplex.NodeID) bool } - type TestSignatureAggregator struct { Err error N int From b49e577b46720a6b64eabf07a9985905108eef9f Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 16:45:57 +0200 Subject: [PATCH 10/16] Move around definitions Signed-off-by: Yacov Manevich --- msm/misc_test.go | 84 ++++++++++++++++++++++++++++++++++++++++ msm/msm.go | 14 +++++++ msm/verification.go | 13 ------- msm/verification_test.go | 80 -------------------------------------- 4 files changed, 98 insertions(+), 93 deletions(-) diff --git a/msm/misc_test.go b/msm/misc_test.go index b899aa6a..8a1c0ada 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -4,9 +4,14 @@ package metadata import ( + "context" + "crypto/sha256" + "fmt" "math" "testing" + "time" + "github.com/ava-labs/simplex" "github.com/stretchr/testify/require" ) @@ -142,3 +147,82 @@ func TestBitmask(t *testing.T) { require.False(t, cloned.Contains(7)) }) } + + +// Test helpers + +type testBlockStore map[uint64]StateMachineBlock + +func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, ok := bs[seq] + if !ok { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) + } + return blk, nil, nil +} + +type testVMBlock struct { + bytes []byte + height uint64 +} + +func (b *testVMBlock) Digest() [32]byte { + return sha256.Sum256(b.bytes) +} + +func (b *testVMBlock) Height() uint64 { + return b.height +} + +func (b *testVMBlock) Timestamp() time.Time { + return time.Now() +} + +func (b *testVMBlock) Verify(_ context.Context) error { + return nil +} + +type testSigVerifier struct { + err error +} + +func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { + return sv.err +} + +type testKeyAggregator struct { + err error +} + +func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + if ka.err != nil { + return nil, ka.err + } + var agg []byte + for _, k := range keys { + agg = append(agg, k...) + } + return agg, nil +} + +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} + +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) +} + +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context) error { + return nil +} diff --git a/msm/msm.go b/msm/msm.go index 524527c0..bf434aee 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -69,6 +69,20 @@ type BlockBuilder interface { WaitForPendingBlock(ctx context.Context) } +type verificationInput struct { + prevMD StateMachineMetadata + proposedBlockMD StateMachineMetadata + hasInnerBlock bool + innerBlockTimestamp time.Time // only set when hasInnerBlock is true + prevBlockSeq uint64 + nextBlockType BlockType + state state +} + +type verifier interface { + Verify(in verificationInput) error +} + // StateMachine manages block building and verification across epoch transitions. type StateMachine struct { // verifiers is the list of verifiers used to verify proposed blocks. diff --git a/msm/verification.go b/msm/verification.go index d9d30e83..2349b13d 100644 --- a/msm/verification.go +++ b/msm/verification.go @@ -12,19 +12,6 @@ import ( "github.com/ava-labs/simplex" ) -type verificationInput struct { - prevMD StateMachineMetadata - proposedBlockMD StateMachineMetadata - hasInnerBlock bool - innerBlockTimestamp time.Time // only set when hasInnerBlock is true - prevBlockSeq uint64 - nextBlockType BlockType - state state -} - -type verifier interface { - Verify(in verificationInput) error -} type validationDescriptorVerifier struct { getValidatorSet ValidatorSetRetriever } diff --git a/msm/verification_test.go b/msm/verification_test.go index 898723dc..22c6b3c7 100644 --- a/msm/verification_test.go +++ b/msm/verification_test.go @@ -4,8 +4,6 @@ package metadata import ( - "context" - "crypto/sha256" "fmt" "testing" "time" @@ -940,81 +938,3 @@ func TestSealingBlockSeqVerifier(t *testing.T) { }) } } - -// Test helpers - -type testBlockStore map[uint64]StateMachineBlock - -func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, ok := bs[seq] - if !ok { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) - } - return blk, nil, nil -} - -type testVMBlock struct { - bytes []byte - height uint64 -} - -func (b *testVMBlock) Digest() [32]byte { - return sha256.Sum256(b.bytes) -} - -func (b *testVMBlock) Height() uint64 { - return b.height -} - -func (b *testVMBlock) Timestamp() time.Time { - return time.Now() -} - -func (b *testVMBlock) Verify(_ context.Context) error { - return nil -} - -type testSigVerifier struct { - err error -} - -func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { - return sv.err -} - -type testKeyAggregator struct { - err error -} - -func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - if ka.err != nil { - return nil, ka.err - } - var agg []byte - for _, k := range keys { - agg = append(agg, k...) - } - return agg, nil -} - -type InnerBlock struct { - TS time.Time - BlockHeight uint64 - Bytes []byte -} - -func (i *InnerBlock) Digest() [32]byte { - return sha256.Sum256(i.Bytes) -} - -func (i *InnerBlock) Height() uint64 { - return i.BlockHeight -} - -func (i *InnerBlock) Timestamp() time.Time { - return i.TS -} - -func (i *InnerBlock) Verify(_ context.Context) error { - return nil -} From a955d7ad8c777ac486e30a36c92b37be749448f2 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 16:57:22 +0200 Subject: [PATCH 11/16] Remove verification to introduce it later Signed-off-by: Yacov Manevich --- msm/msm.go | 57 --- msm/msm_test.go | 92 ---- msm/verification.go | 496 --------------------- msm/verification_test.go | 940 --------------------------------------- 4 files changed, 1585 deletions(-) delete mode 100644 msm/verification.go delete mode 100644 msm/verification_test.go diff --git a/msm/msm.go b/msm/msm.go index bf434aee..d74727f7 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -147,7 +147,6 @@ const ( func NewStateMachine(config *Config) *StateMachine { sm := StateMachine{Config: config} - sm.init() return &sm } @@ -241,62 +240,6 @@ func (sm *StateMachine) VerifyBlock(ctx context.Context, block *StateMachineBloc return err } -func (sm *StateMachine) init() { - sm.verifiers = []verifier{ - &pChainHeightVerifier{ - getPChainHeight: func() uint64 { - return sm.Config.GetPChainHeight() - }, - }, - ×tampVerifier{ - timeSkewLimit: sm.TimeSkewLimit, - getTime: func() time.Time { - return sm.Config.GetTime() - }, - }, - &pChainReferenceHeightVerifier{}, - &epochNumberVerifier{}, - &validationDescriptorVerifier{ - getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { - return sm.Config.GetValidatorSet(pChainHeight) - }, - }, - &prevSealingBlockHashVerifier{ - firstEverSimplexBlock: func() *StateMachineBlock { - return sm.Config.FirstEverSimplexBlock() - }, - getBlock: func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - return sm.Config.GetBlock(seq, digest) - }, - latestPersistedHeight: &sm.LatestPersistedHeight, - }, - &nextPChainReferenceHeightVerifier{ - getPChainHeight: func() uint64 { - return sm.Config.GetPChainHeight() - }, - getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { - return sm.Config.GetValidatorSet(pChainHeight) - }, - }, - &vmBlockSeqVerifier{ - getBlock: func(seq uint64, digest [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - return sm.Config.GetBlock(seq, digest) - }, - }, - &nextEpochApprovalsVerifier{ - getValidatorSet: func(pChainHeight uint64) (NodeBLSMappings, error) { - return sm.Config.GetValidatorSet(pChainHeight) - }, - keyAggregator: sm.KeyAggregator, - sigVerifier: sm.SignatureVerifier, - sigAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { - return sm.Config.SignatureAggregatorCreator(weights) - }, - }, - &sealingBlockSeqVerifier{}, - } -} - func (sm *StateMachine) verifyNonZeroBlock(ctx context.Context, block *StateMachineBlock, prevBlockMD StateMachineMetadata, state state, prevSeq uint64) error { blockType := IdentifyBlockType(block.Metadata, prevBlockMD, prevSeq) sm.Logger.Debug("Identified block type", diff --git a/msm/msm_test.go b/msm/msm_test.go index e118783d..405725b5 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -359,8 +359,6 @@ func TestMSMNormalOp(t *testing.T) { for _, testCase := range []struct { name string setup func(*StateMachine, *testConfig) - mutateBlock func(*StateMachineBlock) - err string expectedPChainHeight uint64 expectedNextPChainRefHeight uint64 }{ @@ -368,82 +366,6 @@ func TestMSMNormalOp(t *testing.T) { name: "correct information", expectedPChainHeight: 100, }, - { - name: "trying to build a genesis block", - mutateBlock: func(block *StateMachineBlock) { - md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) - require.NoError(t, err) - md.Seq = 0 - block.Metadata.SimplexProtocolMetadata = md.Bytes() - }, - err: "attempted to build a genesis inner block", - }, - { - name: "previous block not found", - mutateBlock: func(block *StateMachineBlock) { - md, err := simplex.ProtocolMetadataFromBytes(block.Metadata.SimplexProtocolMetadata) - require.NoError(t, err) - md.Seq = 999 - block.Metadata.SimplexProtocolMetadata = md.Bytes() - }, - err: "failed to retrieve previous (998) inner block", - }, - { - name: "P-chain height too big", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.PChainHeight = 110 - }, - err: "invalid P-chain height (110) is too big", - }, - { - name: "P-chain height smaller than parent", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.PChainHeight = 0 - }, - err: "invalid P-chain height (0) is smaller than parent block's P-chain height (100)", - }, - { - name: "wrong epoch number", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.EpochNumber = 2 - }, - err: "expected epoch number to be 1 but got 2", - }, - { - name: "non-nil BlockValidationDescriptor", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.BlockValidationDescriptor = &BlockValidationDescriptor{} - }, - err: "expected validator set specified at P-chain height 0 does not match validator set encoded in new block", - }, - { - name: "non-zero sealing block seq", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.SealingBlockSeq = 5 - }, - err: "expected sealing block sequence number to be 0 but got 5", - }, - { - name: "wrong PChainReferenceHeight", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.PChainReferenceHeight = 50 - }, - err: "expected P-chain reference height to be 100 but got 50", - }, - { - name: "non-empty PrevSealingBlockHash", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.PrevSealingBlockHash = [32]byte{1, 2, 3} - }, - err: "expected prev sealing block hash of a non sealing block to be empty", - }, - { - name: "wrong PrevVMBlockSeq", - mutateBlock: func(block *StateMachineBlock) { - block.Metadata.SimplexEpochInfo.PrevVMBlockSeq = 999 - }, - err: "expected PrevVMBlockSeq to be", - }, { name: "validator set change detected", setup: func(sm *StateMachine, tc *testConfig) { @@ -459,11 +381,9 @@ func TestMSMNormalOp(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { chain := makeChain(t, 5, 10) sm1, testConfig1 := newStateMachine(t) - sm2, testConfig2 := newStateMachine(t) for i, block := range chain { testConfig1.blockStore[uint64(i)] = &outerBlock{block: block} - testConfig2.blockStore[uint64(i)] = &outerBlock{block: block} } lastBlock := chain[len(chain)-1] @@ -491,24 +411,12 @@ func TestMSMNormalOp(t *testing.T) { if testCase.setup != nil { testCase.setup(sm1, testConfig1) - testCase.setup(sm2, testConfig2) } block1, err := sm1.BuildBlock(context.Background(), *md, &blacklist) require.NoError(t, err) require.NotNil(t, block1) - if testCase.mutateBlock != nil { - testCase.mutateBlock(block1) - } - - err = sm2.VerifyBlock(context.Background(), block1) - if testCase.err != "" { - require.ErrorContains(t, err, testCase.err) - return - } - require.NoError(t, err) - require.Equal(t, &StateMachineBlock{ InnerBlock: &InnerBlock{ TS: blockTime, diff --git a/msm/verification.go b/msm/verification.go deleted file mode 100644 index 2349b13d..00000000 --- a/msm/verification.go +++ /dev/null @@ -1,496 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package metadata - -import ( - "bytes" - "encoding/binary" - "fmt" - "time" - - "github.com/ava-labs/simplex" -) - -type validationDescriptorVerifier struct { - getValidatorSet ValidatorSetRetriever -} - -func (vd *validationDescriptorVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - switch in.nextBlockType { - case BlockTypeSealing: - return vd.verifySealingBlock(prev, next) - default: - return vd.verifyEmptyValidationDescriptor(prev, next) - } -} - -func (vd *validationDescriptorVerifier) verifySealingBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - validators, err := vd.getValidatorSet(prev.NextPChainReferenceHeight) - if err != nil { - return err - } - - if next.BlockValidationDescriptor == nil { - return fmt.Errorf("validation descriptor should not be nil for a sealing block") - } - - if !validators.Equal(next.BlockValidationDescriptor.AggregatedMembership.Members) { - return fmt.Errorf("expected validator set specified at P-chain height %d does not match validator set encoded in new block", next.NextPChainReferenceHeight) - } - - return nil -} - -func (vd *validationDescriptorVerifier) verifyEmptyValidationDescriptor(_ SimplexEpochInfo, next SimplexEpochInfo) error { - if next.BlockValidationDescriptor != nil { - return fmt.Errorf("block validation descriptor should be nil but got %v", next.BlockValidationDescriptor) - } - return nil -} - -type nextEpochApprovalsVerifier struct { - sigVerifier SignatureVerifier - getValidatorSet ValidatorSetRetriever - keyAggregator KeyAggregator - sigAggregatorCreator simplex.SignatureAggregatorCreator -} - -func (nv *nextEpochApprovalsVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - switch in.nextBlockType { - case BlockTypeSealing: - return nv.verifySealingBlock(prev, next) - case BlockTypeNormal: - return nv.verifyNormal(prev, next) - default: - return nv.verifyEmptyNextEpochApprovals(prev, next) - } -} - -func (nv *nextEpochApprovalsVerifier) verifySealingBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if next.NextEpochApprovals == nil { - return fmt.Errorf("next epoch approvals should not be nil for a sealing block") - } - - validators, err := nv.getValidatorSet(prev.NextPChainReferenceHeight) - if err != nil { - return err - } - - err = nv.verifySignature(prev, next, validators) - if err != nil { - return err - } - - approvingNodes := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) - sigAggr := nv.sigAggregatorCreator(validators.NodeWeights()) - canSeal := sigAggr.IsQuorum(validators.SelectSubset(approvingNodes)) - - if !canSeal { - return fmt.Errorf("not enough approvals to seal block") - } - - return nil -} - -func (nv *nextEpochApprovalsVerifier) verifyNormal(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if prev.NextPChainReferenceHeight == 0 { - return nil - } - - // Otherwise, prev.NextPChainReferenceHeight > 0, so this means we're collecting approvals - - if next.NextEpochApprovals == nil { - // The node that proposed the block should have included at least its own approval. - return fmt.Errorf("next epoch approvals should not be nil when collecting approvals") - } - - validators, err := nv.getValidatorSet(prev.NextPChainReferenceHeight) - if err != nil { - return err - } - - err = nv.verifySignature(prev, next, validators) - if err != nil { - return err - } - - // A node cannot remove other nodes' approvals, only add its own approval if it wasn't included in the previous block. - // So the set of signers in next.NextEpochApprovals should be a superset of the set of signers in prev.NextEpochApprovals. - if err := areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev, next); err != nil { - return err - } - - return nil -} - -func (nv *nextEpochApprovalsVerifier) verifyEmptyNextEpochApprovals(_ SimplexEpochInfo, next SimplexEpochInfo) error { - if next.NextEpochApprovals != nil { - return fmt.Errorf("next epoch approvals should be nil but got %v", next.NextEpochApprovals) - } - return nil -} - -func (nv *nextEpochApprovalsVerifier) verifySignature(prev SimplexEpochInfo, next SimplexEpochInfo, validators NodeBLSMappings) error { - // First figure out which validators are approving the next epoch by looking at the bitmask of approving nodes, - // and then aggregate their public keys together to verify the signature. - - nodeIDsBitmask := next.NextEpochApprovals.NodeIDs - aggPK, err := nv.aggregatePubKeysForBitmask(nodeIDsBitmask, validators) - if err != nil { - return err - } - - message := nv.createMessageToBeVerified(prev) - - if err := nv.sigVerifier.VerifySignature(next.NextEpochApprovals.Signature, message, aggPK); err != nil { - return fmt.Errorf("failed to verify signature: %w", err) - } - return nil -} - -func (nv *nextEpochApprovalsVerifier) createMessageToBeVerified(prev SimplexEpochInfo) []byte { - pChainHeightBuff := pChainNextReferenceHeightAsBytes(prev) - - var bb bytes.Buffer - bb.Write(pChainHeightBuff) - - message := bb.Bytes() - return message -} - -func (nv *nextEpochApprovalsVerifier) aggregatePubKeysForBitmask(nodeIDsBitmask []byte, validators NodeBLSMappings) ([]byte, error) { - approvingNodes := bitmaskFromBytes(nodeIDsBitmask) - publicKeys := make([][]byte, 0, len(validators)) - for i := range validators { - if !approvingNodes.Contains(i) { - continue - } - publicKeys = append(publicKeys, validators[i].BLSKey) - } - - aggPK, err := nv.keyAggregator.AggregateKeys(publicKeys...) - if err != nil { - return nil, fmt.Errorf("failed to aggregate public keys: %w", err) - } - return aggPK, nil -} - -func pChainNextReferenceHeightAsBytes(prev SimplexEpochInfo) []byte { - pChainHeight := prev.NextPChainReferenceHeight - pChainHeightBuff := make([]byte, 8) - binary.BigEndian.PutUint64(pChainHeightBuff, pChainHeight) - return pChainHeightBuff -} - -type nextPChainReferenceHeightVerifier struct { - getValidatorSet ValidatorSetRetriever - getPChainHeight func() uint64 -} - -func (n *nextPChainReferenceHeightVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - switch in.nextBlockType { - case BlockTypeTelock, BlockTypeSealing: - if prev.NextPChainReferenceHeight != next.NextPChainReferenceHeight { - return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) - } - case BlockTypeNormal: - return n.verifyNextPChainRefHeightNormal(in.prevMD, prev, next) - case BlockTypeNewEpoch: - if next.NextPChainReferenceHeight != 0 { - return fmt.Errorf("expected P-chain reference height to be 0 but got %d", next.NextPChainReferenceHeight) - } - default: - return fmt.Errorf("unknown block type: %d", in.nextBlockType) - } - return nil -} - -func (n *nextPChainReferenceHeightVerifier) verifyNextPChainRefHeightNormal(prevMD StateMachineMetadata, prev SimplexEpochInfo, next SimplexEpochInfo) error { - // Next P-chain height can only increase, not decrease. - if next.NextPChainReferenceHeight > 0 && prev.PChainReferenceHeight > next.NextPChainReferenceHeight { - return fmt.Errorf("expected P-chain reference height to be non-decreasing, "+ - "but the previous P-chain reference height is %d and the proposed P-chain reference height is %d", prev.PChainReferenceHeight, next.NextPChainReferenceHeight) - } - - // If the previous block already has a next P-chain reference height, - // we should keep the same next P-chain reference height until we reach it. - if prev.NextPChainReferenceHeight > 0 { - if next.NextPChainReferenceHeight != prev.NextPChainReferenceHeight { - return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.NextPChainReferenceHeight, next.NextPChainReferenceHeight) - } - return nil - } - - // If we reached here, then prev.NextPChainReferenceHeight == 0. - // It might be that this block is the first block that has set the next P-chain reference height for the epoch, - // so check if it has done so correctly by observing whether the validator set has indeed changed. - - currentValidatorSet, err := n.getValidatorSet(prevMD.SimplexEpochInfo.PChainReferenceHeight) - if err != nil { - return err - } - - newValidatorSet, err := n.getValidatorSet(next.NextPChainReferenceHeight) - if err != nil { - return err - } - - // If the validator set doesn't change, we shouldn't have increased the next P-chain reference height. - if currentValidatorSet.Equal(newValidatorSet) && next.NextPChainReferenceHeight > 0 { - return fmt.Errorf("validator set at proposed next P-chain reference height %d is the same as "+ - "validator set at previous block's P-chain reference height %d,"+ - "so expected next P-chain reference height to remain the same but got %d", - next.NextPChainReferenceHeight, prev.PChainReferenceHeight, next.NextPChainReferenceHeight) - } - - // Else, either the validator set has changed, or the next P-chain reference height is still 0. - // Both of these cases are fine, but we should verify that we have observed the next P-chain reference height if it is > 0. - - pChainHeight := n.getPChainHeight() - - if pChainHeight < next.NextPChainReferenceHeight { - return fmt.Errorf("haven't reached P-chain height %d yet, current P-chain height is only %d", next.NextPChainReferenceHeight, pChainHeight) - } - - return nil -} - -type epochNumberVerifier struct{} - -func (e *epochNumberVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - // An epoch number of 0 means this is not a Simplex block, so the next block should be the first Simplex block with epoch number 1. - if in.prevMD.SimplexEpochInfo.EpochNumber == 0 && in.proposedBlockMD.SimplexEpochInfo.EpochNumber != 1 { - return fmt.Errorf("expected epoch number of the first block created to be 1 but got %d", next.EpochNumber) - } - - // The only time in which we should increase the epoch number is when we have a block that marks the start of a new epoch. - switch in.nextBlockType { - case BlockTypeNewEpoch: - // TODO: we have to make sure that Telocks are pruned before moving to a new epoch, otherwise we hit a false negative below. - if in.prevBlockSeq != next.EpochNumber { - return fmt.Errorf("expected epoch number to be %d but got %d", in.prevBlockSeq, next.EpochNumber) - } - default: - if prev.EpochNumber != next.EpochNumber { - return fmt.Errorf("expected epoch number to be %d but got %d", prev.EpochNumber, next.EpochNumber) - } - } - return nil -} - -type sealingBlockSeqVerifier struct{} - -func (s *sealingBlockSeqVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - // A block should only have a sealing block if it is a Telock. - switch in.nextBlockType { - case BlockTypeNewEpoch, BlockTypeNormal, BlockTypeSealing: - if next.SealingBlockSeq != 0 { - return fmt.Errorf("expected sealing block sequence number to be 0 but got %d", next.SealingBlockSeq) - } - case BlockTypeTelock: - // This is not the first Telock, make sure the sealing block sequence number doesn't change. - - // prev.SealingBlockSeq > 0 means the previous block is a Telock. - if prev.SealingBlockSeq > 0 && next.SealingBlockSeq != prev.SealingBlockSeq { - return fmt.Errorf("expected sealing block sequence number to be %d but got %d", prev.SealingBlockSeq, next.SealingBlockSeq) - } - - // Else, either this is the first Telock, or the previous block's sealing block sequence is equal to this block's sealing block sequence. - - // We need to check the first case has a valid sealing block sequence, as the second case is fine by definition. - if prev.BlockValidationDescriptor != nil { - md, err := simplex.ProtocolMetadataFromBytes(in.prevMD.SimplexProtocolMetadata) - if err != nil { - return fmt.Errorf("failed parsing protocol metadata: %w", err) - } - if next.SealingBlockSeq != md.Seq { - return fmt.Errorf("expected sealing block sequence number to be %d but got %d", md.Seq, next.SealingBlockSeq) - } - } - default: - return fmt.Errorf("unknown block type: %d", in.nextBlockType) - } - - return nil -} - -type pChainHeightVerifier struct { - getPChainHeight func() uint64 -} - -func (p *pChainHeightVerifier) Verify(in verificationInput) error { - currentPChainHeight := p.getPChainHeight() - - if in.proposedBlockMD.PChainHeight > currentPChainHeight { - return fmt.Errorf("invalid P-chain height (%d) is too big, expected to be ≤ %d", - in.proposedBlockMD.PChainHeight, currentPChainHeight) - } - - if in.prevMD.PChainHeight > in.proposedBlockMD.PChainHeight { - return fmt.Errorf("invalid P-chain height (%d) is smaller than parent block's P-chain height (%d)", - in.proposedBlockMD.PChainHeight, in.prevMD.PChainHeight) - } - - return nil -} - -type pChainReferenceHeightVerifier struct{} - -func (p *pChainReferenceHeightVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - switch in.nextBlockType { - case BlockTypeNewEpoch: - if prev.NextPChainReferenceHeight != next.PChainReferenceHeight { - return fmt.Errorf("expected P-chain reference height of the first block of epoch %d to be %d but got %d", - prev.SealingBlockSeq, prev.NextPChainReferenceHeight, next.PChainReferenceHeight) - } - default: - if prev.PChainReferenceHeight != next.PChainReferenceHeight { - return fmt.Errorf("expected P-chain reference height to be %d but got %d", prev.PChainReferenceHeight, next.PChainReferenceHeight) - } - } - - return nil -} - -type timestampVerifier struct { - getTime func() time.Time - timeSkewLimit time.Duration -} - -func (t *timestampVerifier) Verify(in verificationInput) error { - if !in.hasInnerBlock { - // If no inner block, the timestamp is inherited from the parent block. - if in.proposedBlockMD.Timestamp != in.prevMD.Timestamp { - return fmt.Errorf("block without inner block should inherit parent timestamp %d but got %d", in.prevMD.Timestamp, in.proposedBlockMD.Timestamp) - } - } else { - // If there is an inner block, the timestamp should be the same as the inner block's timestamp. - if in.proposedBlockMD.Timestamp != uint64(in.innerBlockTimestamp.UnixMilli()) { - return fmt.Errorf("block timestamp %d does not match inner block timestamp %d", in.proposedBlockMD.Timestamp, in.innerBlockTimestamp.UnixMilli()) - } - } - - timestamp := time.UnixMilli(int64(in.proposedBlockMD.Timestamp)) - - currentTime := t.getTime() - if currentTime.Add(t.timeSkewLimit).Before(timestamp) { - return fmt.Errorf("proposed block timestamp is too far in the future, current time is %v but got %v", currentTime, timestamp) - } - - if in.prevMD.Timestamp > in.proposedBlockMD.Timestamp { - return fmt.Errorf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", in.prevMD.Timestamp, in.proposedBlockMD.Timestamp) - } - return nil -} - -type prevSealingBlockHashVerifier struct { - getBlock BlockRetriever - latestPersistedHeight *uint64 - firstEverSimplexBlock func() *StateMachineBlock -} - -func (p *prevSealingBlockHashVerifier) Verify(in verificationInput) error { - prev, _ := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - // Sealing block of the first epoch must point to the first ever Simplex block as the previous sealing block. - if prev.EpochNumber == 1 && in.nextBlockType == BlockTypeSealing { - firstSimplexBlock := p.firstEverSimplexBlock() - if firstSimplexBlock == nil { - return fmt.Errorf("first ever simplex block sequence number is not set but verifying a sealing block of the first epoch") - } - hash := firstSimplexBlock.Digest() - if !bytes.Equal(in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash[:], hash[:]) { - return fmt.Errorf("expected prev sealing block hash of the first ever simplex block to be %x but got %x", hash, in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) - } - - return nil - } - - // Otherwise, we can only have a previous sealing block hash if this is a sealing block, - // and in that case, the previous sealing block hash should match the hash of the sealing block of the previous epoch. - - switch in.nextBlockType { - case BlockTypeSealing: - prevSealingBlock, _, err := p.getBlock(in.prevMD.SimplexEpochInfo.EpochNumber, [32]byte{}) - if err != nil { - return fmt.Errorf("failed retrieving block: %w", err) - } - hash := prevSealingBlock.Digest() - if !bytes.Equal(in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash[:], hash[:]) { - return fmt.Errorf("expected prev sealing block hash to be %x but got %x", hash, in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) - } - default: // non-sealing blocks should have an empty previous sealing block hash - if in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash != [32]byte{} { - return fmt.Errorf("expected prev sealing block hash of a non sealing block to be empty but got %x", in.proposedBlockMD.SimplexEpochInfo.PrevSealingBlockHash) - } - } - - return nil -} - -type vmBlockSeqVerifier struct { - getBlock BlockRetriever -} - -func (v *vmBlockSeqVerifier) Verify(in verificationInput) error { - prev, next := in.prevMD.SimplexEpochInfo, in.proposedBlockMD.SimplexEpochInfo - - // If this is the first ever Simplex block, the PrevVMBlockSeq is simply the seq of the previous block. - if prev.EpochNumber == 0 { - if next.PrevVMBlockSeq != in.prevBlockSeq { - return fmt.Errorf("expected PrevVMBlockSeq to be %d but got %d", in.prevBlockSeq, next.PrevVMBlockSeq) - } - return nil - } - - md, err := simplex.ProtocolMetadataFromBytes(in.proposedBlockMD.SimplexProtocolMetadata) - if err != nil { - return fmt.Errorf("failed parsing protocol metadata: %w", err) - } - - // Else, if the previous block has an inner block, we point to it. - // Otherwise, we point to the parent block's previous VM block seq. - prevBlock, _, err := v.getBlock(in.prevBlockSeq, md.Prev) - if err != nil { - return fmt.Errorf("failed retrieving block: %w", err) - } - - expectedPrevVMBlockSeq := in.prevMD.SimplexEpochInfo.PrevVMBlockSeq - - if prevBlock.InnerBlock != nil { - expectedPrevVMBlockSeq = in.prevBlockSeq - } - - if next.PrevVMBlockSeq != expectedPrevVMBlockSeq { - return fmt.Errorf("expected PrevVMBlockSeq to be %d but got %d", expectedPrevVMBlockSeq, next.PrevVMBlockSeq) - } - - return nil -} - -func areNextEpochApprovalsSignersSupersetOfApprovalsOfPrevBlock(prev SimplexEpochInfo, next SimplexEpochInfo) error { - if prev.NextEpochApprovals == nil { - return nil - } - // Make sure that previous signers are still there. - prevSigners := bitmaskFromBytes(prev.NextEpochApprovals.NodeIDs) - nextSigners := bitmaskFromBytes(next.NextEpochApprovals.NodeIDs) - // Remove all bits in nextSigners from prevSigners - prevSigners.Difference(&nextSigners) - // If we have some bits left, it means there was a bit in prevSigners that wasn't in nextSigners - if prevSigners.Len() > 0 { - return fmt.Errorf("some signers from parent block are missing from next epoch approvals of proposed block") - } - return nil -} diff --git a/msm/verification_test.go b/msm/verification_test.go deleted file mode 100644 index 22c6b3c7..00000000 --- a/msm/verification_test.go +++ /dev/null @@ -1,940 +0,0 @@ -// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. -// See the file LICENSE for licensing terms. - -package metadata - -import ( - "fmt" - "testing" - "time" - - "github.com/ava-labs/simplex" - "github.com/stretchr/testify/require" -) - -func TestPChainHeightVerifier(t *testing.T) { - for _, tc := range []struct { - name string - pChainHeight uint64 - prevHeight uint64 - nextHeight uint64 - err string - }{ - { - name: "valid height", - pChainHeight: 200, - prevHeight: 100, - nextHeight: 150, - }, - { - name: "height equal to current", - pChainHeight: 200, - prevHeight: 100, - nextHeight: 200, - }, - { - name: "height too big", - pChainHeight: 100, - prevHeight: 50, - nextHeight: 150, - err: "invalid P-chain height (150) is too big, expected to be ≤ 100", - }, - { - name: "height smaller than parent", - pChainHeight: 200, - prevHeight: 150, - nextHeight: 100, - err: "invalid P-chain height (100) is smaller than parent block's P-chain height (150)", - }, - { - name: "height equal to parent", - pChainHeight: 200, - prevHeight: 100, - nextHeight: 100, - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &pChainHeightVerifier{ - getPChainHeight: func() uint64 { return tc.pChainHeight }, - } - err := v.Verify(verificationInput{ - prevMD: StateMachineMetadata{PChainHeight: tc.prevHeight}, - proposedBlockMD: StateMachineMetadata{PChainHeight: tc.nextHeight}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestTimestampVerifier(t *testing.T) { - now := time.Now() - - timeSkewLimit := 5 * time.Second - - futureTime := now.Add(10 * time.Second) - - for _, tc := range []struct { - name string - hasInnerBlock bool - innerBlockTimestamp time.Time - timestamp uint64 - parentTimestamp uint64 - err string - }{ - { - name: "valid timestamp with inner block", - hasInnerBlock: true, - innerBlockTimestamp: now, - timestamp: uint64(now.UnixMilli()), - }, - { - name: "metadata timestamp does not match inner block", - hasInnerBlock: true, - innerBlockTimestamp: now, - timestamp: uint64(now.UnixMilli()) + 100, - err: fmt.Sprintf("block timestamp %d does not match inner block timestamp %d", uint64(now.UnixMilli())+100, now.UnixMilli()), - }, - { - name: "timestamp too far in the future", - hasInnerBlock: true, - innerBlockTimestamp: futureTime, - timestamp: uint64(futureTime.UnixMilli()), - err: fmt.Sprintf("proposed block timestamp is too far in the future, current time is %v but got %v", now, time.UnixMilli(futureTime.UnixMilli())), - }, - { - name: "timestamp older than parent", - hasInnerBlock: true, - innerBlockTimestamp: now, - timestamp: uint64(now.UnixMilli()), - parentTimestamp: uint64(now.UnixMilli()) + 10, - err: fmt.Sprintf("proposed block timestamp is older than parent block's timestamp, parent timestamp is %d but got %d", uint64(now.UnixMilli())+10, uint64(now.UnixMilli())), - }, - { - name: "no inner block inherits parent timestamp", - hasInnerBlock: false, - timestamp: uint64(now.UnixMilli()), - parentTimestamp: uint64(now.UnixMilli()), - }, - { - name: "no inner block with different timestamp than parent", - hasInnerBlock: false, - timestamp: uint64(now.UnixMilli()) + 100, - parentTimestamp: uint64(now.UnixMilli()), - err: fmt.Sprintf("block without inner block should inherit parent timestamp %d but got %d", uint64(now.UnixMilli()), uint64(now.UnixMilli())+100), - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := ×tampVerifier{ - getTime: func() time.Time { return now }, - timeSkewLimit: timeSkewLimit, - } - err := v.Verify(verificationInput{ - hasInnerBlock: tc.hasInnerBlock, - innerBlockTimestamp: tc.innerBlockTimestamp, - proposedBlockMD: StateMachineMetadata{Timestamp: tc.timestamp}, - prevMD: StateMachineMetadata{Timestamp: tc.parentTimestamp}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestPChainReferenceHeightVerifier(t *testing.T) { - for _, tc := range []struct { - name string - nextBlockType BlockType - prev SimplexEpochInfo - next SimplexEpochInfo - err string - }{ - { - name: "new epoch block matching prev NextPChainReferenceHeight", - nextBlockType: BlockTypeNewEpoch, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200, SealingBlockSeq: 5}, - next: SimplexEpochInfo{PChainReferenceHeight: 200}, - }, - { - name: "new epoch block not matching prev NextPChainReferenceHeight", - nextBlockType: BlockTypeNewEpoch, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200, SealingBlockSeq: 5}, - next: SimplexEpochInfo{PChainReferenceHeight: 100}, - err: "expected P-chain reference height of the first block of epoch 5 to be 200 but got 100", - }, - { - name: "normal block matching prev PChainReferenceHeight", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{PChainReferenceHeight: 100}, - }, - { - name: "normal block not matching prev PChainReferenceHeight", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{PChainReferenceHeight: 200}, - err: "expected P-chain reference height to be 100 but got 200", - }, - { - name: "sealing block matching prev PChainReferenceHeight", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{PChainReferenceHeight: 100}, - }, - { - name: "telock block matching prev PChainReferenceHeight", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{PChainReferenceHeight: 100}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &pChainReferenceHeightVerifier{} - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestEpochNumberVerifier(t *testing.T) { - for _, tc := range []struct { - name string - nextBlockType BlockType - prevBlockSeq uint64 - prev SimplexEpochInfo - next SimplexEpochInfo - err string - }{ - { - name: "prev epoch 0 with wrong next epoch", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{EpochNumber: 0}, - next: SimplexEpochInfo{EpochNumber: 5}, - err: "expected epoch number of the first block created to be 1 but got 5", - }, - { - name: "new epoch block matching sealing seq", - nextBlockType: BlockTypeNewEpoch, - prevBlockSeq: 10, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{EpochNumber: 10}, - }, - { - name: "new epoch block not matching sealing seq", - nextBlockType: BlockTypeNewEpoch, - prevBlockSeq: 10, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{EpochNumber: 5}, - err: "expected epoch number to be 10 but got 5", - }, - { - name: "normal block same epoch", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{EpochNumber: 3}, - next: SimplexEpochInfo{EpochNumber: 3}, - }, - { - name: "normal block different epoch", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{EpochNumber: 3}, - next: SimplexEpochInfo{EpochNumber: 4}, - err: "expected epoch number to be 3 but got 4", - }, - { - name: "sealing block same epoch", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{EpochNumber: 2}, - next: SimplexEpochInfo{EpochNumber: 2}, - }, - { - name: "telock block same epoch", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{EpochNumber: 2}, - next: SimplexEpochInfo{EpochNumber: 2}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &epochNumberVerifier{} - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevBlockSeq: tc.prevBlockSeq, - prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestPrevSealingBlockHashVerifier(t *testing.T) { - // A simplex block (EpochNumber > 0) so findFirstSimplexBlock can locate it. - firstSimplexBlock := StateMachineBlock{ - InnerBlock: &testVMBlock{bytes: []byte{1, 2, 3}}, - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1}}, - } - firstSimplexBlockHash := firstSimplexBlock.Digest() - - // A block used for epoch >1 sealing lookups. - prevSealingBlock := StateMachineBlock{ - InnerBlock: &testVMBlock{bytes: []byte{4, 5, 6}}, - Metadata: StateMachineMetadata{SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 5}}, - } - prevSealingBlockHash := prevSealingBlock.Digest() - - bs := make(testBlockStore) - bs[1] = firstSimplexBlock - bs[5] = prevSealingBlock - latestPersisted := uint64(1) - - for _, tc := range []struct { - name string - nextBlockType BlockType - prev SimplexEpochInfo - next SimplexEpochInfo - err string - }{ - { - name: "epoch 1 sealing block with correct hash", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{ - PrevSealingBlockHash: firstSimplexBlockHash, - }, - }, - { - name: "epoch 1 sealing block with wrong hash", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{9, 9, 9}, - }, - err: fmt.Sprintf("expected prev sealing block hash of the first ever simplex block to be %x but got %x", firstSimplexBlockHash, [32]byte{9, 9, 9}), - }, - { - name: "epoch >1 sealing block with correct hash", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{EpochNumber: 5}, - next: SimplexEpochInfo{ - PrevSealingBlockHash: prevSealingBlockHash, - }, - }, - { - name: "epoch >1 sealing block with wrong hash", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{EpochNumber: 5}, - next: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{9, 9, 9}, - }, - err: fmt.Sprintf("expected prev sealing block hash to be %x but got %x", prevSealingBlockHash, [32]byte{9, 9, 9}), - }, - { - name: "non-sealing block with empty hash", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{}, - }, - { - name: "non-sealing block with non-empty hash", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{EpochNumber: 1}, - next: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{1}, - }, - err: fmt.Sprintf("expected prev sealing block hash of a non sealing block to be empty but got %x", [32]byte{1}), - }, - { - name: "telock block with empty hash", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{EpochNumber: 2}, - next: SimplexEpochInfo{}, - }, - { - name: "new epoch block with empty hash", - nextBlockType: BlockTypeNewEpoch, - prev: SimplexEpochInfo{EpochNumber: 2}, - next: SimplexEpochInfo{}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &prevSealingBlockHashVerifier{ - getBlock: bs.getBlock, - latestPersistedHeight: &latestPersisted, - firstEverSimplexBlock: func() *StateMachineBlock { - return &firstSimplexBlock - }, - } - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestNextPChainReferenceHeightVerifier(t *testing.T) { - validators1 := NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}} - validators2 := NodeBLSMappings{{BLSKey: []byte{2}, Weight: 1}} - - for _, tc := range []struct { - name string - nextBlockType BlockType - prev SimplexEpochInfo - prevPChainRef uint64 - next SimplexEpochInfo - getValidator ValidatorSetRetriever - pChainHeight uint64 - err string - }{ - { - name: "telock block matching height", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - }, - { - name: "telock block mismatched height", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 300}, - err: "expected P-chain reference height to be 200 but got 300", - }, - { - name: "sealing block matching height", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - }, - { - name: "sealing block mismatched height", - nextBlockType: BlockTypeSealing, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, - err: "expected P-chain reference height to be 200 but got 100", - }, - { - name: "normal block prev already has next height set", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - }, - { - name: "normal block prev already has next height set mismatch", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 300}, - err: "expected P-chain reference height to be 200 but got 300", - }, - { - name: "normal block next p-chain reference height less than current", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 200}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, - err: "expected P-chain reference height to be non-decreasing, but the previous P-chain reference height is 200 and the proposed P-chain reference height is 100", - }, - { - name: "normal block same validator set with non-zero next height", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators1, nil }, - err: "validator set at proposed next P-chain reference height 200 is the same as validator set at previous block's P-chain reference height 100,so expected next P-chain reference height to remain the same but got 200", - }, - { - name: "normal block no validator change and next height is zero", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 0}, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators1, nil }, - }, - { - name: "normal block validator change detected and p-chain height reached", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - getValidator: func(h uint64) (NodeBLSMappings, error) { - if h == 200 { - return validators2, nil - } - return validators1, nil - }, - pChainHeight: 200, - }, - { - name: "normal block validator change but p-chain height not reached", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{PChainReferenceHeight: 100}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 200}, - getValidator: func(h uint64) (NodeBLSMappings, error) { - if h == 200 { - return validators2, nil - } - return validators1, nil - }, - pChainHeight: 150, - err: "haven't reached P-chain height 200 yet, current P-chain height is only 150", - }, - { - name: "new epoch block with zero next height", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{NextPChainReferenceHeight: 0}, - }, - { - name: "new epoch block with non-zero next height", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, - err: "expected P-chain reference height to be 0 but got 100", - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &nextPChainReferenceHeightVerifier{ - getValidatorSet: tc.getValidator, - getPChainHeight: func() uint64 { return tc.pChainHeight }, - } - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestVMBlockSeqVerifier(t *testing.T) { - prevMDBytes := (&simplex.ProtocolMetadata{Seq: 5, Prev: [32]byte{1}}).Bytes() - proposedMDBytes := (&simplex.ProtocolMetadata{Seq: 6, Prev: [32]byte{2}}).Bytes() - - blockWithInner := StateMachineBlock{ - InnerBlock: &testVMBlock{bytes: []byte{1}}, - } - blockWithoutInner := StateMachineBlock{} - - for _, tc := range []struct { - name string - prev SimplexEpochInfo - prevMD StateMachineMetadata - next SimplexEpochInfo - prevBlockSeq uint64 - block StateMachineBlock - err string - }{ - { - name: "first simplex block matching seq", - prev: SimplexEpochInfo{EpochNumber: 0}, - next: SimplexEpochInfo{PrevVMBlockSeq: 42}, - prevBlockSeq: 42, - }, - { - name: "first simplex block wrong seq", - prev: SimplexEpochInfo{EpochNumber: 0}, - next: SimplexEpochInfo{PrevVMBlockSeq: 10}, - prevBlockSeq: 42, - err: "expected PrevVMBlockSeq to be 42 but got 10", - }, - { - name: "prev block has block", - prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, - next: SimplexEpochInfo{PrevVMBlockSeq: 4}, - prevBlockSeq: 4, - block: blockWithInner, - }, - { - name: "prev block has block wrong seq", - prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, - next: SimplexEpochInfo{PrevVMBlockSeq: 99}, - prevBlockSeq: 4, - block: blockWithInner, - err: "expected PrevVMBlockSeq to be 4 but got 99", - }, - { - name: "prev block has no block uses parent PrevVMBlockSeq", - prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, - next: SimplexEpochInfo{PrevVMBlockSeq: 3}, - prevBlockSeq: 4, - block: blockWithoutInner, - }, - { - name: "prev block has no block wrong seq", - prev: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevMDBytes, SimplexEpochInfo: SimplexEpochInfo{EpochNumber: 1, PrevVMBlockSeq: 3}}, - next: SimplexEpochInfo{PrevVMBlockSeq: 99}, - prevBlockSeq: 4, - block: blockWithoutInner, - err: "expected PrevVMBlockSeq to be 3 but got 99", - }, - } { - t.Run(tc.name, func(t *testing.T) { - bs := make(testBlockStore) - bs[tc.prevBlockSeq] = tc.block - - v := &vmBlockSeqVerifier{ - getBlock: bs.getBlock, - } - - prevMD := tc.prevMD - if prevMD.SimplexEpochInfo.EpochNumber == 0 && tc.prev.EpochNumber == 0 { - prevMD.SimplexEpochInfo = tc.prev - } - - err := v.Verify(verificationInput{ - prevMD: prevMD, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next, SimplexProtocolMetadata: proposedMDBytes}, - prevBlockSeq: tc.prevBlockSeq, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestValidationDescriptorVerifier(t *testing.T) { - validators := NodeBLSMappings{ - {BLSKey: []byte{1}, Weight: 1}, - {BLSKey: []byte{2}, Weight: 1}, - } - - otherValidators := NodeBLSMappings{ - {BLSKey: []byte{3}, Weight: 1}, - } - - for _, tc := range []struct { - name string - nextBlockType BlockType - next SimplexEpochInfo - getValidator ValidatorSetRetriever - err string - }{ - { - name: "sealing block with matching validators", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - BlockValidationDescriptor: &BlockValidationDescriptor{ - AggregatedMembership: AggregatedMembership{Members: validators}, - }, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - }, - { - name: "sealing block with mismatching validators", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - BlockValidationDescriptor: &BlockValidationDescriptor{ - AggregatedMembership: AggregatedMembership{Members: otherValidators}, - }, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - err: "expected validator set specified at P-chain height 100 does not match validator set encoded in new block", - }, - { - name: "sealing block with validator retrieval error", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - BlockValidationDescriptor: &BlockValidationDescriptor{}, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return nil, fmt.Errorf("unavailable") }, - err: "unavailable", - }, - { - name: "normal block with nil descriptor", - nextBlockType: BlockTypeNormal, - next: SimplexEpochInfo{}, - }, - { - name: "normal block with non-nil descriptor", - nextBlockType: BlockTypeNormal, - next: SimplexEpochInfo{ - BlockValidationDescriptor: &BlockValidationDescriptor{}, - }, - err: "block validation descriptor should be nil but got &{{[] {0}} {0}}", - }, - { - name: "telock block with nil descriptor", - nextBlockType: BlockTypeTelock, - next: SimplexEpochInfo{}, - }, - { - name: "new epoch block with nil descriptor", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{}, - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &validationDescriptorVerifier{ - getValidatorSet: tc.getValidator, - } - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestNextEpochApprovalsVerifier(t *testing.T) { - validators := NodeBLSMappings{ - {BLSKey: []byte{1}, Weight: 1}, - {BLSKey: []byte{2}, Weight: 1}, - {BLSKey: []byte{3}, Weight: 1}, - } - - for _, tc := range []struct { - name string - nextBlockType BlockType - prev SimplexEpochInfo - next SimplexEpochInfo - getValidator ValidatorSetRetriever - sigVerifier SignatureVerifier - keyAggregator KeyAggregator - err string - }{ - { - name: "sealing block with nil approvals", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{}, - err: "next epoch approvals should not be nil for a sealing block", - }, - { - name: "sealing block with validator retrieval error", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{7}, Signature: []byte("sig")}, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return nil, fmt.Errorf("unavailable") }, - err: "unavailable", - }, - { - name: "sealing block not enough approvals", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - sigVerifier: &testSigVerifier{}, - keyAggregator: &testKeyAggregator{}, - err: "not enough approvals to seal block", - }, - { - name: "sealing block enough approvals", - nextBlockType: BlockTypeSealing, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{7}, Signature: []byte("sig")}, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - sigVerifier: &testSigVerifier{}, - keyAggregator: &testKeyAggregator{}, - }, - { - name: "normal block no validator change", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 0}, - next: SimplexEpochInfo{}, - }, - { - name: "normal block collecting approvals with nil approvals", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{NextPChainReferenceHeight: 100}, - next: SimplexEpochInfo{NextPChainReferenceHeight: 100}, - err: "next epoch approvals should not be nil when collecting approvals", - }, - { - name: "normal block collecting approvals valid", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - PChainReferenceHeight: 50, - }, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - sigVerifier: &testSigVerifier{}, - keyAggregator: &testKeyAggregator{}, - }, - { - name: "normal block collecting approvals signers not superset of prev", - nextBlockType: BlockTypeNormal, - prev: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - PChainReferenceHeight: 50, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{3}, Signature: []byte("sig")}, // bits 0,1 - }, - next: SimplexEpochInfo{ - NextPChainReferenceHeight: 100, - NextEpochApprovals: &NextEpochApprovals{NodeIDs: []byte{1}, Signature: []byte("sig")}, // bit 0 only - }, - getValidator: func(h uint64) (NodeBLSMappings, error) { return validators, nil }, - sigVerifier: &testSigVerifier{}, - keyAggregator: &testKeyAggregator{}, - err: "some signers from parent block are missing from next epoch approvals of proposed block", - }, - { - name: "telock block with nil approvals", - nextBlockType: BlockTypeTelock, - next: SimplexEpochInfo{}, - }, - { - name: "telock block with non-nil approvals", - nextBlockType: BlockTypeTelock, - next: SimplexEpochInfo{ - NextEpochApprovals: &NextEpochApprovals{}, - }, - err: "next epoch approvals should be nil but got &{[] [] {0}}", - }, - { - name: "new epoch block with nil approvals", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{}, - }, - { - name: "new epoch block with non-nil approvals", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{ - NextEpochApprovals: &NextEpochApprovals{}, - }, - err: "next epoch approvals should be nil but got &{[] [] {0}}", - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &nextEpochApprovalsVerifier{ - sigVerifier: tc.sigVerifier, - getValidatorSet: tc.getValidator, - keyAggregator: tc.keyAggregator, - sigAggregatorCreator: newSignatureAggregatorCreator(), - } - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevMD: StateMachineMetadata{SimplexEpochInfo: tc.prev}, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} - -func TestSealingBlockSeqVerifier(t *testing.T) { - prevProtocolMD := (&simplex.ProtocolMetadata{Seq: 5}).Bytes() - - for _, tc := range []struct { - name string - nextBlockType BlockType - prev SimplexEpochInfo - prevMD StateMachineMetadata - next SimplexEpochInfo - err string - }{ - { - name: "normal block with zero sealing seq", - nextBlockType: BlockTypeNormal, - next: SimplexEpochInfo{SealingBlockSeq: 0}, - }, - { - name: "normal block with non-zero sealing seq", - nextBlockType: BlockTypeNormal, - next: SimplexEpochInfo{SealingBlockSeq: 5}, - err: "expected sealing block sequence number to be 0 but got 5", - }, - { - name: "new epoch block with zero sealing seq", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{SealingBlockSeq: 0}, - }, - { - name: "new epoch block with non-zero sealing seq", - nextBlockType: BlockTypeNewEpoch, - next: SimplexEpochInfo{SealingBlockSeq: 3}, - err: "expected sealing block sequence number to be 0 but got 3", - }, - { - name: "telock block matching prev sealing seq", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{SealingBlockSeq: 10}, - next: SimplexEpochInfo{SealingBlockSeq: 10}, - }, - { - name: "telock block mismatching prev sealing seq", - nextBlockType: BlockTypeTelock, - prev: SimplexEpochInfo{SealingBlockSeq: 10}, - next: SimplexEpochInfo{SealingBlockSeq: 11}, - err: "expected sealing block sequence number to be 10 but got 11", - }, - { - name: "sealing block with zero seq", - nextBlockType: BlockTypeSealing, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevProtocolMD}, - next: SimplexEpochInfo{SealingBlockSeq: 0}, - }, - { - name: "sealing block with non-zero seq", - nextBlockType: BlockTypeSealing, - prevMD: StateMachineMetadata{SimplexProtocolMetadata: prevProtocolMD}, - next: SimplexEpochInfo{SealingBlockSeq: 10}, - err: "expected sealing block sequence number to be 0 but got 10", - }, - } { - t.Run(tc.name, func(t *testing.T) { - v := &sealingBlockSeqVerifier{} - prevMD := tc.prevMD - prevMD.SimplexEpochInfo = tc.prev - err := v.Verify(verificationInput{ - nextBlockType: tc.nextBlockType, - prevMD: prevMD, - proposedBlockMD: StateMachineMetadata{SimplexEpochInfo: tc.next}, - }) - if tc.err != "" { - require.EqualError(t, err, tc.err) - } else { - require.NoError(t, err) - } - }) - } -} From dff27ead6d0098387cd1395544675f4b6d507166 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Tue, 12 May 2026 17:13:51 +0200 Subject: [PATCH 12/16] Move helpers to misc_test.go Signed-off-by: Yacov Manevich --- msm/misc_test.go | 348 +++++++++++++++++++++++++++++++++++++++++------ msm/msm.go | 30 ++-- msm/msm_test.go | 320 ------------------------------------------- 3 files changed, 322 insertions(+), 376 deletions(-) diff --git a/msm/misc_test.go b/msm/misc_test.go index 8a1c0ada..91325712 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -4,14 +4,19 @@ package metadata import ( + "bytes" "context" + "crypto/rand" "crypto/sha256" + "encoding/asn1" "fmt" + "maps" "math" "testing" "time" "github.com/ava-labs/simplex" + "github.com/ava-labs/simplex/testutil" "github.com/stretchr/testify/require" ) @@ -148,81 +153,342 @@ func TestBitmask(t *testing.T) { }) } - // Test helpers -type testBlockStore map[uint64]StateMachineBlock +type InnerBlock struct { + TS time.Time + BlockHeight uint64 + Bytes []byte +} -func (bs testBlockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, ok := bs[seq] - if !ok { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d", simplex.ErrBlockNotFound, seq) - } - return blk, nil, nil +func (i *InnerBlock) Digest() [32]byte { + return sha256.Sum256(i.Bytes) } -type testVMBlock struct { - bytes []byte +func (i *InnerBlock) Height() uint64 { + return i.BlockHeight +} + +func (i *InnerBlock) Timestamp() time.Time { + return i.TS +} + +func (i *InnerBlock) Verify(_ context.Context) error { + return nil +} + +// fakeVMBlock is a minimal VMBlock implementation for tests. +type fakeVMBlock struct { height uint64 } -func (b *testVMBlock) Digest() [32]byte { - return sha256.Sum256(b.bytes) +func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } +func (f *fakeVMBlock) Height() uint64 { return f.height } +func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } +func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } + +type outerBlock struct { + finalization *simplex.Finalization + block StateMachineBlock +} + +type blockStore map[uint64]*outerBlock + +func (bs blockStore) clone() blockStore { + newStore := make(blockStore) + maps.Copy(newStore, bs) + return newStore } -func (b *testVMBlock) Height() uint64 { - return b.height +func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { + blk, exits := bs[seq] + if !exits { + return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) + } + return blk.block, blk.finalization, nil } -func (b *testVMBlock) Timestamp() time.Time { - return time.Now() +type approvalsRetriever struct { + result ValidatorSetApprovals } -func (b *testVMBlock) Verify(_ context.Context) error { - return nil +func (a approvalsRetriever) RetrieveApprovals() ValidatorSetApprovals { + return a.result } -type testSigVerifier struct { +type signatureVerifier struct { err error } -func (sv *testSigVerifier) VerifySignature(_, _, _ []byte) error { +func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { return sv.err } -type testKeyAggregator struct { - err error +type signatureAggregator struct { + weightByNodeID map[string]uint64 + totalWeight uint64 +} + +type aggregatrdSignature struct { + Signatures [][]byte +} + +func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") } -func (ka *testKeyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - if ka.err != nil { - return nil, ka.err +func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + all := make([][]byte, 0, len(sigs)+1) + all = append(all, sigs...) + if len(existing) > 0 { + all = append(all, existing) } - var agg []byte - for _, k := range keys { - agg = append(agg, k...) + return asn1.Marshal(aggregatrdSignature{Signatures: all}) +} + +func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { + var sum uint64 + for _, signer := range signers { + sum += sv.weightByNodeID[string(signer)] } - return agg, nil + return sum*3 > sv.totalWeight*2 } -type InnerBlock struct { - TS time.Time - BlockHeight uint64 - Bytes []byte +func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { + return func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} + for _, nw := range weights { + s.weightByNodeID[string(nw.Node)] = nw.Weight + s.totalWeight += nw.Weight + } + return s + } } -func (i *InnerBlock) Digest() [32]byte { - return sha256.Sum256(i.Bytes) +type noOpPChainListener struct{} + +func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { + <-ctx.Done() + return ctx.Err() } -func (i *InnerBlock) Height() uint64 { - return i.BlockHeight +type blockBuilder struct { + block VMBlock + err error } -func (i *InnerBlock) Timestamp() time.Time { - return i.TS +func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { + // Block is always ready in tests. } -func (i *InnerBlock) Verify(_ context.Context) error { - return nil +func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { + return bb.block, bb.err +} + +type validatorSetRetriever struct { + result NodeBLSMappings + resultMap map[uint64]NodeBLSMappings + err error +} + +func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { + if vsr.resultMap != nil { + if result, ok := vsr.resultMap[height]; ok { + return result, vsr.err + } + } + return vsr.result, vsr.err +} + +type keyAggregator struct{} + +func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { + aggregated := make([]byte, 0) + for _, key := range keys { + aggregated = append(aggregated, key...) + } + return aggregated, nil +} + +var ( + genesisBlock = StateMachineBlock{ + // Genesis block metadata has all zero values + InnerBlock: &InnerBlock{ + TS: time.Now(), + Bytes: []byte{1, 2, 3}, + }, + } +) + +type dynamicApprovalsRetriever struct { + approvals *ValidatorSetApprovals +} + +func (d *dynamicApprovalsRetriever) RetrieveApprovals() ValidatorSetApprovals { + return *d.approvals +} + +func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { + startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) + blocks := make([]StateMachineBlock, 0, endHeight+1) + var round, seq uint64 + for h := uint64(0); h <= endHeight; h++ { + index := len(blocks) + + if h == 0 { + blocks = append(blocks, genesisBlock) + continue + } + + if h < simplexStartHeight { + blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) + continue + } + + seq = uint64(index) + + blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) + round++ + } + return blocks +} + +func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + prev := genesisBlock.Digest() + if index > 0 { + prev = blocks[index-1].Digest() + } + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + Metadata: StateMachineMetadata{ + PChainHeight: 100, + SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ + Round: round, + Seq: seq, + Epoch: 1, + Prev: prev, + }).Bytes(), + SimplexEpochInfo: SimplexEpochInfo{ + PrevSealingBlockHash: [32]byte{}, + PChainReferenceHeight: 100, + EpochNumber: 1, + PrevVMBlockSeq: uint64(index), + }, + }, + } +} + +func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { + content := make([]byte, 10) + _, err := rand.Read(content) + require.NoError(t, err) + + return StateMachineBlock{ + InnerBlock: &InnerBlock{ + TS: start.Add(time.Duration(h-startHeight) * time.Second), + BlockHeight: h, + Bytes: []byte{1, 2, 3}, + }, + } +} + +type testConfig struct { + blockStore blockStore + approvalsRetriever approvalsRetriever + signatureVerifier signatureVerifier + signatureAggregator signatureAggregator + blockBuilder blockBuilder + keyAggregator keyAggregator + validatorSetRetriever validatorSetRetriever +} + +func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { + bs := make(blockStore) + bs[0] = &outerBlock{block: genesisBlock} + + var testConfig testConfig + testConfig.blockStore = bs + testConfig.validatorSetRetriever.result = NodeBLSMappings{ + {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, + } + + smConfig := Config{ + GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, + LastNonSimplexBlockPChainHeight: 100, + FirstEverSimplexBlock: func() *StateMachineBlock { + var res *StateMachineBlock + min := uint64(math.MaxUint64) + for seq, block := range testConfig.blockStore { + if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { + continue + } + if seq < min { + min = seq + res = &block.block + } + } + return res + }, + GetTime: time.Now, + TimeSkewLimit: time.Second * 5, + Logger: testutil.MakeLogger(t), + GetBlock: testConfig.blockStore.getBlock, + MaxBlockBuildingWaitTime: time.Second, + ApprovalsRetriever: &testConfig.approvalsRetriever, + SignatureVerifier: &testConfig.signatureVerifier, + SignatureAggregatorCreator: newSignatureAggregatorCreator(), + BlockBuilder: &testConfig.blockBuilder, + KeyAggregator: &testConfig.keyAggregator, + GetPChainHeight: func() uint64 { + return 100 + }, + GetUpgrades: func() any { + return nil + }, + GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, + PChainProgressListener: &noOpPChainListener{}, + LastNonSimplexInnerBlock: genesisBlock.InnerBlock, + } + + sm := NewStateMachine(&smConfig) + return sm, &testConfig +} + +// concatAggregator concatenates signatures for easy verification in tests. +type concatAggregator struct{} + +func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { + result := bytes.Join(sigs, nil) + return append(result, existing...), nil +} + +func (concatAggregator) IsQuorum([]simplex.NodeID) bool { + return false +} + +type failingAggregator struct{} + +func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { + panic("unused in tests") +} + +func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { + return nil, fmt.Errorf("aggregation failed") +} + +func (failingAggregator) IsQuorum([]simplex.NodeID) bool { + return false } diff --git a/msm/msm.go b/msm/msm.go index d74727f7..94e2801e 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -151,45 +151,45 @@ func NewStateMachine(config *Config) *StateMachine { } // BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. -func (sm *StateMachine) BuildBlock(ctx context.Context, simplexMetadata simplex.ProtocolMetadata, simplexBlacklist *simplex.Blacklist) (*StateMachineBlock, error) { +func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.ProtocolMetadata, blacklist *simplex.Blacklist) (*StateMachineBlock, error) { // The zero sequence number is reserved for the genesis block, which should never be built. - if simplexMetadata.Seq == 0 { - return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", simplexMetadata.Seq) + if metadata.Seq == 0 { + return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) } - parentBlock, _, err := sm.GetBlock(simplexMetadata.Seq-1, simplexMetadata.Prev) + parentBlock, _, err := sm.GetBlock(metadata.Seq-1, metadata.Prev) if err != nil { - return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", simplexMetadata.Seq-1, simplexMetadata.Prev.String(), err) + return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", metadata.Seq-1, metadata.Prev.String(), err) } start := time.Now() sm.Logger.Debug("Building block", - zap.Uint64("seq", simplexMetadata.Seq), - zap.Uint64("epoch", simplexMetadata.Epoch), - zap.Stringer("prevHash", simplexMetadata.Prev)) + zap.Uint64("seq", metadata.Seq), + zap.Uint64("epoch", metadata.Epoch), + zap.Stringer("prevHash", metadata.Prev)) defer func() { elapsed := time.Since(start) sm.Logger.Debug("Built block", - zap.Uint64("seq", simplexMetadata.Seq), - zap.Uint64("epoch", simplexMetadata.Epoch), - zap.Stringer("prevHash", simplexMetadata.Prev), + zap.Uint64("seq", metadata.Seq), + zap.Uint64("epoch", metadata.Epoch), + zap.Stringer("prevHash", metadata.Prev), zap.Duration("elapsed", elapsed), ) }() var simplexBlacklistBytes []byte - if simplexBlacklist != nil { - simplexBlacklistBytes = simplexBlacklist.Bytes() + if blacklist != nil { + simplexBlacklistBytes = blacklist.Bytes() } // In order to know where in the epoch change process we are, // we identify the current state by looking at the parent block's epoch info. currentState := parentBlock.Metadata.SimplexEpochInfo.NextState() - simplexMetadataBytes := simplexMetadata.Bytes() - prevBlockSeq := simplexMetadata.Seq - 1 + simplexMetadataBytes := metadata.Bytes() + prevBlockSeq := metadata.Seq - 1 switch currentState { case stateFirstSimplexBlock: diff --git a/msm/msm_test.go b/msm/msm_test.go index 405725b5..775e41e6 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -4,13 +4,9 @@ package metadata import ( - "bytes" "context" "crypto/rand" - "encoding/asn1" "fmt" - "maps" - "math" "testing" "time" @@ -19,149 +15,6 @@ import ( "github.com/stretchr/testify/require" ) -// fakeVMBlock is a minimal VMBlock implementation for tests. -type fakeVMBlock struct { - height uint64 -} - -func (f *fakeVMBlock) Digest() [32]byte { return [32]byte{} } -func (f *fakeVMBlock) Height() uint64 { return f.height } -func (f *fakeVMBlock) Timestamp() time.Time { return time.Time{} } -func (f *fakeVMBlock) Verify(_ context.Context) error { return nil } - -type outerBlock struct { - finalization *simplex.Finalization - block StateMachineBlock -} - -type blockStore map[uint64]*outerBlock - -func (bs blockStore) clone() blockStore { - newStore := make(blockStore) - maps.Copy(newStore, bs) - return newStore -} - -func (bs blockStore) getBlock(seq uint64, _ [32]byte) (StateMachineBlock, *simplex.Finalization, error) { - blk, exits := bs[seq] - if !exits { - return StateMachineBlock{}, nil, fmt.Errorf("%w: block %d not found", simplex.ErrBlockNotFound, seq) - } - return blk.block, blk.finalization, nil -} - -type approvalsRetriever struct { - result ValidatorSetApprovals -} - -func (a approvalsRetriever) RetrieveApprovals() ValidatorSetApprovals { - return a.result -} - -type signatureVerifier struct { - err error -} - -func (sv *signatureVerifier) VerifySignature(signature []byte, message []byte, publicKey []byte) error { - return sv.err -} - -type signatureAggregator struct { - weightByNodeID map[string]uint64 - totalWeight uint64 -} - -type aggregatrdSignature struct { - Signatures [][]byte -} - -func (sv *signatureAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (sv *signatureAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - all := make([][]byte, 0, len(sigs)+1) - all = append(all, sigs...) - if len(existing) > 0 { - all = append(all, existing) - } - return asn1.Marshal(aggregatrdSignature{Signatures: all}) -} - -func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { - var sum uint64 - for _, signer := range signers { - sum += sv.weightByNodeID[string(signer)] - } - return sum*3 > sv.totalWeight*2 -} - -func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { - return func(weights []simplex.NodeWeight) simplex.SignatureAggregator { - s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} - for _, nw := range weights { - s.weightByNodeID[string(nw.Node)] = nw.Weight - s.totalWeight += nw.Weight - } - return s - } -} - -type noOpPChainListener struct{} - -func (n *noOpPChainListener) WaitForProgress(ctx context.Context, _ uint64) error { - <-ctx.Done() - return ctx.Err() -} - -type blockBuilder struct { - block VMBlock - err error -} - -func (bb *blockBuilder) WaitForPendingBlock(_ context.Context) { - // Block is always ready in tests. -} - -func (bb *blockBuilder) BuildBlock(_ context.Context, _ uint64) (VMBlock, error) { - return bb.block, bb.err -} - -type validatorSetRetriever struct { - result NodeBLSMappings - resultMap map[uint64]NodeBLSMappings - err error -} - -func (vsr *validatorSetRetriever) getValidatorSet(height uint64) (NodeBLSMappings, error) { - if vsr.resultMap != nil { - if result, ok := vsr.resultMap[height]; ok { - return result, vsr.err - } - } - return vsr.result, vsr.err -} - -type keyAggregator struct{} - -func (ka *keyAggregator) AggregateKeys(keys ...[]byte) ([]byte, error) { - aggregated := make([]byte, 0) - for _, key := range keys { - aggregated = append(aggregated, key...) - } - return aggregated, nil -} - -var ( - genesisBlock = StateMachineBlock{ - // Genesis block metadata has all zero values - InnerBlock: &InnerBlock{ - TS: time.Now(), - Bytes: []byte{1, 2, 3}, - }, - } -) - func TestMSMBuildAndVerifyBlocksAfterGenesis(t *testing.T) { validMD := simplex.ProtocolMetadata{ Round: 1, @@ -846,149 +699,6 @@ func TestMSMFullEpochLifecycle(t *testing.T) { } } -type dynamicApprovalsRetriever struct { - approvals *ValidatorSetApprovals -} - -func (d *dynamicApprovalsRetriever) RetrieveApprovals() ValidatorSetApprovals { - return *d.approvals -} - -func makeChain(t *testing.T, simplexStartHeight uint64, endHeight uint64) []StateMachineBlock { - startTime := time.Now().Add(-time.Duration(endHeight+2) * time.Second) - blocks := make([]StateMachineBlock, 0, endHeight+1) - var round, seq uint64 - for h := uint64(0); h <= endHeight; h++ { - index := len(blocks) - - if h == 0 { - blocks = append(blocks, genesisBlock) - continue - } - - if h < simplexStartHeight { - blocks = append(blocks, makeNonSimplexBlock(t, simplexStartHeight, startTime, h)) - continue - } - - seq = uint64(index) - - blocks = append(blocks, makeNormalSimplexBlock(t, index, blocks, startTime, h, round, seq)) - round++ - } - return blocks -} - -func makeNormalSimplexBlock(t *testing.T, index int, blocks []StateMachineBlock, start time.Time, h uint64, round uint64, seq uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - prev := genesisBlock.Digest() - if index > 0 { - prev = blocks[index-1].Digest() - } - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - Metadata: StateMachineMetadata{ - PChainHeight: 100, - SimplexProtocolMetadata: (&simplex.ProtocolMetadata{ - Round: round, - Seq: seq, - Epoch: 1, - Prev: prev, - }).Bytes(), - SimplexEpochInfo: SimplexEpochInfo{ - PrevSealingBlockHash: [32]byte{}, - PChainReferenceHeight: 100, - EpochNumber: 1, - PrevVMBlockSeq: uint64(index), - }, - }, - } -} - -func makeNonSimplexBlock(t *testing.T, startHeight uint64, start time.Time, h uint64) StateMachineBlock { - content := make([]byte, 10) - _, err := rand.Read(content) - require.NoError(t, err) - - return StateMachineBlock{ - InnerBlock: &InnerBlock{ - TS: start.Add(time.Duration(h-startHeight) * time.Second), - BlockHeight: h, - Bytes: []byte{1, 2, 3}, - }, - } -} - -type testConfig struct { - blockStore blockStore - approvalsRetriever approvalsRetriever - signatureVerifier signatureVerifier - signatureAggregator signatureAggregator - blockBuilder blockBuilder - keyAggregator keyAggregator - validatorSetRetriever validatorSetRetriever -} - -func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { - bs := make(blockStore) - bs[0] = &outerBlock{block: genesisBlock} - - var testConfig testConfig - testConfig.blockStore = bs - testConfig.validatorSetRetriever.result = NodeBLSMappings{ - {BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}, - } - - smConfig := Config{ - GenesisValidatorSet: NodeBLSMappings{{BLSKey: []byte{1}, Weight: 1}, {BLSKey: []byte{2}, Weight: 1}}, - LastNonSimplexBlockPChainHeight: 100, - FirstEverSimplexBlock: func() *StateMachineBlock { - var res *StateMachineBlock - min := uint64(math.MaxUint64) - for seq, block := range testConfig.blockStore { - if block.block.Metadata.SimplexEpochInfo.EpochNumber == 0 { - continue - } - if seq < min { - min = seq - res = &block.block - } - } - return res - }, - GetTime: time.Now, - TimeSkewLimit: time.Second * 5, - Logger: testutil.MakeLogger(t), - GetBlock: testConfig.blockStore.getBlock, - MaxBlockBuildingWaitTime: time.Second, - ApprovalsRetriever: &testConfig.approvalsRetriever, - SignatureVerifier: &testConfig.signatureVerifier, - SignatureAggregatorCreator: newSignatureAggregatorCreator(), - BlockBuilder: &testConfig.blockBuilder, - KeyAggregator: &testConfig.keyAggregator, - GetPChainHeight: func() uint64 { - return 100 - }, - GetUpgrades: func() any { - return nil - }, - GetValidatorSet: testConfig.validatorSetRetriever.getValidatorSet, - PChainProgressListener: &noOpPChainListener{}, - LastNonSimplexInnerBlock: genesisBlock.InnerBlock, - } - - sm := NewStateMachine(&smConfig) - return sm, &testConfig -} - func TestIdentifyCurrentState(t *testing.T) { bvd := &BlockValidationDescriptor{} for _, tc := range []struct { @@ -1151,36 +861,6 @@ func TestSanitizeApprovals(t *testing.T) { }) } -// concatAggregator concatenates signatures for easy verification in tests. -type concatAggregator struct{} - -func (concatAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (concatAggregator) AppendSignatures(existing []byte, sigs ...[]byte) ([]byte, error) { - result := bytes.Join(sigs, nil) - return append(result, existing...), nil -} - -func (concatAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} - -type failingAggregator struct{} - -func (failingAggregator) Aggregate([]simplex.Signature) (simplex.QuorumCertificate, error) { - panic("unused in tests") -} - -func (failingAggregator) AppendSignatures([]byte, ...[]byte) ([]byte, error) { - return nil, fmt.Errorf("aggregation failed") -} - -func (failingAggregator) IsQuorum([]simplex.NodeID) bool { - return false -} - func TestComputeNewApproverSignaturesAndSigners(t *testing.T) { node0 := nodeID{0} node1 := nodeID{1} From f2b6f7ece4c17825354c4fbeec6a49b36c517d1c Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 13 May 2026 00:09:17 +0200 Subject: [PATCH 13/16] Address code review comments Signed-off-by: Yacov Manevich --- api.go | 16 ++++---- blacklist.go | 2 +- blacklist_test.go | 4 +- epoch.go | 12 +++--- epoch_failover_test.go | 14 +++---- epoch_test.go | 14 +++---- global.go | 6 +-- msm/build_decision.go | 47 +++++++---------------- msm/build_decision_test.go | 37 ++++++++---------- msm/encoding.go | 6 +-- msm/misc_test.go | 6 ++- msm/msm.go | 78 ++++++++++++-------------------------- pos_test.go | 2 +- testutil/comm.go | 10 ++--- testutil/controlled.go | 4 +- testutil/network.go | 6 +-- testutil/util.go | 2 +- 17 files changed, 107 insertions(+), 159 deletions(-) diff --git a/api.go b/api.go index fdf5db39..eef03e5a 100644 --- a/api.go +++ b/api.go @@ -58,7 +58,7 @@ type Storage interface { type Communication interface { // Nodes returns all nodes that participate in the epoch. - Nodes() NodeWeights + Nodes() Nodes // Send sends a message to the given destination node Send(msg *Message, destination NodeID) @@ -140,11 +140,11 @@ type SignatureAggregator interface { IsQuorum([]NodeID) bool } -// NodeWeights is a list of NodeWeight elements. -type NodeWeights []NodeWeight +// Nodes is a list of Node elements. +type Nodes []Node -// NodesIDs returns the NodeIDs of the nodes in the NodeWeights. -func (nws NodeWeights) NodesIDs() []NodeID { +// NodeIDs returns the NodeIDs of the nodes in the Nodes. +func (nws Nodes) NodeIDs() []NodeID { nodes := make([]NodeID, len(nws)) for i, nw := range nws { nodes[i] = nw.Node @@ -152,11 +152,11 @@ func (nws NodeWeights) NodesIDs() []NodeID { return nodes } -// NodeWeight is a struct that pairs a node with its weight in the signature aggregator. -type NodeWeight struct { +// Node is a struct that pairs a node with its weight in the signature aggregator. +type Node struct { Node NodeID Weight uint64 } // SignatureAggregatorCreator creates a SignatureAggregator from a list of nodes and their weights. -type SignatureAggregatorCreator func([]NodeWeight) SignatureAggregator +type SignatureAggregatorCreator func([]Node) SignatureAggregator diff --git a/blacklist.go b/blacklist.go index edc7abbb..a80c16d0 100644 --- a/blacklist.go +++ b/blacklist.go @@ -206,7 +206,7 @@ func (bl *Blacklist) ApplyUpdates(updates []BlacklistUpdate, round uint64) Black } // garbageCollectSuspectedNodes returns a new list of suspected nodes for the given round. -// NodesIDs that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. +// NodeIDs that are no longer suspected or have been redeemed, will not be included in the returned suspected nodes. // It will also garbage-collect any redeem votes from past orbits, unless hey have surpassed the threshold of f+1. // It does not modify the current blacklist. func (bl *Blacklist) garbageCollectSuspectedNodes(round uint64) SuspectedNodes { diff --git a/blacklist_test.go b/blacklist_test.go index faee4080..0a477234 100644 --- a/blacklist_test.go +++ b/blacklist_test.go @@ -465,8 +465,8 @@ func TestComputeBlacklistUpdates(t *testing.T) { func TestAdvanceRound(t *testing.T) { nodes := []uint16{0, 1, 2, 3} - // NodesIDs 0, 2 are suspected. - // NodesIDs 1 and 3 are not suspected. + // NodeIDs 0, 2 are suspected. + // NodeIDs 1 and 3 are not suspected. // Node 2 can be redeemed. suspectedNodesBefore := SuspectedNodes{ {NodeIndex: 0, SuspectingCount: 2, OrbitSuspected: 1, RedeemingCount: 1, OrbitToRedeem: 1}, diff --git a/epoch.go b/epoch.go index d6a4466d..157d7b7c 100644 --- a/epoch.go +++ b/epoch.go @@ -95,7 +95,7 @@ type Epoch struct { blockBuilderCtx context.Context blockBuilderCancelFunc context.CancelFunc nodes NodeIDs - nodeWeights NodeWeights + nodeWeights Nodes eligibleNodeIDs map[string]struct{} rounds map[uint64]*Round emptyVotes map[uint64]*EmptyVoteSet @@ -201,8 +201,8 @@ func (e *Epoch) init() error { e.blockBuilderCtx = context.Background() e.blockBuilderCancelFunc = func() {} e.nodeWeights = e.Comm.Nodes() - SortNodesWeights(e.nodeWeights) - e.nodes = e.nodeWeights.NodesIDs() + SortNodes(e.nodeWeights) + e.nodes = e.nodeWeights.NodeIDs() e.timedOutRounds = make(map[uint16]uint64, len(e.nodes)) e.redeemedRounds = make(map[uint16]uint64, len(e.nodes)) e.rounds = make(map[uint64]*Round) @@ -3424,9 +3424,9 @@ func (e *Epoch) nextSeqToCommit() uint64 { return e.Storage.NumBlocks() } -// SortNodesWeights sorts the nodes in place by their byte representations. -func SortNodesWeights(nodes NodeWeights) { - slices.SortFunc(nodes, func(a, b NodeWeight) int { +// SortNodes sorts the nodes in place by their byte representations. +func SortNodes(nodes Nodes) { + slices.SortFunc(nodes, func(a, b Node) int { return bytes.Compare(a.Node[:], b.Node[:]) }) } diff --git a/epoch_failover_test.go b/epoch_failover_test.go index 321acb1f..9db6a9b1 100644 --- a/epoch_failover_test.go +++ b/epoch_failover_test.go @@ -87,7 +87,7 @@ func TestEpochRebroadcastsEmptyVoteAfterBlockProposalReceived(t *testing.T) { bb := testutil.NewTestBlockBuilder() nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes.EqualWeightedNodeWeights()) + comm := newRebroadcastComm(nodes.EqualWeightedNodes()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -354,7 +354,7 @@ func TestEpochLeaderRecursivelyFetchNotarizedBlocks(t *testing.T) { recordedMessages := make(chan *Message, 100) - comm := &recordingComm{Communication: testutil.NoopComm(nodes.EqualWeightedNodeWeights()), SentMessages: recordedMessages} + comm := &recordingComm{Communication: testutil.NoopComm(nodes.EqualWeightedNodes()), SentMessages: recordedMessages} conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0], comm, bb) e, err := NewEpoch(conf) @@ -1115,18 +1115,18 @@ func TestEpochBlacklist(t *testing.T) { } type rebroadcastComm struct { - nodes NodeWeights + nodes Nodes emptyVotes chan *EmptyVote } -func newRebroadcastComm(nodes NodeWeights) *rebroadcastComm { +func newRebroadcastComm(nodes Nodes) *rebroadcastComm { return &rebroadcastComm{ nodes: nodes, emptyVotes: make(chan *EmptyVote, 10), } } -func (r *rebroadcastComm) Nodes() NodeWeights { +func (r *rebroadcastComm) Nodes() Nodes { return r.nodes } @@ -1144,7 +1144,7 @@ func TestEpochRebroadcastsEmptyVote(t *testing.T) { bb := testutil.NewTestBlockBuilder() nodes := NodeIDs{{1}, {2}, {3}, {4}} - comm := newRebroadcastComm(nodes.EqualWeightedNodeWeights()) + comm := newRebroadcastComm(nodes.EqualWeightedNodes()) conf, wal, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[3], comm, bb) epochTime := conf.StartTime e, err := NewEpoch(conf) @@ -1227,7 +1227,7 @@ func runCrashAndRestartExecution(t *testing.T, e *Epoch, bb *testutil.TestBlockB // Case 2: t.Run(fmt.Sprintf("%s-with-crash", t.Name()), func(t *testing.T) { - conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0].Node, testutil.NewNoopComm(nodes.NodesIDs()), bbAfterCrash) + conf, _, _ := testutil.DefaultTestNodeEpochConfig(t, nodes[0].Node, testutil.NewNoopComm(nodes.NodeIDs()), bbAfterCrash) conf.Storage = cloneStorage conf.WAL = cloneWAL diff --git a/epoch_test.go b/epoch_test.go index 39bb1aab..a7a034b5 100644 --- a/epoch_test.go +++ b/epoch_test.go @@ -203,7 +203,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat } for range numEmptyNotarizations { - leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) if e.ID.Equals(leader) { fVote := advanceWithFinalizeCheck(t, e, recordingComm, bb) finalizeVoteSeqs[fVote.Finalization.Seq] = fVote @@ -237,7 +237,7 @@ func testFinalizeSameSequenceGap(t *testing.T, nodes []NodeID, numEmptyNotarizat verified <- struct{}{} } - leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), 1+numEmptyNotarizations+numNotarizations) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), 1+numEmptyNotarizations+numNotarizations) if e.ID.Equals(leader) { return } @@ -446,7 +446,7 @@ func TestEpochIndexFinalization(t *testing.T) { // 1 & 2 sigAggr := e.SignatureAggregatorCreator(conf.Comm.Nodes()) - finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, e.Comm.Nodes().NodesIDs()) + finalization, _ := testutil.NewFinalizationRecord(t, conf.Logger, sigAggr, firstBlock, e.Comm.Nodes().NodeIDs()) testutil.InjectTestFinalization(t, e, &finalization, nodes[1]) storage.WaitForBlockCommit(2) @@ -751,10 +751,10 @@ func TestEpochStartedTwice(t *testing.T) { } func advanceRoundFromEmpty(t *testing.T, e *Epoch) { - leader := LeaderForRound(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) + leader := LeaderForRound(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) require.False(t, e.ID.Equals(leader), "epoch cannot be the leader for the empty round") - emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes().NodesIDs(), e.Metadata().Round) + emptyNote := testutil.NewEmptyNotarization(e.Comm.Nodes().NodeIDs(), e.Metadata().Round) err := e.HandleMessage(&Message{ EmptyNotarization: emptyNote, }, leader) @@ -1644,7 +1644,7 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, nodes := e.Comm.Nodes() quorum := simplex.Quorum(len(nodes)) // leader is the proposer of the new block for the given round - leader := simplex.LeaderForRound(nodes.NodesIDs(), e.Metadata().Round) + leader := simplex.LeaderForRound(nodes.NodeIDs(), e.Metadata().Round) md := e.Metadata() if injectedMD != nil { md = *injectedMD @@ -1678,7 +1678,7 @@ func advanceRound(t *testing.T, e *simplex.Epoch, bb *testutil.TestBlockBuilder, if notarize { // start at one since our node has already voted sigAggr := e.SignatureAggregatorCreator(nodes) - n, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes.NodesIDs()[0:quorum]) + n, err := testutil.NewNotarization(e.Logger, sigAggr, block, nodes.NodeIDs()[0:quorum]) testutil.InjectTestNotarization(t, e, n, nodes[1].Node) e.WAL.(*testutil.TestWAL).AssertNotarization(block.Metadata.Round) diff --git a/global.go b/global.go index b632acea..771da09a 100644 --- a/global.go +++ b/global.go @@ -52,10 +52,10 @@ func (nodes NodeIDs) IndexOf(id NodeID) int { return -1 } -func (nodes NodeIDs) EqualWeightedNodeWeights() NodeWeights { - weights := make(NodeWeights, len(nodes)) +func (nodes NodeIDs) EqualWeightedNodes() Nodes { + weights := make(Nodes, len(nodes)) for i, node := range nodes { - weights[i] = NodeWeight{ + weights[i] = Node{ Node: node, Weight: 1, } diff --git a/msm/build_decision.go b/msm/build_decision.go index b19ecaf2..f0786114 100644 --- a/msm/build_decision.go +++ b/msm/build_decision.go @@ -14,31 +14,10 @@ import ( // blockBuildingDecision represents the decision of whether we should build a block at the current time, // and if so, whether we should also transition to a new epoch along the way. -type blockBuildingDecision int8 - -const ( - decisionUndefined blockBuildingDecision = iota - decisionBuild // We should build a block, and we don't need to transition to a new epoch. - decisionTransitionEpoch // We should transition to a new epoch immediately, but we don't need to build a block. - decisionBuildAndTransitionEpoch // We should build a block and transition to a new epoch along the way. - decisionContextCanceled -) - -func (bbd blockBuildingDecision) String() string { - switch bbd { - case decisionUndefined: - return "undefined" - case decisionBuild: - return "build block" - case decisionTransitionEpoch: - return "transition epoch" - case decisionBuildAndTransitionEpoch: - return "build block and transition epoch" - case decisionContextCanceled: - return "context canceled" - default: - return "unknown" - } +type blockBuildingDecision struct { + buildInnerBlock bool + transitionEpoch bool + pChainHeight uint64 } // PChainProgressListener listens for changes in the P-chain height. @@ -65,18 +44,18 @@ type blockBuildingDecider struct { // The P-chain height is returned because sampling the P-chain height afterwards might be inconsistent with the decision that was made. func (bbd *blockBuildingDecider) shouldBuildBlock( ctx context.Context, -) (blockBuildingDecision, uint64, error) { +) (blockBuildingDecision, error) { for { pChainHeight := bbd.getPChainHeight() shouldTransitionEpoch, err := bbd.hasValidatorSetChanged(pChainHeight) if err != nil { - return decisionUndefined, 0, err + return blockBuildingDecision{}, err } if shouldTransitionEpoch { // If we should transition to a new epoch, maybe we can also build a block along the way. - return bbd.buildBlockWithEpochTransition(ctx), pChainHeight, nil + return bbd.buildBlockWithEpochTransition(ctx, pChainHeight) } // Else, we don't need to transition to a new epoch, but maybe we should build a block. @@ -85,7 +64,7 @@ func (bbd *blockBuildingDecider) shouldBuildBlock( // If the context was cancelled in the meantime, abandon evaluation. if ctx.Err() != nil { - return decisionContextCanceled, 0, nil + return blockBuildingDecision{}, ctx.Err() } // If we've reached here, either the P-chain height has changed, or a block is ready to be built. @@ -99,7 +78,7 @@ func (bbd *blockBuildingDecider) shouldBuildBlock( // Else, we have reached here because a block is ready to be built, and the P-chain height has not changed, // which means we should build a block. - return decisionBuild, pChainHeight, nil + return blockBuildingDecision{buildInnerBlock: true, pChainHeight: pChainHeight}, nil } } @@ -130,7 +109,7 @@ func (bbd *blockBuildingDecider) waitForPChainChangeOrPendingBlock(ctx context.C // It waits up to a limited amount of time (bbd.maxBlockBuildingWaitTime) for a block to be ready to be built, // and if no block is ready by then, it returns the decision to transition epoch without building a block. // Otherwise, it returns the decision to build a block and transition epoch along the way. -func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Context) blockBuildingDecision { +func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Context, pChainHeight uint64) (blockBuildingDecision, error) { impatientContext, cancel := context.WithTimeout(ctx, bbd.maxBlockBuildingWaitTime) defer cancel() @@ -139,15 +118,15 @@ func (bbd *blockBuildingDecider) buildBlockWithEpochTransition(ctx context.Conte bbd.waitForPendingBlock(impatientContext) if ctx.Err() != nil { - return decisionContextCanceled + return blockBuildingDecision{}, ctx.Err() } if impatientContext.Err() != nil { // We have returned from waitForPendingBlock because impatientContext has timed out, // which means we don't need to build a block. - return decisionTransitionEpoch + return blockBuildingDecision{transitionEpoch: true, pChainHeight: pChainHeight}, nil } // Block is ready to be built - return decisionBuildAndTransitionEpoch + return blockBuildingDecision{buildInnerBlock: true, transitionEpoch: true, pChainHeight: pChainHeight}, nil } diff --git a/msm/build_decision_test.go b/msm/build_decision_test.go index 828aabeb..b1b13402 100644 --- a/msm/build_decision_test.go +++ b/msm/build_decision_test.go @@ -35,10 +35,9 @@ func TestShouldBuildBlock_VMSignalsBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuild, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_ContextCanceled(t *testing.T) { @@ -60,9 +59,9 @@ func TestShouldBuildBlock_ContextCanceled(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, _, err := bbd.shouldBuildBlock(ctx) - require.NoError(t, err) - require.Equal(t, decisionContextCanceled, result) + decision, err := bbd.shouldBuildBlock(ctx) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, blockBuildingDecision{}, decision) } func TestShouldBuildBlock_PChainHeightChangeTriggersEpochTransition(t *testing.T) { @@ -92,10 +91,9 @@ func TestShouldBuildBlock_PChainHeightChangeTriggersEpochTransition(t *testing.T getPChainHeight: func() uint64 { return pChainHeight.Load() }, } - result, resultPChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionTransitionEpoch, result) - require.Equal(t, uint64(200), resultPChainHeight) + require.Equal(t, blockBuildingDecision{transitionEpoch: true, pChainHeight: 200}, decision) } func TestShouldBuildBlock_PChainHeightChangeButNoEpochTransition(t *testing.T) { @@ -128,10 +126,9 @@ func TestShouldBuildBlock_PChainHeightChangeButNoEpochTransition(t *testing.T) { getPChainHeight: func() uint64 { return pChainHeight.Load() }, } - result, resultPChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuild, result) - require.Equal(t, uint64(200), resultPChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, pChainHeight: 200}, decision) } func TestShouldBuildBlock_EpochTransitionWithVMBlock(t *testing.T) { @@ -148,10 +145,9 @@ func TestShouldBuildBlock_EpochTransitionWithVMBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionBuildAndTransitionEpoch, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{buildInnerBlock: true, transitionEpoch: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_EpochTransitionWithoutVMBlock(t *testing.T) { @@ -170,10 +166,9 @@ func TestShouldBuildBlock_EpochTransitionWithoutVMBlock(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, pChainHeight, err := bbd.shouldBuildBlock(t.Context()) + decision, err := bbd.shouldBuildBlock(t.Context()) require.NoError(t, err) - require.Equal(t, decisionTransitionEpoch, result) - require.Equal(t, uint64(100), pChainHeight) + require.Equal(t, blockBuildingDecision{transitionEpoch: true, pChainHeight: 100}, decision) } func TestShouldBuildBlock_EpochTransitionContextCanceled(t *testing.T) { @@ -196,7 +191,7 @@ func TestShouldBuildBlock_EpochTransitionContextCanceled(t *testing.T) { getPChainHeight: func() uint64 { return 100 }, } - result, _, err := bbd.shouldBuildBlock(ctx) - require.NoError(t, err) - require.Equal(t, decisionContextCanceled, result) + decision, err := bbd.shouldBuildBlock(ctx) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, blockBuildingDecision{}, decision) } diff --git a/msm/encoding.go b/msm/encoding.go index 414dd241..258f33f5 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -223,10 +223,10 @@ func (nea *NextEpochApprovals) Equals(other *NextEpochApprovals) bool { type NodeBLSMappings []NodeBLSMapping -func (nbms NodeBLSMappings) NodeWeights() simplex.NodeWeights { - nodeWeights := make(simplex.NodeWeights, len(nbms)) +func (nbms NodeBLSMappings) NodeWeights() simplex.Nodes { + nodeWeights := make(simplex.Nodes, len(nbms)) for i, nbm := range nbms { - nodeWeights[i] = simplex.NodeWeight{ + nodeWeights[i] = simplex.Node{ Node: nbm.NodeID[:], Weight: nbm.Weight, } diff --git a/msm/misc_test.go b/msm/misc_test.go index 91325712..82153a06 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -255,7 +255,7 @@ func (sv *signatureAggregator) IsQuorum(signers []simplex.NodeID) bool { } func newSignatureAggregatorCreator() simplex.SignatureAggregatorCreator { - return func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + return func(weights []simplex.Node) simplex.SignatureAggregator { s := &signatureAggregator{weightByNodeID: make(map[string]uint64, len(weights))} for _, nw := range weights { s.weightByNodeID[string(nw.Node)] = nw.Weight @@ -459,7 +459,9 @@ func newStateMachine(t *testing.T) (*StateMachine, *testConfig) { LastNonSimplexInnerBlock: genesisBlock.InnerBlock, } - sm := NewStateMachine(&smConfig) + sm, err := NewStateMachine(&smConfig) + require.NoError(t, err) + return sm, &testConfig } diff --git a/msm/msm.go b/msm/msm.go index 94e2801e..ca24de07 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -145,9 +145,13 @@ const ( stateBuildBlockEpochSealed ) -func NewStateMachine(config *Config) *StateMachine { +func NewStateMachine(config *Config) (*StateMachine, error) { + if config.LastNonSimplexInnerBlock == nil { + config.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") + return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") + } sm := StateMachine{Config: config} - return &sm + return &sm, nil } // BuildBlock constructs the next block on top of the given parent block, and passes in the provided simplex metadata and blacklist. @@ -291,31 +295,32 @@ func (sm *StateMachine) buildBlockNormalOp(ctx context.Context, parentBlock Stat } blockBuildingDecider := sm.createBlockBuildingDecider(parentBlock) - decisionToBuildBlock, pChainHeight, err := blockBuildingDecider.shouldBuildBlock(ctx) + decisionToBuildBlock, err := blockBuildingDecider.shouldBuildBlock(ctx) if err != nil { return nil, err } - sm.Logger.Debug("Block building decision", zap.Stringer("decision", decisionToBuildBlock)) + sm.Logger.Debug("Block building decision", + zap.Bool("build inner block", decisionToBuildBlock.buildInnerBlock), + zap.Bool("transition epoch", decisionToBuildBlock.transitionEpoch), + zap.Uint64("P-chain height", decisionToBuildBlock.pChainHeight)) + + if decisionToBuildBlock.transitionEpoch { + sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", decisionToBuildBlock.pChainHeight)) + newSimplexEpochInfo.NextPChainReferenceHeight = decisionToBuildBlock.pChainHeight + } - var childBlock VMBlock + var innerBlock VMBlock - switch decisionToBuildBlock { - case decisionBuild, decisionBuildAndTransitionEpoch: - // If we reached here, we need to build a new block, and maybe also transition to a new epoch. - return sm.buildBlockAndMaybeTransitionEpoch(ctx, parentBlock, simplexMetadata, simplexBlacklist, childBlock, decisionToBuildBlock, newSimplexEpochInfo, pChainHeight) - case decisionTransitionEpoch: - // If we reached here, we don't need to build an inner block, yet we need to transition to a new epoch. - // Initiate the epoch transition by setting the next P-chain reference height for the new epoch info, - // and build a block without an inner block. - newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight - sm.Logger.Debug("Transitioning epoch without building block", zap.Uint64("newPChainRefHeight", pChainHeight)) - return sm.wrapBlock(parentBlock, nil, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil - case decisionContextCanceled: - return nil, ctx.Err() - default: - return nil, fmt.Errorf("unknown block building decision %d", decisionToBuildBlock) + if decisionToBuildBlock.buildInnerBlock { + // TODO: This P-chain height should be taken from the ICM epoch + innerBlock, err = sm.BlockBuilder.BuildBlock(ctx, decisionToBuildBlock.pChainHeight) + if err != nil { + return nil, err + } } + + return sm.wrapBlock(parentBlock, innerBlock, newSimplexEpochInfo, decisionToBuildBlock.pChainHeight, simplexMetadata, simplexBlacklist), nil } func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock) blockBuildingDecider { @@ -355,38 +360,9 @@ func (sm *StateMachine) createBlockBuildingDecider(parentBlock StateMachineBlock return blockBuildingDecider } -func (sm *StateMachine) buildBlockAndMaybeTransitionEpoch(ctx context.Context, - parentBlock StateMachineBlock, - simplexMetadata []byte, - simplexBlacklist []byte, - childBlock VMBlock, - decisionToBuildBlock blockBuildingDecision, - newSimplexEpochInfo SimplexEpochInfo, - pChainHeight uint64) (*StateMachineBlock, error) { - // TODO: This P-chain height should be taken from the ICM epoch - childBlock, err := sm.BlockBuilder.BuildBlock(ctx, pChainHeight) - if err != nil { - return nil, err - } - - if decisionToBuildBlock == decisionBuildAndTransitionEpoch { - // We need to also transition to a new epoch, in addition to building an inner block, - // so set the next P-chain reference height for the new epoch info. - newSimplexEpochInfo.NextPChainReferenceHeight = pChainHeight - sm.Logger.Debug("Transitioning epoch after building block", zap.Uint64("newPChainRefHeight", pChainHeight)) - } - - return sm.wrapBlock(parentBlock, childBlock, newSimplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil -} - // buildBlockZero builds the first ever block for Simplex, // which is a special block that introduces the first validator set and starts the first epoch. func (sm *StateMachine) buildBlockZero(parentBlock StateMachineBlock, simplexMetadata, simplexBlacklist []byte) (*StateMachineBlock, error) { - if sm.LastNonSimplexInnerBlock == nil { - sm.Logger.Error("Last non-Simplex inner block is nil, cannot build zero block with correct metadata") - return nil, fmt.Errorf("failed constructing zero block: last non-Simplex inner block is nil") - } - pChainHeight := sm.LastNonSimplexBlockPChainHeight var validatorSet NodeBLSMappings @@ -436,10 +412,6 @@ func (sm *StateMachine) verifyBlockZero(block *StateMachineBlock, prevBlock Stat return fmt.Errorf("block is nil") } - if sm.LastNonSimplexInnerBlock == nil { - return fmt.Errorf("failed verifying zero block: last non-Simplex inner block is nil") - } - simplexEpochInfo := block.Metadata.SimplexEpochInfo if simplexEpochInfo.EpochNumber != 1 { diff --git a/pos_test.go b/pos_test.go index 177492e1..1d1d7f08 100644 --- a/pos_test.go +++ b/pos_test.go @@ -23,7 +23,7 @@ func TestPoS(t *testing.T) { nodes := []simplex.NodeID{{1}, {2}, {3}, {4}} - posSigAggregatorCreator := func(_ []simplex.NodeWeight) simplex.SignatureAggregator { + posSigAggregatorCreator := func(_ []simplex.Node) simplex.SignatureAggregator { return &testutil.TestSignatureAggregator{ IsQuorumFunc: func(signatures []simplex.NodeID) bool { var totalWeight uint64 diff --git a/testutil/comm.go b/testutil/comm.go index a3b5c2d3..71f33191 100644 --- a/testutil/comm.go +++ b/testutil/comm.go @@ -22,10 +22,10 @@ import ( // - bool: true if the message can be transmitted, false otherwise type MessageFilter func(msg *simplex.Message, from simplex.NodeID, to simplex.NodeID) bool -type NoopComm simplex.NodeWeights +type NoopComm simplex.Nodes -func (n NoopComm) Nodes() simplex.NodeWeights { - return simplex.NodeWeights(n) +func (n NoopComm) Nodes() simplex.Nodes { + return simplex.Nodes(n) } func (n NoopComm) Send(*simplex.Message, simplex.NodeID) { @@ -51,7 +51,7 @@ func NewTestComm(from simplex.NodeID, net *BasicInMemoryNetwork, messageFilter M } } -func (c *TestComm) Nodes() simplex.NodeWeights { +func (c *TestComm) Nodes() simplex.Nodes { return c.net.nodeWeights } @@ -192,5 +192,5 @@ func AllowAllMessages(*simplex.Message, simplex.NodeID, simplex.NodeID) bool { } func NewNoopComm(nodes simplex.NodeIDs) NoopComm { - return NoopComm(nodes.EqualWeightedNodeWeights()) + return NoopComm(nodes.EqualWeightedNodes()) } diff --git a/testutil/controlled.go b/testutil/controlled.go index 929835f0..e9a337d3 100644 --- a/testutil/controlled.go +++ b/testutil/controlled.go @@ -21,7 +21,7 @@ type ControlledInMemoryNetwork struct { // NewControlledNetwork creates an in-memory network. Node IDs must be provided before // adding instances, as nodeWeights require prior knowledge of all participants. func NewControlledNetwork(t *testing.T, nodes simplex.NodeIDs) *ControlledInMemoryNetwork { - simplex.SortNodesWeights(nodes.EqualWeightedNodeWeights()) + simplex.SortNodes(nodes.EqualWeightedNodes()) net := &ControlledInMemoryNetwork{ BasicInMemoryNetwork: NewBasicInMemoryNetwork(t, nodes), Instances: make([]*ControlledNode, 0), @@ -72,7 +72,7 @@ func (n *ControlledInMemoryNetwork) AdvanceWithoutLeader(round uint64, laggingNo } for _, n := range n.Instances { - leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes().NodesIDs(), n.E.Metadata().Round)) + leader := n.E.ID.Equals(simplex.LeaderForRound(n.E.Comm.Nodes().NodeIDs(), n.E.Metadata().Round)) if leader || laggingNodeId.Equals(n.E.ID) { continue } diff --git a/testutil/network.go b/testutil/network.go index 6d908af8..8ce6ac58 100644 --- a/testutil/network.go +++ b/testutil/network.go @@ -16,15 +16,15 @@ import ( type BasicInMemoryNetwork struct { t *testing.T nodes []simplex.NodeID - nodeWeights simplex.NodeWeights + nodeWeights simplex.Nodes lock sync.RWMutex disconnected map[string]struct{} instances []*BasicNode } func NewBasicInMemoryNetwork(t *testing.T, nodes simplex.NodeIDs) *BasicInMemoryNetwork { - nodeWeights := nodes.EqualWeightedNodeWeights() - simplex.SortNodesWeights(nodeWeights) + nodeWeights := nodes.EqualWeightedNodes() + simplex.SortNodes(nodeWeights) return &BasicInMemoryNetwork{ t: t, nodeWeights: nodeWeights, diff --git a/testutil/util.go b/testutil/util.go index f0232059..4be23174 100644 --- a/testutil/util.go +++ b/testutil/util.go @@ -32,7 +32,7 @@ func DefaultTestNodeEpochConfig(t *testing.T, nodeID simplex.NodeID, comm simple Verifier: &testVerifier{}, Storage: storage, BlockBuilder: bb, - SignatureAggregatorCreator: func(weights []simplex.NodeWeight) simplex.SignatureAggregator { + SignatureAggregatorCreator: func(weights []simplex.Node) simplex.SignatureAggregator { return &TestSignatureAggregator{N: len(weights)} }, BlockDeserializer: &BlockDeserializer{}, From 19001401796daaaa77904bb3dd5fedcdefa93606 Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 13 May 2026 15:37:38 +0200 Subject: [PATCH 14/16] Address code review comments II Signed-off-by: Yacov Manevich --- msm/encoding.go | 16 +++++++++++++--- msm/encoding_test.go | 4 ++-- msm/misc_test.go | 4 ++-- msm/msm.go | 41 +++++++++++++++++++---------------------- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/msm/encoding.go b/msm/encoding.go index 258f33f5..88c45995 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -234,6 +234,16 @@ func (nbms NodeBLSMappings) NodeWeights() simplex.Nodes { return nodeWeights } +// IndexByNodeID returns a mapping from NodeID to the validator's index in the set, +// which is the position used by approval bitmasks. +func (nbms NodeBLSMappings) IndexByNodeID() map[nodeID]int { + result := make(map[nodeID]int, len(nbms)) + for i, nbm := range nbms { + result[nbm.NodeID] = i + } + return result +} + func (nbms NodeBLSMappings) SelectSubset(bitmask bitmask) []simplex.NodeID { nodeIDs := make([]simplex.NodeID, 0, len(nbms)) for i, nbm := range nbms { @@ -289,10 +299,10 @@ type ValidatorSetApproval struct { type ValidatorSetApprovals []ValidatorSetApproval -func (vsa ValidatorSetApprovals) Filter(f func(int, ValidatorSetApproval, simplex.Logger) bool, logger simplex.Logger) ValidatorSetApprovals { +func (vsa ValidatorSetApprovals) Filter(f func(ValidatorSetApproval, simplex.Logger) bool, logger simplex.Logger) ValidatorSetApprovals { result := make(ValidatorSetApprovals, 0, len(vsa)) - for i, v := range vsa { - if f(i, v, logger) { + for _, v := range vsa { + if f(v, logger) { result = append(result, v) } } diff --git a/msm/encoding_test.go b/msm/encoding_test.go index efae8a0a..ef259503 100644 --- a/msm/encoding_test.go +++ b/msm/encoding_test.go @@ -411,7 +411,7 @@ func TestValidatorSetApprovalsFilter(t *testing.T) { {NodeID: nodeID{3}, PChainHeight: 30}, } - filtered := approvals.Filter(func(_ int, v ValidatorSetApproval, _ simplex.Logger) bool { + filtered := approvals.Filter(func(v ValidatorSetApproval, _ simplex.Logger) bool { return v.PChainHeight > 15 }, logger) require.Len(t, filtered, 2) @@ -419,7 +419,7 @@ func TestValidatorSetApprovalsFilter(t *testing.T) { require.Equal(t, uint64(30), filtered[1].PChainHeight) // Filter all - filtered = approvals.Filter(func(int, ValidatorSetApproval, simplex.Logger) bool { + filtered = approvals.Filter(func(ValidatorSetApproval, simplex.Logger) bool { return false }, logger) require.Empty(t, filtered) diff --git a/msm/misc_test.go b/msm/misc_test.go index 82153a06..b78d2cd3 100644 --- a/msm/misc_test.go +++ b/msm/misc_test.go @@ -212,7 +212,7 @@ type approvalsRetriever struct { result ValidatorSetApprovals } -func (a approvalsRetriever) RetrieveApprovals() ValidatorSetApprovals { +func (a approvalsRetriever) Approvals() ValidatorSetApprovals { return a.result } @@ -324,7 +324,7 @@ type dynamicApprovalsRetriever struct { approvals *ValidatorSetApprovals } -func (d *dynamicApprovalsRetriever) RetrieveApprovals() ValidatorSetApprovals { +func (d *dynamicApprovalsRetriever) Approvals() ValidatorSetApprovals { return *d.approvals } diff --git a/msm/msm.go b/msm/msm.go index ca24de07..93e1d2a4 100644 --- a/msm/msm.go +++ b/msm/msm.go @@ -38,7 +38,7 @@ func (smb *StateMachineBlock) Digest() [32]byte { // ApprovalsRetriever retrieves the approvals from validators of the next epoch for the epoch change. type ApprovalsRetriever interface { - RetrieveApprovals() ValidatorSetApprovals + Approvals() ValidatorSetApprovals } // KeyAggregator combines multiple public keys into a single aggregated public key. @@ -161,9 +161,11 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco return nil, fmt.Errorf("invalid ProtocolMetadata sequence number: should be > 0, got %d", metadata.Seq) } - parentBlock, _, err := sm.GetBlock(metadata.Seq-1, metadata.Prev) + prevBlockSeq := metadata.Seq - 1 + + parentBlock, _, err := sm.GetBlock(prevBlockSeq, metadata.Prev) if err != nil { - return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", metadata.Seq-1, metadata.Prev.String(), err) + return nil, fmt.Errorf("failed retrieving parent block at height %d with digest %s: %w", prevBlockSeq, metadata.Prev.String(), err) } start := time.Now() @@ -193,7 +195,6 @@ func (sm *StateMachine) BuildBlock(ctx context.Context, metadata simplex.Protoco currentState := parentBlock.Metadata.SimplexEpochInfo.NextState() simplexMetadataBytes := metadata.Bytes() - prevBlockSeq := metadata.Seq - 1 switch currentState { case stateFirstSimplexBlock: @@ -494,7 +495,7 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren // We retrieve approvals that validators have sent us for the next epoch. // These approvals are signed by validators of the next epoch. - approvalsFromPeers := sm.ApprovalsRetriever.RetrieveApprovals() + approvalsFromPeers := sm.ApprovalsRetriever.Approvals() sm.Logger.Debug("Retrieved approvals from peers", zap.Int("numApprovals", len(approvalsFromPeers))) nextPChainHeight := newSimplexEpochInfo.NextPChainReferenceHeight @@ -528,7 +529,7 @@ func (sm *StateMachine) buildBlockCollectingApprovals(ctx context.Context, paren sm.Logger.Debug("Have enough approvals to seal epoch, building sealing block") // Else, we have enough approvals to seal the epoch, so we create the sealing block. - return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, newApprovals, pChainHeight) + return sm.createSealingBlock(ctx, parentBlock, simplexMetadata, simplexBlacklist, newSimplexEpochInfo, pChainHeight) } // buildBlockImpatiently builds a block by waiting for the VM to build a block until MaxBlockBuildingWaitTime. @@ -554,7 +555,7 @@ func (sm *StateMachine) buildBlockImpatiently(ctx context.Context, parentBlock S return sm.wrapBlock(parentBlock, childBlock, simplexEpochInfo, pChainHeight, simplexMetadata, simplexBlacklist), nil } -func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, newApprovals *approvals, pChainHeight uint64) (*StateMachineBlock, error) { +func (sm *StateMachine) createSealingBlock(ctx context.Context, parentBlock StateMachineBlock, simplexMetadata []byte, simplexBlacklist []byte, simplexEpochInfo SimplexEpochInfo, pChainHeight uint64) (*StateMachineBlock, error) { validators, err := sm.GetValidatorSet(simplexEpochInfo.NextPChainReferenceHeight) if err != nil { return nil, err @@ -687,34 +688,30 @@ func constructSimplexZeroBlockSimplexEpochInfo(pChainHeight uint64, newValidator } func computeNewApprovals( - nextEpochApprovals *NextEpochApprovals, + prevNextEpochApprovals *NextEpochApprovals, approvalsFromPeers ValidatorSetApprovals, pChainHeight uint64, sigAggr simplex.SignatureAggregator, validators NodeBLSMappings, logger simplex.Logger, ) (*approvals, error) { - if nextEpochApprovals == nil { - nextEpochApprovals = &NextEpochApprovals{} + if prevNextEpochApprovals == nil { + prevNextEpochApprovals = &NextEpochApprovals{} } - oldApprovingNodes := bitmaskFromBytes(nextEpochApprovals.NodeIDs) + oldApprovingNodes := bitmaskFromBytes(prevNextEpochApprovals.NodeIDs) - // We map each validator to its relative index in the validator set. - nodeID2ValidatorIndex := make(map[nodeID]int) - for i, nbm := range validators { - nodeID2ValidatorIndex[nbm.NodeID] = i - } + nodeID2ValidatorIndex := validators.IndexByNodeID() oldApprovalFromPeersCount := len(approvalsFromPeers) // We have the approvals obtained from peers, but we need to sanitize them by filtering out approvals that are not valid, // such as approvals that do not agree with our candidate auxiliary info digest and P-Chain height, // and approvals that are from nodes that are not in the validator set or have already approved in prior blocks. approvalsFromPeers = sanitizeApprovals(approvalsFromPeers, pChainHeight, nodeID2ValidatorIndex, oldApprovingNodes, logger) - logger.Debug("Santizied approvals after filtering out invalid approvals", zap.Int("numApprovalsBefore", oldApprovalFromPeersCount), zap.Int("numApprovalsAfter", len(approvalsFromPeers))) + logger.Debug("Sanitized approvals after filtering out invalid approvals", zap.Int("numApprovalsBefore", oldApprovalFromPeersCount), zap.Int("numApprovalsAfter", len(approvalsFromPeers))) // Next we aggregate both previous and new approvals to compute the new aggregated signatures and the new bitmask of approving nodes. - aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(nextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr, logger) + aggregatedSignature, newApprovingNodes, err := computeNewApproverSignaturesAndSigners(prevNextEpochApprovals, approvalsFromPeers, oldApprovingNodes, nodeID2ValidatorIndex, sigAggr, logger) if err != nil { return nil, err } @@ -791,8 +788,8 @@ func sanitizeApprovals(approvals ValidatorSetApprovals, pChainHeight uint64, nod return approvals.Filter(filter1, logger).Filter(filter2, logger).UniqueByNodeID() } -func approvalsThatAgreeWithPChainHeight(pChainHeight uint64) func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { - return func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { +func approvalsThatAgreeWithPChainHeight(pChainHeight uint64) func(approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(approval ValidatorSetApproval, logger simplex.Logger) bool { // Pick only approvals that agree with our P-Chain height ok := approval.PChainHeight == pChainHeight if !ok { @@ -805,8 +802,8 @@ func approvalsThatAgreeWithPChainHeight(pChainHeight uint64) func(i int, approva } } -func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { - return func(i int, approval ValidatorSetApproval, logger simplex.Logger) bool { +func approvalsThatAreInValidatorSetAndHaveNotAlreadyApproved(oldApprovingNodes bitmask, nodeID2ValidatorIndex map[nodeID]int) func(approval ValidatorSetApproval, logger simplex.Logger) bool { + return func(approval ValidatorSetApproval, logger simplex.Logger) bool { approvingNodeIndexOfNewApprover, exists := nodeID2ValidatorIndex[approval.NodeID] if !exists { logger.Debug("Filtering out approval from node that is not in the validator set", From 5ef3c853985d3a115e01de5f87bd8b112c32b80d Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 13 May 2026 17:16:27 +0200 Subject: [PATCH 15/16] Refine comment about SealingBlockSeq Signed-off-by: Yacov Manevich --- msm/encoding.go | 1 + 1 file changed, 1 insertion(+) diff --git a/msm/encoding.go b/msm/encoding.go index 88c45995..7fd22189 100644 --- a/msm/encoding.go +++ b/msm/encoding.go @@ -62,6 +62,7 @@ type SimplexEpochInfo struct { NextEpochApprovals *NextEpochApprovals `canoto:"pointer,7"` // SealingBlockSeq is the block sequence of the sealing block of the current epoch. // It defines the validator set of the next epoch. + // It is set once the first Telock is built and is copied over to subsequent Telocks. SealingBlockSeq uint64 `canoto:"uint,8"` canotoData canotoData_SimplexEpochInfo From 931a7a978d6a456b2377eaab63a3a6e224cb3c6d Mon Sep 17 00:00:00 2001 From: Yacov Manevich Date: Wed, 13 May 2026 18:13:47 +0200 Subject: [PATCH 16/16] Fix test TestMSMFirstSimplexBlockAfterPreSimplexBlocks Signed-off-by: Yacov Manevich --- msm/msm_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/msm/msm_test.go b/msm/msm_test.go index 775e41e6..b04f8d27 100644 --- a/msm/msm_test.go +++ b/msm/msm_test.go @@ -186,7 +186,7 @@ func TestMSMFirstSimplexBlockAfterPreSimplexBlocks(t *testing.T) { require.Equal(t, &StateMachineBlock{ Metadata: StateMachineMetadata{ - Timestamp: uint64(testConfig1.blockBuilder.block.Timestamp().UnixMilli()), + Timestamp: uint64(preSimplexParent.InnerBlock.Timestamp().UnixMilli()), PChainHeight: 100, SimplexProtocolMetadata: md.Bytes(), SimplexEpochInfo: SimplexEpochInfo{