diff --git a/Raft/raft.go b/Raft/raft.go index c0e86aa..1423b82 100644 --- a/Raft/raft.go +++ b/Raft/raft.go @@ -1,6 +1,7 @@ package Raft import ( + "encoding/binary" "fmt" "math/rand" "sync" @@ -194,11 +195,11 @@ func (r *Raft) startElection() { r.votedFor = r.me r.persistLocked() // 持久化 Term 和 votedFor - lastLogIndex := -1 - lastLogTerm := 0 + lastLogIndex := int(r.LastIncludedIndex) + lastLogTerm := int(r.LastIncludedTerm) if len(r.log) > 0 { - lastLogIndex = len(r.log) - 1 - lastLogTerm = r.log[lastLogIndex].Term + lastLogIndex = r.log[len(r.log)-1].Index + lastLogTerm = r.log[len(r.log)-1].Term } args := &RequestVoteArgs{ @@ -246,18 +247,6 @@ func (r *Raft) startElection() { r.mu.Unlock() - // 自己投自己一票 - - // 单节点模式:自己一票就超过半数,直接成为 Leader - if votes > len(r.peers)/2 { - r.mu.Lock() - if r.state == Candidate { - r.becomeLeader() - } - r.mu.Unlock() - return - } - // 等待投票结果或超时 timeout := time.After(500 * time.Millisecond) for j := 0; j < peerCount; j++ { @@ -292,9 +281,15 @@ func (r *Raft) becomeLeader() { fmt.Printf("[RAFT] Becoming Leader, Term=%d\n", r.Term) r.state = Leader + // 计算下一个日志的绝对索引(考虑快照偏移) + nextLogIndex := int(r.LastIncludedIndex) + 1 + if len(r.log) > 0 { + nextLogIndex = r.log[len(r.log)-1].Index + 1 + } + for i := range r.peers { - r.nextIndex[i] = len(r.log) - r.matchIndex[i] = -1 + r.nextIndex[i] = nextLogIndex + r.matchIndex[i] = int(r.LastIncludedIndex) } fmt.Printf("[RAFT] Started heartbeat loop\n") @@ -328,11 +323,43 @@ func (r *Raft) SendHeartBeat() { } prevLogIndex := r.nextIndex[i] - 1 - prevLogTerm := 0 - if prevLogIndex >= 0 && prevLogIndex < len(r.log) { - prevLogTerm = r.log[prevLogIndex].Term + + // 如果 follower 落后太多(prevLogIndex 在快照范围内),发送 InstallSnapshot + if prevLogIndex < int(r.LastIncludedIndex) && r.LastIncludedIndex > 0 { + snapshotData, _, _, err := r.wal.LoadLatestSnapshot() + if err == nil && snapshotData != nil { + snapArgs := &InstallSnapshotArgs{ + Term: r.Term, + LeaderID: r.me, + Data: snapshotData, + LastIncludedIndex: r.LastIncludedIndex, + LastIncludedTerm: r.LastIncludedTerm, + } + r.mu.Unlock() + go func(peerID int, snapArgs *InstallSnapshotArgs) { + reply, err := r.SendInstallSnapshot(r.addrMap[peerID], snapArgs) + if err != nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + if reply.Success { + r.nextIndex[peerID] = int(r.LastIncludedIndex) + 1 + r.matchIndex[peerID] = int(r.LastIncludedIndex) + } else if reply.Term > r.Term { + r.Term = reply.Term + r.state = Follower + r.votedFor = -1 + r.heartbeatTicker.Stop() + } + }(i, snapArgs) + r.mu.Lock() + continue + } } + prevLogTerm := r.getTermAt(prevLogIndex) + args := &AppendEntriesArgs{ Term: r.Term, LeaderID: r.me, @@ -367,8 +394,8 @@ func (r *Raft) SendHeartBeat() { } if reply.Success { - r.nextIndex[peerID] = len(r.log) - r.matchIndex[peerID] = len(r.log) - 1 + r.nextIndex[peerID] = r.getLastLogIndex() + 1 + r.matchIndex[peerID] = r.getLastLogIndex() r.updateCommitIndex() } else { r.nextIndex[peerID]-- @@ -385,14 +412,23 @@ func (r *Raft) updateCommitIndex() { return } - for n := len(r.log) - 1; n > r.commitIndex; n-- { + // 从后往前遍历日志条目,找到可以提交的 + for i := len(r.log) - 1; i >= 0; i-- { + n := r.log[i].Index + if n <= r.commitIndex { + continue + } + if r.log[i].Term != r.Term { + continue + } + count := 1 - for i := range r.peers { - if i != r.me && r.matchIndex[i] >= n { + for j := range r.peers { + if j != r.me && r.matchIndex[j] >= n { count++ } } - if count > len(r.peers)/2 && r.log[n].Term == r.Term { + if count > len(r.peers)/2 { r.commitIndex = n r.applyCommittedLogs() r.commitCond.Broadcast() @@ -404,8 +440,12 @@ func (r *Raft) updateCommitIndex() { func (r *Raft) applyCommittedLogs() { for r.lastApplied < r.commitIndex { r.lastApplied++ - if r.ApplyCh != nil { - r.ApplyCh <- r.log[r.lastApplied] + // 将绝对索引转换为相对数组索引 + relativeIndex := r.lastApplied - int(r.LastIncludedIndex) - 1 + if relativeIndex >= 0 && relativeIndex < len(r.log) { + if r.ApplyCh != nil { + r.ApplyCh <- r.log[relativeIndex] + } } } @@ -419,30 +459,50 @@ func (r *Raft) checkSnapshotTrigger() { return } - // 如果日志长度超过阈值,触发快照 logLength := len(r.log) - threshold := 1000 // 默认阈值 - keepEntries := 100 // 保留的条目数 - - // 从配置中读取(如果可用) - if config.G.RaftSnapshotThreshold > 0 { - threshold = config.G.RaftSnapshotThreshold + threshold := config.G.RaftSnapshotThreshold + if threshold <= 0 { + threshold = 10000 } - if config.G.RaftSnapshotKeepEntries > 0 { - keepEntries = config.G.RaftSnapshotKeepEntries + keepEntries := config.G.RaftSnapshotKeepEntries + if keepEntries <= 0 { + keepEntries = 100 } if logLength > threshold { - // 计算快照索引:保留最新的 keepEntries 条日志 snapshotIndex := r.commitIndex - keepEntries if snapshotIndex > r.lastSnapshotIndex { - // 这里需要上层应用提供快照数据 - // 实际使用时,应该通过回调或通道请求 FSM 生成快照 - // TODO: 实现快照生成逻辑 + fmt.Printf("[RAFT] Auto-triggering snapshot at index %d (log length=%d, threshold=%d)\n", + snapshotIndex, logLength, threshold) + // 异步调用避免持锁死锁(checkSnapshotTrigger 在持锁上下文中被调用) + go r.TakeSnapshot(snapshotIndex) } } } +// getTermAt 获取指定绝对索引处的日志 term(考虑快照偏移) +func (r *Raft) getTermAt(absIndex int) int { + if absIndex < 0 { + return 0 + } + if absIndex == int(r.LastIncludedIndex) && r.LastIncludedIndex > 0 { + return int(r.LastIncludedTerm) + } + relativeIndex := absIndex - int(r.LastIncludedIndex) - 1 + if relativeIndex >= 0 && relativeIndex < len(r.log) { + return r.log[relativeIndex].Term + } + return 0 +} + +// getLastLogIndex 获取最后一条日志的绝对索引 +func (r *Raft) getLastLogIndex() int { + if len(r.log) > 0 { + return r.log[len(r.log)-1].Index + } + return int(r.LastIncludedIndex) +} + func (r *Raft) AppendEntry(command []byte) (int, error) { r.mu.Lock() defer r.mu.Unlock() @@ -452,13 +512,19 @@ func (r *Raft) AppendEntry(command []byte) (int, error) { return -1, fmt.Errorf("not leader") } + // 计算绝对索引(考虑快照偏移) + lastLogIndex := int(r.LastIncludedIndex) + if len(r.log) > 0 { + lastLogIndex = r.log[len(r.log)-1].Index + } + entry := LogEntry{ - Index: len(r.log), + Index: lastLogIndex + 1, Term: r.Term, Command: command, } r.log = append(r.log, entry) - r.persistLocked() // 持久化日志条目 + r.persistLocked() // 单节点模式:立即提交 if len(r.peers) == 1 { @@ -483,14 +549,46 @@ func (r *Raft) replicateLog() { } prevLogIndex := r.nextIndex[i] - 1 - prevLogTerm := 0 - if prevLogIndex >= 0 && prevLogIndex < len(r.log) { - prevLogTerm = r.log[prevLogIndex].Term + + // 如果 follower 落后太多(prevLogIndex 在快照范围内),发送 InstallSnapshot + if prevLogIndex < int(r.LastIncludedIndex) && r.LastIncludedIndex > 0 { + snapshotData, _, _, err := r.wal.LoadLatestSnapshot() + if err == nil && snapshotData != nil { + snapArgs := &InstallSnapshotArgs{ + Term: r.Term, + LeaderID: r.me, + Data: snapshotData, + LastIncludedIndex: r.LastIncludedIndex, + LastIncludedTerm: r.LastIncludedTerm, + } + go func(peerID int, snapArgs *InstallSnapshotArgs) { + reply, err := r.SendInstallSnapshot(r.addrMap[peerID], snapArgs) + if err != nil { + return + } + r.mu.Lock() + defer r.mu.Unlock() + if reply.Success { + r.nextIndex[peerID] = int(r.LastIncludedIndex) + 1 + r.matchIndex[peerID] = int(r.LastIncludedIndex) + } else if reply.Term > r.Term { + r.Term = reply.Term + r.state = Follower + r.votedFor = -1 + r.heartbeatTicker.Stop() + } + }(i, snapArgs) + } + continue } + prevLogTerm := r.getTermAt(prevLogIndex) + + // 将绝对索引转换为相对数组索引来切片日志 var entries []LogEntry - if r.nextIndex[i] < len(r.log) { - entries = r.log[r.nextIndex[i]:] + relativeStart := r.nextIndex[i] - int(r.LastIncludedIndex) - 1 + if relativeStart < len(r.log) { + entries = r.log[relativeStart:] } args := &AppendEntriesArgs{ @@ -525,8 +623,8 @@ func (r *Raft) replicateLog() { } if reply.Success { - r.nextIndex[peerID] = len(r.log) - r.matchIndex[peerID] = len(r.log) - 1 + r.nextIndex[peerID] = r.getLastLogIndex() + 1 + r.matchIndex[peerID] = r.getLastLogIndex() r.updateCommitIndex() } else { r.nextIndex[peerID]-- @@ -569,7 +667,7 @@ func (r *Raft) GetCommitIndex() int { return r.commitIndex } -func (r *Raft) TakeSnapshot(index int, data []byte) error { +func (r *Raft) TakeSnapshot(index int) error { r.mu.Lock() defer r.mu.Unlock() @@ -581,20 +679,24 @@ func (r *Raft) TakeSnapshot(index int, data []byte) error { return fmt.Errorf("cannot snapshot uncommitted index %d, commitIndex is %d", index, r.commitIndex) } - // 获取快照包含的最后一条日志的 term + // 收集需要放入快照的日志条目(绝对索引 <= index) + var snapshotEntries []LogEntry + relativeEnd := index + 1 - int(r.LastIncludedIndex) + for i := 0; i < relativeEnd && i < len(r.log); i++ { + snapshotEntries = append(snapshotEntries, r.log[i]) + } + + // 获取最后一条的 term var term int - if index == int(r.LastIncludedIndex) { + if len(snapshotEntries) > 0 { + term = snapshotEntries[len(snapshotEntries)-1].Term + } else if index == int(r.LastIncludedIndex) { term = int(r.LastIncludedTerm) - } else { - // 将绝对索引转换为 log 数组的相对索引 - logIndex := index - int(r.LastIncludedIndex) - 1 - if logIndex < 0 || logIndex >= len(r.log) { - return fmt.Errorf("invalid snapshot index %d, LastIncludedIndex=%d, log length=%d", - index, r.LastIncludedIndex, len(r.log)) - } - term = r.log[logIndex].Term } + // 序列化日志条目为快照数据 + data := SerializeLogEntries(snapshotEntries) + // 1. 先保存快照到磁盘 if err := r.wal.SaveSnapshot(data, int64(index), int64(term)); err != nil { return fmt.Errorf("failed to save snapshot: %w", err) @@ -603,14 +705,14 @@ func (r *Raft) TakeSnapshot(index int, data []byte) error { // 2. 删除旧快照 r.wal.DeleteOldSnapshots(int64(index)) - // 3. 截断 WAL 日志(删除快照包含的日志) + // 3. 截断 WAL 日志 if err := r.wal.TruncateLogs(int64(index)); err != nil { return fmt.Errorf("failed to truncate logs: %w", err) } // 4. 清理内存中的日志并重新编号 - newLogStart := index - int(r.LastIncludedIndex) - if newLogStart > 0 && newLogStart <= len(r.log) { + newLogStart := index + 1 - int(r.LastIncludedIndex) + if newLogStart >= 0 && newLogStart <= len(r.log) { r.log = r.log[newLogStart:] for i := range r.log { r.log[i].Index = index + 1 + i @@ -624,9 +726,7 @@ func (r *Raft) TakeSnapshot(index int, data []byte) error { r.LastIncludedIndex = int64(index) r.LastIncludedTerm = int64(term) - fmt.Printf("[RAFT] Snapshot created: Index=%d, Term=%d\n", index, term) - - // 6. 通知 FSM 应用快照 + // 6. 通知 FSM 异步重放快照(日志条目序列化数据) if r.ApplyCh != nil { snapshotEntry := LogEntry{ Index: index, @@ -636,9 +736,9 @@ func (r *Raft) TakeSnapshot(index int, data []byte) error { } select { case r.ApplyCh <- snapshotEntry: - fmt.Printf("[RAFT] Snapshot notification sent to FSM: Index=%d\n", index) + fmt.Printf("[RAFT] Snapshot replay sent to FSM: Index=%d, entries=%d\n", index, len(snapshotEntries)) default: - fmt.Println("[WARN] ApplyCh is full, snapshot notification skipped") + fmt.Println("[WARN] ApplyCh is full, snapshot replay skipped") } } @@ -647,3 +747,69 @@ func (r *Raft) TakeSnapshot(index int, data []byte) error { return nil } + +// SerializeLogEntries 序列化日志条目为字节流(快照数据格式) +func SerializeLogEntries(entries []LogEntry) []byte { + if len(entries) == 0 { + return nil + } + + size := 4 // entry count + for _, e := range entries { + size += 8 + 8 + 8 + len(e.Command) // Index(8) + Term(8) + CmdLen(8) + Command + } + + buf := make([]byte, size) + offset := 0 + binary.BigEndian.PutUint32(buf[offset:], uint32(len(entries))) + offset += 4 + for _, e := range entries { + binary.BigEndian.PutUint64(buf[offset:], uint64(e.Index)) + offset += 8 + binary.BigEndian.PutUint64(buf[offset:], uint64(e.Term)) + offset += 8 + binary.BigEndian.PutUint64(buf[offset:], uint64(len(e.Command))) + offset += 8 + copy(buf[offset:], e.Command) + offset += len(e.Command) + } + return buf +} + +// DeserializeLogEntries 反序列化日志条目 +func DeserializeLogEntries(data []byte) []LogEntry { + if len(data) < 4 { + return nil + } + + offset := 0 + count := binary.BigEndian.Uint32(data[offset:]) + offset += 4 + + entries := make([]LogEntry, 0, count) + for i := uint32(0); i < count; i++ { + if offset+24 > len(data) { + break + } + index := int(binary.BigEndian.Uint64(data[offset:])) + offset += 8 + term := int(binary.BigEndian.Uint64(data[offset:])) + offset += 8 + cmdLen := int(binary.BigEndian.Uint64(data[offset:])) + offset += 8 + + if offset+cmdLen > len(data) { + break + } + cmd := make([]byte, cmdLen) + copy(cmd, data[offset:offset+cmdLen]) + offset += cmdLen + + entries = append(entries, LogEntry{ + Index: index, + Term: term, + Command: cmd, + }) + } + return entries +} diff --git a/Raft/raft_data/raft_log.dat b/Raft/raft_data/raft_log.dat new file mode 100644 index 0000000..7a844b2 Binary files /dev/null and b/Raft/raft_data/raft_log.dat differ diff --git a/Raft/raft_test.go b/Raft/raft_test.go index e316ec4..e89070b 100644 --- a/Raft/raft_test.go +++ b/Raft/raft_test.go @@ -204,8 +204,7 @@ func TestSnapshotCreation(t *testing.T) { r.mu.Unlock() // 创建快照 - snapshotData := []byte("snapshot state") - err := r.TakeSnapshot(1, snapshotData) + err := r.TakeSnapshot(1) if err != nil { t.Fatalf("Failed to create snapshot: %v", err) } @@ -240,8 +239,7 @@ func TestSnapshotPersistence(t *testing.T) { r.mu.Unlock() // 创建快照(包含索引 0 和 1) - snapshotData := []byte("snapshot state at index 1") - err := r.TakeSnapshot(1, snapshotData) + err := r.TakeSnapshot(1) if err != nil { t.Fatalf("Failed to create snapshot: %v", err) } diff --git a/Raft/rpc.go b/Raft/rpc.go index 08afe4c..f9e4d21 100644 --- a/Raft/rpc.go +++ b/Raft/rpc.go @@ -89,15 +89,24 @@ func (r *RaftRPC) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) er } func (r *RaftRPC) isLogUpToDate(candidateLastIndex, candidateLastTerm int) bool { - if len(r.raft.log) == 0 { + // 当前节点日志为空且无快照时,候选者日志始终是最新的 + if len(r.raft.log) == 0 && r.raft.LastIncludedIndex == 0 { return true } - lastLog := r.raft.log[len(r.raft.log)-1] - if candidateLastTerm > lastLog.Term { + // 获取当前节点的最后日志索引和任期(考虑快照) + lastIndex := int(r.raft.LastIncludedIndex) + lastTerm := int(r.raft.LastIncludedTerm) + if len(r.raft.log) > 0 { + lastLog := r.raft.log[len(r.raft.log)-1] + lastIndex = lastLog.Index + lastTerm = lastLog.Term + } + + if candidateLastTerm > lastTerm { return true } - if candidateLastTerm == lastLog.Term && candidateLastIndex >= len(r.raft.log)-1 { + if candidateLastTerm == lastTerm && candidateLastIndex >= lastIndex { return true } return false @@ -120,29 +129,51 @@ func (r *RaftRPC) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesRep r.raft.persistLocked() } - if len(r.raft.log) > 0 && (args.PrevLogIndex >= len(r.raft.log) || r.raft.log[args.PrevLogIndex].Term != args.PrevLogTerm) { - reply.Success = false - reply.Term = r.raft.Term - return nil + // 检查 PrevLogIndex 是否匹配(考虑快照偏移) + if args.PrevLogIndex >= 0 { + if args.PrevLogIndex == int(r.raft.LastIncludedIndex) && r.raft.LastIncludedIndex > 0 { + // prevLogIndex 匹配快照,检查 term 是否一致 + if args.PrevLogTerm != int(r.raft.LastIncludedTerm) { + reply.Success = false + reply.Term = r.raft.Term + return nil + } + } else if args.PrevLogIndex > int(r.raft.LastIncludedIndex) { + relativeIndex := args.PrevLogIndex - int(r.raft.LastIncludedIndex) - 1 + if relativeIndex >= len(r.raft.log) || r.raft.log[relativeIndex].Term != args.PrevLogTerm { + reply.Success = false + reply.Term = r.raft.Term + return nil + } + } else { + // PrevLogIndex 小于 LastIncludedIndex,日志不一致 + reply.Success = false + reply.Term = r.raft.Term + return nil + } } - for i, entry := range args.Entries { - logIndex := args.PrevLogIndex + i + 1 - if logIndex < len(r.raft.log) && r.raft.log[logIndex].Term != entry.Term { - r.raft.log = r.raft.log[:logIndex] + // 追加新日志条目 + for _, entry := range args.Entries { + relativeIndex := entry.Index - int(r.raft.LastIncludedIndex) - 1 + if relativeIndex < len(r.raft.log) && r.raft.log[relativeIndex].Term != entry.Term { + r.raft.log = r.raft.log[:relativeIndex] } - if logIndex >= len(r.raft.log) { + if relativeIndex >= len(r.raft.log) { r.raft.log = append(r.raft.log, entry) } } - // 持久化接收到的日志 if len(args.Entries) > 0 { r.raft.persistLocked() } if args.LeaderCommit > r.raft.commitIndex { - r.raft.commitIndex = min(args.LeaderCommit, len(r.raft.log)-1) + lastLogIndex := int(r.raft.LastIncludedIndex) + if len(r.raft.log) > 0 { + lastLogIndex = r.raft.log[len(r.raft.log)-1].Index + } + r.raft.commitIndex = min(args.LeaderCommit, lastLogIndex) r.applyCommittedLogs() } @@ -154,8 +185,11 @@ func (r *RaftRPC) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesRep func (r *RaftRPC) applyCommittedLogs() { for r.raft.lastApplied < r.raft.commitIndex { r.raft.lastApplied++ - if r.raft.ApplyCh != nil { - r.raft.ApplyCh <- r.raft.log[r.raft.lastApplied] + relativeIndex := r.raft.lastApplied - int(r.raft.LastIncludedIndex) - 1 + if relativeIndex >= 0 && relativeIndex < len(r.raft.log) { + if r.raft.ApplyCh != nil { + r.raft.ApplyCh <- r.raft.log[relativeIndex] + } } } } @@ -176,8 +210,8 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps r.raft.votedFor = -1 } - if args.LastIncludedIndex <= int64(r.raft.commitIndex) { - // 快照比已提交的还旧,不需要应用 + if args.LastIncludedIndex <= r.raft.LastIncludedIndex { + // 快照比已有的还旧,不需要应用 reply.Success = false reply.Term = r.raft.Term return nil @@ -194,18 +228,15 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps // 2. 删除旧快照 r.raft.wal.DeleteOldSnapshots(args.LastIncludedIndex) - // 3. 清理内存中的日志并重新编号 - if len(r.raft.log) > 0 { - // 计算需要保留的日志起始位置(相对于 LastIncludedIndex) - newLogStart := int(args.LastIncludedIndex) - int(r.raft.LastIncludedIndex) - if newLogStart > 0 && newLogStart <= len(r.raft.log) { - r.raft.log = r.raft.log[newLogStart:] - for i := range r.raft.log { - r.raft.log[i].Index = int(args.LastIncludedIndex) + 1 + i - } - } else { - r.raft.log = []LogEntry{} + // 3. 清理内存中的日志并重新编号(移除快照包含的条目) + newLogStart := int(args.LastIncludedIndex) + 1 - int(r.raft.LastIncludedIndex) + if newLogStart >= 0 && newLogStart <= len(r.raft.log) { + r.raft.log = r.raft.log[newLogStart:] + for i := range r.raft.log { + r.raft.log[i].Index = int(args.LastIncludedIndex) + 1 + i } + } else { + r.raft.log = []LogEntry{} } // 4. 截断 WAL 日志 @@ -219,6 +250,7 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps // 5. 更新元数据 r.raft.commitIndex = int(args.LastIncludedIndex) r.raft.lastApplied = int(args.LastIncludedIndex) + r.raft.lastSnapshotIndex = int(args.LastIncludedIndex) r.raft.LastIncludedIndex = args.LastIncludedIndex r.raft.LastIncludedTerm = args.LastIncludedTerm diff --git a/Server/raft_data/raft_log.dat b/Server/raft_data/raft_log.dat index e69de29..ae7b5dc 100644 Binary files a/Server/raft_data/raft_log.dat and b/Server/raft_data/raft_log.dat differ diff --git a/Server/raft_data/raft_state.dat b/Server/raft_data/raft_state.dat index e69de29..c1a3377 100644 Binary files a/Server/raft_data/raft_state.dat and b/Server/raft_data/raft_state.dat differ diff --git a/config/config.json b/config/config.json index c1d0236..75fbefa 100644 --- a/config/config.json +++ b/config/config.json @@ -5,11 +5,11 @@ "Host": "127.0.0.1", "Port": 8080, "Name": "demo server", - "MaxConn": 3, + "MaxConn": 100, "MaxPackageSize": 1024, - "WorkPoolSize": 3, - "MaxWorkPoolTaskLen": 3, - "MaxMsgChanLen": 3, + "WorkPoolSize": 100, + "MaxWorkPoolTaskLen": 100, + "MaxMsgChanLen": 100, "MaxCompactionSize": 4, "MaxMemTableP": 0.5, "MaxMemTableLevel": 32, diff --git a/service/fsm.go b/service/fsm.go index c09a7d7..fffc9a8 100644 --- a/service/fsm.go +++ b/service/fsm.go @@ -8,6 +8,7 @@ import ( "github.com/NeverENG/BanDB/Raft" "github.com/NeverENG/BanDB/config" "github.com/NeverENG/BanDB/storage" + "github.com/NeverENG/BanDB/storage/istorage" "github.com/NeverENG/BanDB/storage/zstorage" ) @@ -53,14 +54,20 @@ func (k *KVServer) Run() { // Apply 应用日志到存储 func (k *KVServer) Apply(entry Raft.LogEntry) { + // 处理快照:异步重放到临时表 → Flush → SSTable,不阻塞 ApplyCh + if entry.IsSnapshot { + fmt.Printf("[FSM] Snapshot received, async replaying %d entries...\n", + len(Raft.DeserializeLogEntries(entry.Command))) + go k.replaySnapshot(entry) + return + } + var cmd Command if err := json.Unmarshal(entry.Command, &cmd); err != nil { fmt.Printf("[ERROR] Failed to unmarshal command: %v\n", err) return } - - switch cmd.Type { case "Put": err := k.storage.Put(cmd.Key, cmd.Value) @@ -79,6 +86,34 @@ func (k *KVServer) Apply(entry Raft.LogEntry) { } } +// replaySnapshot 异步重放快照中的日志条目到临时表并 Flush 到 SSTable +func (k *KVServer) replaySnapshot(entry Raft.LogEntry) { + entries := Raft.DeserializeLogEntries(entry.Command) + if len(entries) == 0 { + return + } + + kvEntries := make([]istorage.LogEntry, 0, len(entries)) + for _, e := range entries { + var cmd Command + if err := json.Unmarshal(e.Command, &cmd); err != nil { + continue + } + switch cmd.Type { + case "Put": + kvEntries = append(kvEntries, istorage.LogEntry{Key: cmd.Key, Value: cmd.Value}) + case "Delete": + kvEntries = append(kvEntries, istorage.LogEntry{Key: cmd.Key, Value: nil}) + } + } + + if err := k.storage.FlushToSSTable(kvEntries); err != nil { + fmt.Printf("[FSM ERROR] Snapshot replay failed: %v\n", err) + } else { + fmt.Printf("[FSM] Snapshot replay completed: %d entries flushed to SSTable\n", len(kvEntries)) + } +} + // Get 从存储获取值 func (k *KVServer) Get(key []byte) ([]byte, error) { value, err := k.storage.Get(key) diff --git a/service/router.go b/service/router.go index 5f01186..25ad50f 100644 --- a/service/router.go +++ b/service/router.go @@ -77,8 +77,6 @@ func (r *Router) handlePut(data []byte, request banIface.IRequest) { key := data[8 : 8+keyLen] value := data[8+keyLen : 8+keyLen+valueLen] - slog.Info("[INFO] handlePut", "key", string(key), "value", string(value)) - // 创建命令并通过 Raft 追加日志 cmd := Command{ Type: "Put", diff --git a/storage/engine.go b/storage/engine.go index 3f4ea75..982e308 100644 --- a/storage/engine.go +++ b/storage/engine.go @@ -50,6 +50,11 @@ func (e *Engine) GetApplyCh() chan StorageCommand { return e.applyCh } +// FlushToSSTable 快照重放到 SSTable(不经过 active 表,走临时表 → Flush → SSTable 路径) +func (e *Engine) FlushToSSTable(entries []istorage.LogEntry) error { + return e.memTable.FlushToSSTable(entries) +} + func (e *Engine) applyWorker() { for cmd := range e.applyCh { switch cmd.Type { @@ -59,4 +64,4 @@ func (e *Engine) applyWorker() { e.Delete(cmd.Key) } } -} +} \ No newline at end of file diff --git a/storage/istorage/IMemTable.go b/storage/istorage/IMemTable.go index 588e9d5..157e41b 100644 --- a/storage/istorage/IMemTable.go +++ b/storage/istorage/IMemTable.go @@ -6,4 +6,7 @@ type IMemTable interface { Delete(key []byte) error Size() int StartFlush() + // FlushToSSTable 将 entries 写入临时表并立即 Flush 到 SSTable + // 不经过 active 表,不阻塞正常读写 + FlushToSSTable(entries []LogEntry) error } diff --git a/storage/zstorage/memtable.go b/storage/zstorage/memtable.go index 5a4b030..daa9da1 100644 --- a/storage/zstorage/memtable.go +++ b/storage/zstorage/memtable.go @@ -129,7 +129,6 @@ func (m *MemTable) Get(key []byte) ([]byte, error) { // 最后在 SSTable 中查找 if val, found := m.getFromSSTables(key); found { - fmt.Printf("[MEMTABLE] Get found in SSTable: key=%s\n", string(key)) return val, nil } @@ -404,6 +403,44 @@ func (m *MemTable) getFromSSTables(key []byte) ([]byte, bool) { return nil, false } +// FlushToSSTable 将 entries 写入临时跳表并立即 Flush 到 SSTable +// 不经过 active 表,不影响正常读写,专用于快照重放等场景 +func (m *MemTable) FlushToSSTable(entries []istorage.LogEntry) error { + if len(entries) == 0 { + return nil + } + + // 创建临时跳表,按序插入(同 key 自动去重/更新) + tmp := newSkipList() + for _, entry := range entries { + if entry.Value == nil { + tmp.delete(entry.Key) + } else { + tmp.insert(entry.Key, entry.Value) + } + } + + // 从临时跳表收集有序条目 + sorted := collectAllEntry(tmp) + if len(sorted) == 0 { + return nil + } + + // 写入 SSTable(SSTable 内部有锁保护元数据并发安全) + if err := m.sst.WriteToSSTable(sorted); err != nil { + return fmt.Errorf("FlushToSSTable write error: %w", err) + } + + // 触发 Compaction 检查 + select { + case m.compactCh <- true: + default: + } + + fmt.Printf("[MEMTABLE] FlushToSSTable completed: %d entries\n", len(sorted)) + return nil +} + func (m *MemTable) WriteSSTable() error { m.mu.RLock() active := m.active