Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 236 additions & 70 deletions Raft/raft.go

Large diffs are not rendered by default.

Binary file added Raft/raft_data/raft_log.dat
Binary file not shown.
6 changes: 2 additions & 4 deletions Raft/raft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,7 @@ func TestSnapshotCreation(t *testing.T) {
r.mu.Unlock()

// 创建快照
snapshotData := []byte("snapshot state")
err := r.TakeSnapshot(1, snapshotData)
err := r.TakeSnapshot(1)
if err != nil {
t.Fatalf("Failed to create snapshot: %v", err)
}
Expand Down Expand Up @@ -240,8 +239,7 @@ func TestSnapshotPersistence(t *testing.T) {
r.mu.Unlock()

// 创建快照(包含索引 0 和 1)
snapshotData := []byte("snapshot state at index 1")
err := r.TakeSnapshot(1, snapshotData)
err := r.TakeSnapshot(1)
if err != nil {
t.Fatalf("Failed to create snapshot: %v", err)
}
Expand Down
92 changes: 62 additions & 30 deletions Raft/rpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,24 @@ func (r *RaftRPC) RequestVote(args *RequestVoteArgs, reply *RequestVoteReply) er
}

func (r *RaftRPC) isLogUpToDate(candidateLastIndex, candidateLastTerm int) bool {
if len(r.raft.log) == 0 {
// 当前节点日志为空且无快照时,候选者日志始终是最新的
if len(r.raft.log) == 0 && r.raft.LastIncludedIndex == 0 {
return true
}

lastLog := r.raft.log[len(r.raft.log)-1]
if candidateLastTerm > lastLog.Term {
// 获取当前节点的最后日志索引和任期(考虑快照)
lastIndex := int(r.raft.LastIncludedIndex)
lastTerm := int(r.raft.LastIncludedTerm)
if len(r.raft.log) > 0 {
lastLog := r.raft.log[len(r.raft.log)-1]
lastIndex = lastLog.Index
lastTerm = lastLog.Term
}

if candidateLastTerm > lastTerm {
return true
}
if candidateLastTerm == lastLog.Term && candidateLastIndex >= len(r.raft.log)-1 {
if candidateLastTerm == lastTerm && candidateLastIndex >= lastIndex {
return true
}
return false
Expand All @@ -120,29 +129,51 @@ func (r *RaftRPC) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesRep
r.raft.persistLocked()
}

if len(r.raft.log) > 0 && (args.PrevLogIndex >= len(r.raft.log) || r.raft.log[args.PrevLogIndex].Term != args.PrevLogTerm) {
reply.Success = false
reply.Term = r.raft.Term
return nil
// 检查 PrevLogIndex 是否匹配(考虑快照偏移)
if args.PrevLogIndex >= 0 {
if args.PrevLogIndex == int(r.raft.LastIncludedIndex) && r.raft.LastIncludedIndex > 0 {
// prevLogIndex 匹配快照,检查 term 是否一致
if args.PrevLogTerm != int(r.raft.LastIncludedTerm) {
reply.Success = false
reply.Term = r.raft.Term
return nil
}
} else if args.PrevLogIndex > int(r.raft.LastIncludedIndex) {
relativeIndex := args.PrevLogIndex - int(r.raft.LastIncludedIndex) - 1
if relativeIndex >= len(r.raft.log) || r.raft.log[relativeIndex].Term != args.PrevLogTerm {
reply.Success = false
reply.Term = r.raft.Term
return nil
}
} else {
// PrevLogIndex 小于 LastIncludedIndex,日志不一致
reply.Success = false
reply.Term = r.raft.Term
return nil
}
}

for i, entry := range args.Entries {
logIndex := args.PrevLogIndex + i + 1
if logIndex < len(r.raft.log) && r.raft.log[logIndex].Term != entry.Term {
r.raft.log = r.raft.log[:logIndex]
// 追加新日志条目
for _, entry := range args.Entries {
relativeIndex := entry.Index - int(r.raft.LastIncludedIndex) - 1
if relativeIndex < len(r.raft.log) && r.raft.log[relativeIndex].Term != entry.Term {
r.raft.log = r.raft.log[:relativeIndex]
}
if logIndex >= len(r.raft.log) {
if relativeIndex >= len(r.raft.log) {
r.raft.log = append(r.raft.log, entry)
}
}

// 持久化接收到的日志
if len(args.Entries) > 0 {
r.raft.persistLocked()
}

if args.LeaderCommit > r.raft.commitIndex {
r.raft.commitIndex = min(args.LeaderCommit, len(r.raft.log)-1)
lastLogIndex := int(r.raft.LastIncludedIndex)
if len(r.raft.log) > 0 {
lastLogIndex = r.raft.log[len(r.raft.log)-1].Index
}
r.raft.commitIndex = min(args.LeaderCommit, lastLogIndex)
r.applyCommittedLogs()
}

Expand All @@ -154,8 +185,11 @@ func (r *RaftRPC) AppendEntries(args *AppendEntriesArgs, reply *AppendEntriesRep
func (r *RaftRPC) applyCommittedLogs() {
for r.raft.lastApplied < r.raft.commitIndex {
r.raft.lastApplied++
if r.raft.ApplyCh != nil {
r.raft.ApplyCh <- r.raft.log[r.raft.lastApplied]
relativeIndex := r.raft.lastApplied - int(r.raft.LastIncludedIndex) - 1
if relativeIndex >= 0 && relativeIndex < len(r.raft.log) {
if r.raft.ApplyCh != nil {
r.raft.ApplyCh <- r.raft.log[relativeIndex]
}
}
}
}
Expand All @@ -176,8 +210,8 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps
r.raft.votedFor = -1
}

if args.LastIncludedIndex <= int64(r.raft.commitIndex) {
// 快照比已提交的还旧,不需要应用
if args.LastIncludedIndex <= r.raft.LastIncludedIndex {
// 快照比已有的还旧,不需要应用
reply.Success = false
reply.Term = r.raft.Term
return nil
Expand All @@ -194,18 +228,15 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps
// 2. 删除旧快照
r.raft.wal.DeleteOldSnapshots(args.LastIncludedIndex)

// 3. 清理内存中的日志并重新编号
if len(r.raft.log) > 0 {
// 计算需要保留的日志起始位置(相对于 LastIncludedIndex)
newLogStart := int(args.LastIncludedIndex) - int(r.raft.LastIncludedIndex)
if newLogStart > 0 && newLogStart <= len(r.raft.log) {
r.raft.log = r.raft.log[newLogStart:]
for i := range r.raft.log {
r.raft.log[i].Index = int(args.LastIncludedIndex) + 1 + i
}
} else {
r.raft.log = []LogEntry{}
// 3. 清理内存中的日志并重新编号(移除快照包含的条目)
newLogStart := int(args.LastIncludedIndex) + 1 - int(r.raft.LastIncludedIndex)
if newLogStart >= 0 && newLogStart <= len(r.raft.log) {
r.raft.log = r.raft.log[newLogStart:]
for i := range r.raft.log {
r.raft.log[i].Index = int(args.LastIncludedIndex) + 1 + i
}
} else {
r.raft.log = []LogEntry{}
}

// 4. 截断 WAL 日志
Expand All @@ -219,6 +250,7 @@ func (r *RaftRPC) InstallSnapshot(args *InstallSnapshotArgs, reply *InstallSnaps
// 5. 更新元数据
r.raft.commitIndex = int(args.LastIncludedIndex)
r.raft.lastApplied = int(args.LastIncludedIndex)
r.raft.lastSnapshotIndex = int(args.LastIncludedIndex)
r.raft.LastIncludedIndex = args.LastIncludedIndex
r.raft.LastIncludedTerm = args.LastIncludedTerm

Expand Down
Binary file modified Server/raft_data/raft_log.dat
Binary file not shown.
Binary file modified Server/raft_data/raft_state.dat
Binary file not shown.
8 changes: 4 additions & 4 deletions config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
"Host": "127.0.0.1",
"Port": 8080,
"Name": "demo server",
"MaxConn": 3,
"MaxConn": 100,
"MaxPackageSize": 1024,
"WorkPoolSize": 3,
"MaxWorkPoolTaskLen": 3,
"MaxMsgChanLen": 3,
"WorkPoolSize": 100,
"MaxWorkPoolTaskLen": 100,
"MaxMsgChanLen": 100,
"MaxCompactionSize": 4,
"MaxMemTableP": 0.5,
"MaxMemTableLevel": 32,
Expand Down
39 changes: 37 additions & 2 deletions service/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/NeverENG/BanDB/Raft"
"github.com/NeverENG/BanDB/config"
"github.com/NeverENG/BanDB/storage"
"github.com/NeverENG/BanDB/storage/istorage"
"github.com/NeverENG/BanDB/storage/zstorage"
)

Expand Down Expand Up @@ -53,14 +54,20 @@ func (k *KVServer) Run() {

// Apply 应用日志到存储
func (k *KVServer) Apply(entry Raft.LogEntry) {
// 处理快照:异步重放到临时表 → Flush → SSTable,不阻塞 ApplyCh
if entry.IsSnapshot {
fmt.Printf("[FSM] Snapshot received, async replaying %d entries...\n",
len(Raft.DeserializeLogEntries(entry.Command)))
go k.replaySnapshot(entry)
return
}

var cmd Command
if err := json.Unmarshal(entry.Command, &cmd); err != nil {
fmt.Printf("[ERROR] Failed to unmarshal command: %v\n", err)
return
}



switch cmd.Type {
case "Put":
err := k.storage.Put(cmd.Key, cmd.Value)
Expand All @@ -79,6 +86,34 @@ func (k *KVServer) Apply(entry Raft.LogEntry) {
}
}

// replaySnapshot 异步重放快照中的日志条目到临时表并 Flush 到 SSTable
func (k *KVServer) replaySnapshot(entry Raft.LogEntry) {
entries := Raft.DeserializeLogEntries(entry.Command)
if len(entries) == 0 {
return
}

kvEntries := make([]istorage.LogEntry, 0, len(entries))
for _, e := range entries {
var cmd Command
if err := json.Unmarshal(e.Command, &cmd); err != nil {
continue
}
switch cmd.Type {
case "Put":
kvEntries = append(kvEntries, istorage.LogEntry{Key: cmd.Key, Value: cmd.Value})
case "Delete":
kvEntries = append(kvEntries, istorage.LogEntry{Key: cmd.Key, Value: nil})
}
}

if err := k.storage.FlushToSSTable(kvEntries); err != nil {
fmt.Printf("[FSM ERROR] Snapshot replay failed: %v\n", err)
} else {
fmt.Printf("[FSM] Snapshot replay completed: %d entries flushed to SSTable\n", len(kvEntries))
}
}

// Get 从存储获取值
func (k *KVServer) Get(key []byte) ([]byte, error) {
value, err := k.storage.Get(key)
Expand Down
2 changes: 0 additions & 2 deletions service/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ func (r *Router) handlePut(data []byte, request banIface.IRequest) {
key := data[8 : 8+keyLen]
value := data[8+keyLen : 8+keyLen+valueLen]

slog.Info("[INFO] handlePut", "key", string(key), "value", string(value))

// 创建命令并通过 Raft 追加日志
cmd := Command{
Type: "Put",
Expand Down
7 changes: 6 additions & 1 deletion storage/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ func (e *Engine) GetApplyCh() chan StorageCommand {
return e.applyCh
}

// FlushToSSTable 快照重放到 SSTable(不经过 active 表,走临时表 → Flush → SSTable 路径)
func (e *Engine) FlushToSSTable(entries []istorage.LogEntry) error {
return e.memTable.FlushToSSTable(entries)
}

func (e *Engine) applyWorker() {
for cmd := range e.applyCh {
switch cmd.Type {
Expand All @@ -59,4 +64,4 @@ func (e *Engine) applyWorker() {
e.Delete(cmd.Key)
}
}
}
}
3 changes: 3 additions & 0 deletions storage/istorage/IMemTable.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ type IMemTable interface {
Delete(key []byte) error
Size() int
StartFlush()
// FlushToSSTable 将 entries 写入临时表并立即 Flush 到 SSTable
// 不经过 active 表,不阻塞正常读写
FlushToSSTable(entries []LogEntry) error
}
39 changes: 38 additions & 1 deletion storage/zstorage/memtable.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ func (m *MemTable) Get(key []byte) ([]byte, error) {

// 最后在 SSTable 中查找
if val, found := m.getFromSSTables(key); found {
fmt.Printf("[MEMTABLE] Get found in SSTable: key=%s\n", string(key))
return val, nil
}

Expand Down Expand Up @@ -404,6 +403,44 @@ func (m *MemTable) getFromSSTables(key []byte) ([]byte, bool) {
return nil, false
}

// FlushToSSTable 将 entries 写入临时跳表并立即 Flush 到 SSTable
// 不经过 active 表,不影响正常读写,专用于快照重放等场景
func (m *MemTable) FlushToSSTable(entries []istorage.LogEntry) error {
if len(entries) == 0 {
return nil
}

// 创建临时跳表,按序插入(同 key 自动去重/更新)
tmp := newSkipList()
for _, entry := range entries {
if entry.Value == nil {
tmp.delete(entry.Key)
} else {
tmp.insert(entry.Key, entry.Value)
}
}

// 从临时跳表收集有序条目
sorted := collectAllEntry(tmp)
if len(sorted) == 0 {
return nil
}

// 写入 SSTable(SSTable 内部有锁保护元数据并发安全)
if err := m.sst.WriteToSSTable(sorted); err != nil {
return fmt.Errorf("FlushToSSTable write error: %w", err)
}

// 触发 Compaction 检查
select {
case m.compactCh <- true:
default:
}

fmt.Printf("[MEMTABLE] FlushToSSTable completed: %d entries\n", len(sorted))
return nil
}

func (m *MemTable) WriteSSTable() error {
m.mu.RLock()
active := m.active
Expand Down
Loading