Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 53 additions & 23 deletions Raft/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package Raft
import (
"encoding/binary"
"fmt"
"log/slog"
"math/rand"
"sync"
"time"
Expand Down Expand Up @@ -95,7 +96,7 @@ func NewRaft(peers []string, me int) *Raft {

// 从磁盘加载持久化状态(currentTerm, votedFor, log, snapshot metadata)
if err := r.readPersist(); err != nil {
fmt.Printf("[RAFT WARN] Failed to load persisted state: %v\n", err)
slog.Warn("failed to load persisted state", "error", err)
}

// 如果有快照,通知 FSM
Expand All @@ -110,7 +111,7 @@ func NewRaft(peers []string, me int) *Raft {
IsSnapshot: true,
}:
default:
fmt.Println("[WARN] ApplyCh is full during initialization, snapshot skipped")
slog.Warn("ApplyCh full during init, snapshot skipped")
}
}
}
Expand All @@ -122,7 +123,14 @@ func NewRaft(peers []string, me int) *Raft {
return r
}

// persistLocked 持久化 Raft 状态(必须在持有锁的情况下调用)
// persistStateLocked 仅持久化 Term 和 votedFor(增量持久化,O(1))
func (r *Raft) persistStateLocked() {
if err := r.wal.SaveState(int64(r.Term), int64(r.votedFor)); err != nil {
slog.Error("failed to persist state", "error", err)
}
}
Comment on lines +126 to +131
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Persist every Term/votedFor mutation, not just election start.

This helper is a good optimization, but the higher-term reply paths still update Term and votedFor without calling it. A crash after one of those downgrades can restart the node with a stale term/vote.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@Raft/raft.go` around lines 126 - 131, The persistStateLocked helper currently
only persists Term/votedFor in some places; ensure every mutation of r.Term or
r.votedFor calls persistStateLocked while holding the raft lock so state is
durable. Find all code paths that assign r.Term or r.votedFor (e.g., the
higher-term reply/response handlers such as the functions that process
AppendEntries/RequestVote replies or any onReceiveHigherTerm/becomeFollower
helpers) and add a call to persistStateLocked() immediately after the assignment
(still under the same lock), propagating/logging any persistence error as the
existing SaveState path does. Ensure no code updates Term/votedFor without
invoking persistStateLocked under lock.


// persistLocked 全量持久化 Raft 状态(仅用于日志冲突截断等特殊情况)
func (r *Raft) persistLocked() {
data := PersistData{
CurrentTerm: int64(r.Term),
Expand All @@ -133,7 +141,7 @@ func (r *Raft) persistLocked() {
}

if err := r.wal.SavePersist(data); err != nil {
fmt.Printf("[RAFT ERROR] Failed to persist state: %v\n", err)
slog.Error("failed to persist state", "error", err)
}
}

Expand Down Expand Up @@ -188,18 +196,21 @@ func (r *Raft) startElection() {
return
}

fmt.Printf("[RAFT] Starting election, current state=%v, Term=%d\n", r.state, r.Term)
slog.Info("election commenced", "term", r.Term)

r.state = Candidate
r.Term++
r.votedFor = r.me
r.persistLocked() // 持久化 Term 和 votedFor
r.persistStateLocked()

lastLogIndex := int(r.LastIncludedIndex)
lastLogTerm := int(r.LastIncludedTerm)
lastLogIndex := -1
lastLogTerm := 0
if len(r.log) > 0 {
lastLogIndex = r.log[len(r.log)-1].Index
lastLogTerm = r.log[len(r.log)-1].Term
} else if r.LastIncludedIndex > 0 {
lastLogIndex = int(r.LastIncludedIndex)
lastLogTerm = int(r.LastIncludedTerm)
Comment on lines +206 to +213
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | 🏗️ Heavy lift

Keep the new zero-based index convention consistent across Raft.

After these changes, the first real log entry becomes index 0, but the rest of the file still assumes the first in-memory entry is LastIncludedIndex + 1 and that LastIncludedIndex == 0 means “no snapshot”. In a fresh multi-node cluster that makes replicateLog() compute relativeStart == -1 on the first append, and a snapshot taken at index 0 is ignored by the > 0 guards during later bootstrap/election flows. Please align on one baseline before merging.

Also applies to: 306-310, 524-527, 540-545

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@Raft/raft.go` around lines 206 - 213, The code mixes a new zero-based
log-index convention with old guards that treat LastIncludedIndex==0 as “no
snapshot”, causing relativeStart to be -1 in replicateLog() and skipping valid
snapshot index 0; update the logic around lastLogIndex/lastLogTerm and snapshot
checks to consistently treat LastIncludedIndex as the actual last included log
index (allowing 0), e.g. initialize lastLogIndex = int(r.LastIncludedIndex) when
r.log is empty, compute relativeStart as int(start) - int(r.LastIncludedIndex)
(not subtracting 1), and replace all `> 0` snapshot guards with `>= 0` or
explicit nil/empty checks so functions like replicateLog(), and code paths
referencing r.LastIncludedIndex/r.LastIncludedTerm/r.log use the same baseline
across the file.

}

args := &RequestVoteArgs{
Expand Down Expand Up @@ -247,6 +258,16 @@ func (r *Raft) startElection() {

r.mu.Unlock()

// 单节点模式:无需等待投票,直接成为 Leader
if peerCount == 0 {
r.mu.Lock()
if r.state == Candidate {
r.becomeLeader()
}
r.mu.Unlock()
return
}

// 等待投票结果或超时
timeout := time.After(500 * time.Millisecond)
for j := 0; j < peerCount; j++ {
Expand Down Expand Up @@ -278,21 +299,22 @@ func (r *Raft) startElection() {
}

func (r *Raft) becomeLeader() {
fmt.Printf("[RAFT] Becoming Leader, Term=%d\n", r.Term)
slog.Info("leader elected", "term", r.Term)
r.state = Leader

// 计算下一个日志的绝对索引(考虑快照偏移)
nextLogIndex := int(r.LastIncludedIndex) + 1
nextLogIndex := 0
if len(r.log) > 0 {
nextLogIndex = r.log[len(r.log)-1].Index + 1
} else if r.LastIncludedIndex > 0 {
nextLogIndex = int(r.LastIncludedIndex) + 1
}

for i := range r.peers {
r.nextIndex[i] = nextLogIndex
r.matchIndex[i] = int(r.LastIncludedIndex)
}

fmt.Printf("[RAFT] Started heartbeat loop\n")
r.startHeartbeatLoop()
}

Expand Down Expand Up @@ -472,10 +494,9 @@ func (r *Raft) checkSnapshotTrigger() {
if logLength > threshold {
snapshotIndex := r.commitIndex - keepEntries
if snapshotIndex > r.lastSnapshotIndex {
fmt.Printf("[RAFT] Auto-triggering snapshot at index %d (log length=%d, threshold=%d)\n",
snapshotIndex, logLength, threshold)
// 异步调用避免持锁死锁(checkSnapshotTrigger 在持锁上下文中被调用)
go r.TakeSnapshot(snapshotIndex)
slog.Info("auto-triggering snapshot", "index", snapshotIndex, "logLen", logLength, "threshold", threshold)
// 异步调用避免持锁死锁(checkSnapshotTrigger 在持锁上下文中被调用)
go r.TakeSnapshot(snapshotIndex)
}
}
}
Expand All @@ -500,22 +521,27 @@ func (r *Raft) getLastLogIndex() int {
if len(r.log) > 0 {
return r.log[len(r.log)-1].Index
}
return int(r.LastIncludedIndex)
if r.LastIncludedIndex > 0 {
return int(r.LastIncludedIndex)
}
return -1
}

func (r *Raft) AppendEntry(command []byte) (int, error) {
r.mu.Lock()
defer r.mu.Unlock()

if r.state != Leader {
fmt.Printf("[RAFT] AppendEntry failed: not leader, state=%v\n", r.state)
slog.Warn("AppendEntry rejected, not leader", "state", r.state)
return -1, fmt.Errorf("not leader")
}

// 计算绝对索引(考虑快照偏移)
lastLogIndex := int(r.LastIncludedIndex)
lastLogIndex := -1
if len(r.log) > 0 {
lastLogIndex = r.log[len(r.log)-1].Index
} else if r.LastIncludedIndex > 0 {
lastLogIndex = int(r.LastIncludedIndex)
}

entry := LogEntry{
Expand All @@ -524,7 +550,11 @@ func (r *Raft) AppendEntry(command []byte) (int, error) {
Command: command,
}
r.log = append(r.log, entry)
r.persistLocked()

// 增量持久化:仅追加一条日志
if err := r.wal.AppendLog(entry); err != nil {
slog.Error("failed to append log", "error", err)
}

// 单节点模式:立即提交
if len(r.peers) == 1 {
Expand Down Expand Up @@ -736,14 +766,14 @@ func (r *Raft) TakeSnapshot(index int) error {
}
select {
case r.ApplyCh <- snapshotEntry:
fmt.Printf("[RAFT] Snapshot replay sent to FSM: Index=%d, entries=%d\n", index, len(snapshotEntries))
slog.Info("snapshot replay sent to FSM", "index", index, "entries", len(snapshotEntries))
default:
fmt.Println("[WARN] ApplyCh is full, snapshot replay skipped")
slog.Warn("ApplyCh full, snapshot replay skipped")
}
}

// 7. 持久化状态
r.persistLocked()
// 7. 持久化状态(日志已由 TruncateLogs 处理)
r.persistStateLocked()

return nil
}
Expand Down
5 changes: 4 additions & 1 deletion Server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ func main() {
server.AddRouter(2, router) // GET 操作
server.AddRouter(3, router) // DELETE 操作

// 注册连接生命周期回调
server.SetConnStartFunc(router.OnConnStart)
server.SetConnStopFunc(router.OnConnStop)

// 启动服务
fmt.Println("Starting Server...")
fmt.Printf("HA initialized, initial health status: %v\n", ha.IsHealthy())
server.Serve()
}
36 changes: 9 additions & 27 deletions service/fsm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package service
import (
"encoding/json"
"errors"
"fmt"
"log/slog"

"github.com/NeverENG/BanDB/Raft"
"github.com/NeverENG/BanDB/config"
Expand Down Expand Up @@ -46,42 +46,33 @@ func NewKVServer() *KVServer {

// Run 运行 FSM
func (k *KVServer) Run() {
fmt.Println("[INFO] KVServer Run started, waiting for Raft entries...")
slog.Info("KVServer commenced — awaiting Raft entries")
for entry := range k.raft.GetApplyCh() {
k.Apply(entry)
}
}

// Apply 应用日志到存储
func (k *KVServer) Apply(entry Raft.LogEntry) {
// 处理快照:异步重放到临时表 → Flush → SSTable,不阻塞 ApplyCh
if entry.IsSnapshot {
fmt.Printf("[FSM] Snapshot received, async replaying %d entries...\n",
len(Raft.DeserializeLogEntries(entry.Command)))
go k.replaySnapshot(entry)
return
}

var cmd Command
if err := json.Unmarshal(entry.Command, &cmd); err != nil {
fmt.Printf("[ERROR] Failed to unmarshal command: %v\n", err)
slog.Error("failed to unmarshal command", "error", err)
return
}

switch cmd.Type {
case "Put":
err := k.storage.Put(cmd.Key, cmd.Value)
if err != nil {
fmt.Printf("[ERROR] Failed to put: %v\n", err)
} else {
fmt.Printf("[INFO] Put success: %s = %s\n", string(cmd.Key), string(cmd.Value))
if err := k.storage.Put(cmd.Key, cmd.Value); err != nil {
slog.Error("failed to put", "error", err)
}
case "Delete":
err := k.storage.Delete(cmd.Key)
if err != nil {
fmt.Printf("[ERROR] Failed to delete: %v\n", err)
} else {
fmt.Printf("[INFO] Delete success: %s\n", string(cmd.Key))
if err := k.storage.Delete(cmd.Key); err != nil {
slog.Error("failed to delete", "error", err)
}
}
}
Expand All @@ -108,24 +99,15 @@ func (k *KVServer) replaySnapshot(entry Raft.LogEntry) {
}

if err := k.storage.FlushToSSTable(kvEntries); err != nil {
fmt.Printf("[FSM ERROR] Snapshot replay failed: %v\n", err)
} else {
fmt.Printf("[FSM] Snapshot replay completed: %d entries flushed to SSTable\n", len(kvEntries))
slog.Error("snapshot replay failed", "error", err)
}
}

// Get 从存储获取值
func (k *KVServer) Get(key []byte) ([]byte, error) {
value, err := k.storage.Get(key)
if err != nil {
fmt.Printf("[ERROR] Get failed: %v\n", err)
} else {
fmt.Printf("[INFO] Get result: %s\n", string(value))
}
if value == nil {
if value == nil && err == nil {
return nil, errors.New("key not found")
}

return value, err
}

Expand Down