From f84d92c4c5d975ed01c97d9dff5504fd651624fa Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Wed, 18 Feb 2026 19:44:12 +0100 Subject: [PATCH 1/3] Add new percentile aggregator and tests for comparing it with Byzantine --- .../aggregator/percentile_clipping.spec.ts | 423 ++++++++++++++++++ discojs/src/aggregator/percentile_clipping.ts | 132 ++++++ 2 files changed, 555 insertions(+) create mode 100644 discojs/src/aggregator/percentile_clipping.spec.ts create mode 100644 discojs/src/aggregator/percentile_clipping.ts diff --git a/discojs/src/aggregator/percentile_clipping.spec.ts b/discojs/src/aggregator/percentile_clipping.spec.ts new file mode 100644 index 000000000..674d2cd09 --- /dev/null +++ b/discojs/src/aggregator/percentile_clipping.spec.ts @@ -0,0 +1,423 @@ +import { Set } from "immutable"; +import { describe, expect, it } from "vitest"; + +import { WeightsContainer } from "../index.js"; +import { ByzantineRobustAggregator } from "./byzantine.js"; +import { PercentileClippingAggregator } from "./percentile_clipping.js"; + +// Helper to convert WeightsContainer → number[][] for easy assertions +async function WSIntoArrays(ws: WeightsContainer): Promise { + return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); +} + +// Timing measurement helper +interface TimingResult { + name: string; + time: number; + result: number; +} + +async function measureAggregation( + aggregator: ByzantineRobustAggregator | PercentileClippingAggregator, + name: string, + peers: { id: string; value: number }[] +): Promise { + const promise = aggregator.getPromiseForAggregation(); + const currentRound = aggregator.round; + + const startTime = performance.now(); + peers.forEach(peer => { + aggregator.add(peer.id, WeightsContainer.of([peer.value]), currentRound); + }); + + const result = await promise; + const endTime = performance.now(); + + const arr = await WSIntoArrays(result); + const aggregatedValue = arr[0][0]; + + return { + name, + time: endTime - startTime, + result: aggregatedValue, + }; +} + +function formatTiming(timings: TimingResult[]): string { + const maxNameLen = Math.max(...timings.map(t => t.name.length)); + return timings + .map(t => ` ${t.name.padEnd(maxNameLen)} | ${t.time.toFixed(3)}ms | result: ${t.result.toFixed(4)}`) + .join('\n'); +} + +describe("Performance Comparison: Old vs New Byzantine Aggregators", () => { + /** + * Comparison test: Centered Clipping (Current) vs Percentile-based Clipping (Old) + * + * Setup: + * - Multiple honest updates (value=1.0) + * - Multiple Byzantine updates (value=100) + * - Measure robustness of aggregated result + */ + + it("both aggregators handle simple outlier rejection", async () => { + const honestPeers = ["honest1", "honest2"]; + const byzantinePeers = ["byzantine1"]; + const allPeers = honestPeers.concat(byzantinePeers); + + // --- TEST 1: New Aggregator (Centered Clipping with iterations) --- + const newAgg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 1, 0); + newAgg.setNodes(Set(allPeers)); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 100.0 })), + ]; + + const timingNew = await measureAggregation(newAgg, "New ByzantineRobust", peersWithValues); + + // --- TEST 2: Old Aggregator (Percentile Clipping) --- + const oldAgg = new PercentileClippingAggregator(0, 3, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + + const timingOld = await measureAggregation(oldAgg, "Old PercentileClipping", peersWithValues); + + // Both should produce a value closer to 1.0 than 100.0 + expect(timingNew.result).toBeLessThan(50); + expect(timingOld.result).toBeLessThan(50); + + // Print timing comparison + console.log("\n=== Timing Comparison: Simple Outlier Rejection ==="); + console.log(formatTiming([timingNew, timingOld])); + }); + + it("old aggregator with different percentiles", async () => { + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1", "byzantine2"]; + const allPeers = honestPeers.concat(byzantinePeers); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 50.0 })), + ]; + + const testPercentiles = [0.05, 0.1, 0.2, 0.5]; + const timings: TimingResult[] = []; + + for (const tau of testPercentiles) { + const agg = new PercentileClippingAggregator(0, 5, "absolute", tau); + agg.setNodes(Set(allPeers)); + + const timing = await measureAggregation(agg, `tau=${tau}`, peersWithValues); + timings.push(timing); + + // Should clip towards honest value + expect(timing.result).toBeLessThan(30); + } + + console.log("\n=== Timing Comparison: Old Aggregator with Different Percentiles ==="); + console.log(formatTiming(timings)); + }); + + it("new aggregator with different clipping radii", async () => { + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1", "byzantine2"]; + const allPeers = honestPeers.concat(byzantinePeers); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 50.0 })), + ]; + + const testRadii = [0.5, 1.0, 2.0, 5.0]; + const timings: TimingResult[] = []; + + for (const radius of testRadii) { + const agg = new ByzantineRobustAggregator(0, 5, "absolute", radius, 1, 0); + agg.setNodes(Set(allPeers)); + + const timing = await measureAggregation(agg, `radius=${radius}`, peersWithValues); + timings.push(timing); + + // With larger radius, more Byzantine influence + expect(timing.result).toBeGreaterThan(0); + } + + console.log("\n=== Timing Comparison: New Aggregator with Different Clipping Radii ==="); + console.log(formatTiming(timings)); + }); + + it("old aggregator stores previous aggregation state", async () => { + const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.1); + const [peer1, peer2] = ["peer1", "peer2"]; + agg.setNodes(Set([peer1, peer2])); + + // Round 1 + const peersRound1 = [ + { id: peer1, value: 5.0 }, + { id: peer2, value: 5.0 }, + ]; + + const timingRound1 = await measureAggregation(agg, "Round 1", peersRound1); + expect(timingRound1.result).to.equal(5.0); + + // Round 2 - should center around previous result + const peersRound2 = [ + { id: peer1, value: 10.0 }, + { id: peer2, value: 10.0 }, + ]; + + const timingRound2 = await measureAggregation(agg, "Round 2", peersRound2); + + // With centering on previous (5.0), updates to 10.0 should result in something close to 10.0 + expect(timingRound2.result).toBeGreaterThan(5.0); + + console.log("\n=== Timing Comparison: State Preservation Across Rounds ==="); + console.log(formatTiming([timingRound1, timingRound2])); + }); + + it("scalability: larger peer set (10 peers)", async () => { + const numHonest = 7; + const numByzantine = 3; + const honestPeers = Array.from({ length: numHonest }, (_, i) => `honest${i}`); + const byzantinePeers = Array.from({ length: numByzantine }, (_, i) => `byzantine${i}`); + const allPeers = honestPeers.concat(byzantinePeers); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 100.0 })), + ]; + + const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 1, 0); + newAgg.setNodes(Set(allPeers)); + const timingNew = await measureAggregation(newAgg, "New (10 peers)", peersWithValues); + + const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + const timingOld = await measureAggregation(oldAgg, "Old (10 peers)", peersWithValues); + + console.log("\n=== Scalability Test: 10 Peers (7 honest, 3 Byzantine) ==="); + console.log(formatTiming([timingNew, timingOld])); + console.log(` Speedup: ${(timingOld.time / timingNew.time).toFixed(2)}x`); + }); + + it("iterative refinement: new aggregator with multiple iterations", async () => { + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1", "byzantine2"]; + const allPeers = honestPeers.concat(byzantinePeers); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 50.0 })), + ]; + + const iterations = [1, 2, 5, 10]; + const timings: TimingResult[] = []; + + for (const iter of iterations) { + const agg = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, iter, 0); + agg.setNodes(Set(allPeers)); + + const timing = await measureAggregation(agg, `iterations=${iter}`, peersWithValues); + timings.push(timing); + } + + console.log("\n=== Performance Impact of Iterative Refinement ==="); + console.log(formatTiming(timings)); + }); + + it("equivalence: new aggregator with 1 iteration matches old aggregator", async () => { + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1", "byzantine2"]; + const allPeers = honestPeers.concat(byzantinePeers); + + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 50.0 })), + ]; + + const newAggWithOneIter = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 1, 0); + newAggWithOneIter.setNodes(Set(allPeers)); + + const oldAgg = new PercentileClippingAggregator(0, 5, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + + const timingNew = await measureAggregation(newAggWithOneIter, "New (maxIter=1)", peersWithValues); + const timingOld = await measureAggregation(oldAgg, "Old (tau=0.1)", peersWithValues); + + console.log("\n=== Equivalence Test: Single Iteration Convergence ==="); + console.log(formatTiming([timingNew, timingOld])); + console.log(` Result difference: ${Math.abs(timingNew.result - timingOld.result).toFixed(4)}`); + console.log(` Speed ratio (new/old): ${(timingNew.time / timingOld.time).toFixed(2)}x`); + + expect(timingNew.result).toBeLessThan(30); + expect(timingOld.result).toBeLessThan(30); + + // With single iteration, results should be very close (within reasonable tolerance) + expect(Math.abs(timingNew.result - timingOld.result)).toBeLessThan(5); + }); + + it("byzantine robustness: high ratio attack (40% malicious peers)", async () => { + const numHonest = 6; + const numByzantine = 4; + const honestPeers = Array.from({ length: numHonest }, (_, i) => `honest${i}`); + const byzantinePeers = Array.from({ length: numByzantine }, (_, i) => `byzantine${i}`); + const allPeers = honestPeers.concat(byzantinePeers); + + // Honest send gradient 1.0, Byzantine send large outlier + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 100.0 })), // 40% Byzantine pulling result up + ]; + + const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 1, 0); + newAgg.setNodes(Set(allPeers)); + const timingNew = await measureAggregation(newAgg, "New (40% Byzantine)", peersWithValues); + + const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + const timingOld = await measureAggregation(oldAgg, "Old (40% Byzantine)", peersWithValues); + + console.log("\n=== Byzantine Robustness: High Ratio Attack (4/10 = 40% malicious) ==="); + console.log(formatTiming([timingNew, timingOld])); + console.log(` Result gap: new=${timingNew.result.toFixed(4)}, old=${timingOld.result.toFixed(4)}`); + console.log(` Winner: ${timingNew.result < timingOld.result ? "NEW (closer to honest 1.0)" : "OLD (closer to honest 1.0)"}`); + }); + + it("byzantine robustness: gradient poisoning attack (crafted gradients)", async () => { + // Byzantine gradient attack: send gradients designed to manipulate centroid + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1", "byzantine2"]; + const allPeers = honestPeers.concat(byzantinePeers); + + // Honest: standard gradient 1.0 + // Byzantine: crafted to move result away from honest consensus + // Poisoning strategy: send same large value to coordinate + const peersWithValues = [ + ...honestPeers.map(id => ({ id, value: 1.0 })), + ...byzantinePeers.map(id => ({ id, value: 10.0 })), + ]; + + const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 2.0, 5, 0); + newAgg.setNodes(Set(allPeers)); + const timingNew = await measureAggregation(newAgg, "New (5 iterations)", peersWithValues); + + const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.2); + oldAgg.setNodes(Set(allPeers)); + const timingOld = await measureAggregation(oldAgg, "Old (tau=0.2)", peersWithValues); + + console.log("\n=== Gradient Poisoning Attack (coordinated Byzantine values) ==="); + console.log(formatTiming([timingNew, timingOld])); + console.log(` Result gap: new=${timingNew.result.toFixed(4)}, old=${timingOld.result.toFixed(4)}`); + console.log(` Expected honest value: 1.0000`); + console.log(` Winner: ${Math.abs(timingNew.result - 1.0) < Math.abs(timingOld.result - 1.0) ? "NEW (closer to honest)" : "OLD (closer to honest)"}`); + }); + + it("byzantine robustness: adaptive multi-round attack", async () => { + // Multi-round attack: Byzantine adapts based on previous aggregation + // Round 1: test the aggregator behavior + // Round 2: Byzantine sends crafted response gradient + + const honestPeers = ["honest1", "honest2", "honest3"]; + const byzantinePeers = ["byzantine1"]; + const allPeers = honestPeers.concat(byzantinePeers); + + // Round 1 setup + const round1Values = [ + ...honestPeers.map(id => ({ id, value: 5.0 })), + ...byzantinePeers.map(id => ({ id, value: 5.0 })), // Byzantine cooperates round 1 + ]; + + const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 5, 0); + newAgg.setNodes(Set(allPeers)); + const timing1New = await measureAggregation(newAgg, "New Round 1", round1Values); + + const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + const timing1Old = await measureAggregation(oldAgg, "Old Round 1", round1Values); + + // Round 2: Byzantine launches adaptive attack + const round2Values = [ + ...honestPeers.map(id => ({ id, value: 10.0 })), // Honest update + ...byzantinePeers.map(id => ({ id, value: 50.0 })), // Byzantine aggressive attack in round 2 + ]; + + const timing2New = await measureAggregation(newAgg, "New Round 2 (attack)", round2Values); + const timing2Old = await measureAggregation(oldAgg, "Old Round 2 (attack)", round2Values); + + console.log("\n=== Adaptive Multi-Round Attack ==="); + console.log("Round 1 (cooperation):"); + console.log(formatTiming([timing1New, timing1Old])); + console.log("\nRound 2 (adaptive Byzantine attack):"); + console.log(formatTiming([timing2New, timing2Old])); + console.log(` New result: ${timing2New.result.toFixed(4)} (expected ~10.0)`); + console.log(` Old result: ${timing2Old.result.toFixed(4)} (expected ~10.0)`); + console.log(` Winner: ${Math.abs(timing2New.result - 10.0) < Math.abs(timing2Old.result - 10.0) ? "NEW (better rejects attack)" : "OLD (better rejects attack)"}`); + }); + + it("heterogeneous gradients: realistic multi-tensor federated model", async () => { + // Realistic FL scenario: aggregate weights across multiple layers/tensors with different scales + // Layer 1: large values (e.g., from first dense layer) + // Layer 2: small values (e.g., from final output layer) + + const honestPeers = ["honest1", "honest2", "honest3", "honest4"]; + const byzantinePeers = ["byzantine1"]; + const allPeers = honestPeers.concat(byzantinePeers); + + // Create multi-tensor contributions + const createHeterogeneousGradient = (baseValue: number): WeightsContainer => { + return WeightsContainer.of([baseValue * 100, baseValue * 10, baseValue]); // Different scales + }; + + const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 5.0, 3, 0); + newAgg.setNodes(Set(allPeers)); + const promiseNew = newAgg.getPromiseForAggregation(); + const startNew = performance.now(); + + honestPeers.forEach(id => { + newAgg.add(id, createHeterogeneousGradient(1.0), newAgg.round); + }); + byzantinePeers.forEach(id => { + newAgg.add(id, createHeterogeneousGradient(100.0), newAgg.round); // Byzantine + }); + + const resultNew = await promiseNew; + const timeNew = performance.now() - startNew; + const arrNew = await WSIntoArrays(resultNew); + + const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); + oldAgg.setNodes(Set(allPeers)); + const promiseOld = oldAgg.getPromiseForAggregation(); + const startOld = performance.now(); + + honestPeers.forEach(id => { + oldAgg.add(id, createHeterogeneousGradient(1.0), oldAgg.round); + }); + byzantinePeers.forEach(id => { + oldAgg.add(id, createHeterogeneousGradient(100.0), oldAgg.round); // Byzantine + }); + + const resultOld = await promiseOld; + const timeOld = performance.now() - startOld; + const arrOld = await WSIntoArrays(resultOld); + + console.log("\n=== Heterogeneous Gradients: Multi-Tensor Federated Model ==="); + console.log("Gradient structure: [layer1=100×value, layer2=10×value, layer3=value]"); + console.log("Honest peers send: [100, 10, 1]"); + console.log("Byzantine sends: [10000, 1000, 100]"); + console.log(`\nNew aggregator (${timeNew.toFixed(2)}ms):`); + console.log(` Layer 1: ${arrNew[0][0].toFixed(2)} (expected ~100)`); + console.log(` Layer 2: ${arrNew[0][1].toFixed(2)} (expected ~10)`); + console.log(` Layer 3: ${arrNew[0][2].toFixed(2)} (expected ~1)`); + console.log(`\nOld aggregator (${timeOld.toFixed(2)}ms):`); + console.log(` Layer 1: ${arrOld[0][0].toFixed(2)} (expected ~100)`); + console.log(` Layer 2: ${arrOld[0][1].toFixed(2)} (expected ~10)`); + console.log(` Layer 3: ${arrOld[0][2].toFixed(2)} (expected ~1)`); + + // Check relative error + const newError = Math.abs((arrNew[0][0] - 100) / 100) + Math.abs((arrNew[0][1] - 10) / 10) + Math.abs(arrNew[0][2] - 1); + const oldError = Math.abs((arrOld[0][0] - 100) / 100) + Math.abs((arrOld[0][1] - 10) / 10) + Math.abs(arrOld[0][2] - 1); + console.log(`\nTotal relative error: new=${newError.toFixed(3)}, old=${oldError.toFixed(3)}`); + console.log(`Winner: ${newError < oldError ? "NEW (better handles multi-scale)" : "OLD (better handles multi-scale)"}`); + }); +}); diff --git a/discojs/src/aggregator/percentile_clipping.ts b/discojs/src/aggregator/percentile_clipping.ts new file mode 100644 index 000000000..88a28b957 --- /dev/null +++ b/discojs/src/aggregator/percentile_clipping.ts @@ -0,0 +1,132 @@ +import { Map } from "immutable"; +import * as tf from '@tensorflow/tfjs'; +import { AggregationStep } from "./aggregator.js"; +import { MultiRoundAggregator, ThresholdType } from "./multiround.js"; +import { WeightsContainer, client } from "../index.js"; +import { aggregation } from "../index.js"; + +/** + * Old Byzantine-robust aggregator using Percentile-based Clipping + * + * This class implements a gradient aggregation rule that clips updates based on a + * percentile-computed threshold (tau) to mitigate the influence of Byzantine nodes. + * Unlike the iterative Centered Clipping approach, this uses a single-pass percentile-based clipping. + * + * Algorithm: + * 1. Center all peer weights w.r.t. the previous aggregation + * 2. Compute Frobenius norm for each centered weight + * 3. Compute tau as the percentile of the norm array + * 4. Clip each centered weight: clip = centeredWeight * min(1, tau / norm) + * 5. Average all clipped weights + */ + +export class PercentileClippingAggregator extends MultiRoundAggregator { + private readonly tauPercentile: number; + private prevAggregate: WeightsContainer | null = null; + + /** + * @property tauPercentile The percentile (0 < tau < 1) used to compute the clipping threshold. + * - Type: `number` + * - Determines which percentile of the Frobenius norms to use as the clipping threshold. + * - For example, 0.1 clips at the 10th percentile of norms. + * - Smaller values are more aggressive (clip more updates). + * - Default value is 0.1. + */ + + constructor(roundCutoff = 0, threshold = 1, thresholdType?: ThresholdType, tauPercentile = 0.1) { + super(roundCutoff, threshold, thresholdType); + if (tauPercentile <= 0 || tauPercentile >= 1) { + throw new Error("Tau percentile must be between 0 and 1 (exclusive)."); + } + this.tauPercentile = tauPercentile; + } + + override _add(nodeId: client.NodeID, contribution: WeightsContainer): void { + this.log( + this.contributions.hasIn([0, nodeId]) ? AggregationStep.UPDATE : AggregationStep.ADD, + nodeId, + ); + // Store contribution as is, without client-side momentum + this.contributions = this.contributions.setIn([0, nodeId], contribution); + } + + override aggregate(): WeightsContainer { + const currentContributions = this.contributions.get(0); + if (!currentContributions) throw new Error("aggregating without any contribution"); + + this.log(AggregationStep.AGGREGATE); + + // Step 1: Get the centering reference (previous aggregation or zero vector) + let centerReference: WeightsContainer; + if (this.prevAggregate) { + centerReference = this.prevAggregate; + } else { + // Use shape of the first contribution to create zero vector + const first = currentContributions.values().next(); + if (first.done) throw new Error("zero sized contribution"); + centerReference = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); + } + + // Step 2: Center the weights with respect to the reference + const centeredWeights = Array.from(currentContributions.values()).map(w => + w.sub(centerReference) + ); + + // Step 3: Compute Frobenius norms for each centered weight + const normArray = centeredWeights.map(w => frobeniusNorm(w)); + + // Step 4: Compute tau as the percentile of the norm array + const tau = this.computePercentile(normArray, this.tauPercentile); + + // Step 5: Clip weights based on tau + // Each peer gets one scale factor based on their Frobenius norm + const clippedWeights = centeredWeights.map((w, peerIdx) => { + const scaleFactor = Math.min(1, tau / normArray[peerIdx]); + return w.map((t: tf.Tensor) => t.mul(scaleFactor)); + }); + + // Step 6: Average the clipped weights and add back the reference + const clippedAvg = aggregation.avg(clippedWeights); + const result = centerReference.add(clippedAvg); + + clippedWeights.forEach(w => w.dispose()); + clippedAvg.dispose(); + if (!this.prevAggregate) { + centerReference.dispose(); + } + + // Step 7: Store result for next round + this.prevAggregate = result; + return result; + } + + private computePercentile(array: number[], percentile: number): number { + // Linear interpolation for percentile calculation + const sorted = [...array].sort((a, b) => a - b); + const pos = (sorted.length - 1) * percentile; + const base = Math.floor(pos); + const rest = pos - base; + + if (sorted[base + 1] !== undefined) { + return sorted[base] + rest * (sorted[base + 1] - sorted[base]); + } else { + return sorted[base]; + } + } + + override makePayloads(weights: WeightsContainer): Map { + return this.nodes.toMap().map(() => weights); + } +} + +function frobeniusNorm(w: WeightsContainer): number { + // Computes the Frobenius (L2) norm of all tensors in a WeightsContainer + // sqrt(sum of all squared elements across all tensors) + return tf.tidy(() => { + const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); + const total = norms.reduce((a, b) => tf.add(a, b)); + const result = tf.sqrt(total); + const value = result.dataSync()[0]; + return value; + }); +} From 2ecd9231bb6f789ebcff004228f2c7eae8f4d192 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 14 Apr 2026 19:48:31 +0200 Subject: [PATCH 2/3] Fix Byzantine (paper-like) and compare the usecases with percentile --- discojs/package.json | 3 +- discojs/src/aggregator/byzantine.spec.ts | 238 ++++++++- discojs/src/aggregator/byzantine.ts | 83 ++- .../byzantine_vs_percentile.spec.ts | 362 +++++++++++++ .../aggregator/percentile_clipping.spec.ts | 482 +++++------------- discojs/src/aggregator/percentile_clipping.ts | 54 +- package-lock.json | 54 +- 7 files changed, 850 insertions(+), 426 deletions(-) create mode 100644 discojs/src/aggregator/byzantine_vs_percentile.spec.ts diff --git a/discojs/package.json b/discojs/package.json index 4e3b63a5f..2cc709fca 100644 --- a/discojs/package.json +++ b/discojs/package.json @@ -34,6 +34,7 @@ "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", "nodemon": "3", - "ts-node": "10" + "ts-node": "10", + "fast-check": "3" } } diff --git a/discojs/src/aggregator/byzantine.spec.ts b/discojs/src/aggregator/byzantine.spec.ts index d300fbb4c..4770947ef 100644 --- a/discojs/src/aggregator/byzantine.spec.ts +++ b/discojs/src/aggregator/byzantine.spec.ts @@ -1,5 +1,6 @@ import { Set } from "immutable"; import { describe, expect, it } from "vitest"; +import fc from "fast-check"; import { WeightsContainer } from "../index.js"; import { ByzantineRobustAggregator } from "./byzantine.js"; @@ -31,8 +32,8 @@ describe("ByzantineRobustAggregator", () => { expect(arr).to.deep.equal([[2], [3]]); }); - it("clips a single outlier with small radius", async () => { - const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + it("reduces influence of a single outlier", async () => { + const agg = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 10, 0); const [c1, c2, bad] = ["c1", "c2", "bad"]; agg.setNodes(Set.of(c1, c2, bad)); @@ -43,21 +44,40 @@ describe("ByzantineRobustAggregator", () => { const out = await p; const arr = await WSIntoArrays(out); - expect(arr[0][0]).to.be.closeTo(1, 1e-6); + + const result = arr[0][0]; + const mean = (1 + 1 + 100) / 3; + + expect(Math.abs(result - 1)).to.be.lessThan(Math.abs(mean - 1)); }); - it("applies multiple clipping iterations (maxIterations > 1)", async () => { - const agg = new ByzantineRobustAggregator(0, 2, 'absolute', 1.0, 3, 0); - const [c1, bad] = ["c1", "bad"]; - agg.setNodes(Set.of(c1, bad)); + it("multiple iterations improve the estimate", async () => { + const [c1, c2, bad] = ["c1", "c2", "bad"]; - const p = agg.getPromiseForAggregation(); - agg.add(c1, WeightsContainer.of([0]), 0); - agg.add(bad, WeightsContainer.of([10]), 0); + const agg1 = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 1, 0); + agg1.setNodes(Set.of(c1, c2, bad)); - const out = await p; - const arr = await WSIntoArrays(out); - expect(arr[0][0]).to.be.lessThan(1); // clipped closer to 0 + const p1 = agg1.getPromiseForAggregation(); + agg1.add(c1, WeightsContainer.of([0]), 0); + agg1.add(c2, WeightsContainer.of([0]), 0); + agg1.add(bad, WeightsContainer.of([10]), 0); + const out1 = await p1; + const arr1 = await WSIntoArrays(out1); + + const agg3 = new ByzantineRobustAggregator(0, 3, 'absolute', 1.0, 3, 0); + agg3.setNodes(Set.of(c1, c2, bad)); + + const p3 = agg3.getPromiseForAggregation(); + agg3.add(c1, WeightsContainer.of([0]), 0); + agg3.add(c2, WeightsContainer.of([0]), 0); + agg3.add(bad, WeightsContainer.of([10]), 0); + const out3 = await p3; + const arr3 = await WSIntoArrays(out3); + + const honest = 0; + + expect(Math.abs(arr3[0][0] - honest)) + .to.be.lessThanOrEqual(Math.abs(arr1[0][0] - honest)); }); it("uses momentum when beta > 0", async () => { @@ -65,21 +85,25 @@ describe("ByzantineRobustAggregator", () => { const [c1, c2] = ["c1", "c2"]; agg.setNodes(Set.of(c1, c2)); + // Round 1 const p1 = agg.getPromiseForAggregation(); agg.add(c1, WeightsContainer.of([2]), 0); agg.add(c2, WeightsContainer.of([2]), 0); const out1 = await p1; const arr1 = await WSIntoArrays(out1); - expect(arr1[0][0]).to.equal(2); + // m₀ = (1 - β) * g = 1 + expect(arr1[0][0]).to.be.closeTo(1, 1e-6); + + // Round 2 const p2 = agg.getPromiseForAggregation(); agg.add(c1, WeightsContainer.of([4]), 1); agg.add(c2, WeightsContainer.of([4]), 1); const out2 = await p2; const arr2 = await WSIntoArrays(out2); - // With momentum = 0.5, result = 0.5 * prev + 0.5 * current = 3.0 - expect(arr2[0][0]).to.be.closeTo(3, 1e-6); + // m₁ = 0.5*4 + 0.5*1 = 2.5 → avg = 2.5 + expect(arr2[0][0]).to.be.closeTo(2.5, 1e-6); }); it("respects roundCutoff — ignores old contributions", async () => { @@ -100,4 +124,186 @@ describe("ByzantineRobustAggregator", () => { const arr2 = await WSIntoArrays(out2); expect(arr2[0][0]).to.equal(20); }); + + it("remains robust with 30% Byzantine clients", async () => { + const honest = Array(7).fill(1); + const byzantine = Array(3).fill(100); + + const agg = new ByzantineRobustAggregator(0, 10, 'absolute', 1.0, 5, 0); + const ids = [...honest, ...byzantine].map((_, i) => `c${i}`); + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + honest.forEach((v, i) => agg.add(`c${i}`, WeightsContainer.of([v]), 0)); + byzantine.forEach((v, i) => agg.add(`c${i + honest.length}`, WeightsContainer.of([v]), 0)); + + const out = await p; + const arr = await WSIntoArrays(out); + + const honestMean = honest.reduce((a, b) => a + b, 0) / honest.length; + const rawMean = [...honest, ...byzantine].reduce((a, b) => a + b, 0) / (honest.length + byzantine.length); + + expect(Math.abs(arr[0][0] - honestMean)).to.be.lessThan(Math.abs(rawMean - honestMean)); + }); + + it("moves closer to the honest signal under constant input", async () => { + const honest = 1; + + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([10]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + const mean = (1 + 1 + 1 + 10) / 4; + + expect(Math.abs(v - honest)).to.be.lessThan(Math.abs(mean - honest)); + }); + + it("does not significantly worsen deviation compared to mean", async () => { + const clipRadius = 1.0; + + await fc.assert( + fc.asyncProperty( + fc.array( + fc.double({ + min: -1, + max: 1, + noNaN: true, + noDefaultInfinity: true + }), + { minLength: 3, maxLength: 10 } + ) + // avoid degenerate constant arrays (no signal) + .filter(arr => arr.some(v => Math.abs(v - arr[0]) > 1e-8)), + + async (honest) => { + const n = honest.length + 1; + + // clean aggregation + const aggClean = new ByzantineRobustAggregator(0, honest.length, "absolute", clipRadius, 1, 0); + const honestIds = honest.map((_, i) => `h${i}`); + aggClean.setNodes(Set(honestIds)); + + const pClean = aggClean.getPromiseForAggregation(); + honest.forEach((v, i) => aggClean.add(`h${i}`, WeightsContainer.of([v]), 0)); + const cleanOut = await pClean; + const clean = (await cleanOut.weights[0].data())[0]; + + // aggregation with Byzantine + const aggByz = new ByzantineRobustAggregator(0, n, "absolute", clipRadius, 1, 0); + const ids = honestIds.concat("byz"); + aggByz.setNodes(Set(ids)); + + const pByz = aggByz.getPromiseForAggregation(); + honest.forEach((v, i) => aggByz.add(`h${i}`, WeightsContainer.of([v]), 0)); + aggByz.add("byz", WeightsContainer.of([1e9]), 0); + + const byzOut = await pByz; + const byz = (await byzOut.weights[0].data())[0]; + + const deviation = Math.abs(byz - clean); + const mean = [...honest, 1e9].reduce((a, b) => a + b, 0) / n; + const baseline = Math.abs(mean - clean); + + // combined tolerance (absolute + relative) + const ABS_EPS = 1e-6; + const REL_EPS = 1e-6; + + expect(deviation).toBeLessThanOrEqual( + baseline * (1 + REL_EPS) + ABS_EPS + ); + } + ), + { numRuns: 500 } + ); + }); + + it("is invariant to client ordering", async () => { + const values = [0, 1, 100]; + const ids1 = ["a", "b", "c"]; + const ids2 = ["c", "a", "b"]; + + const run = async (ids: string[]) => { + const agg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 3, 0); + agg.setNodes(Set(ids)); + const p = agg.getPromiseForAggregation(); + ids.forEach((id, i) => + agg.add(id, WeightsContainer.of([values[i]]), 0) + ); + return (await (await p).weights[0].data())[0]; + }; + + const out1 = await run(ids1); + const out2 = await run(ids2); + + expect(out1).to.be.closeTo(out2, 1e-6); + }); + + it("is idempotent when all inputs are identical and within clipping radius", async () => { + const agg = new ByzantineRobustAggregator(0, 5, "absolute", 10.0, 5, 0); + const ids = ["a", "b", "c", "d", "e"]; + agg.setNodes(Set(ids)); + + const p = agg.getPromiseForAggregation(); + ids.forEach(id => agg.add(id, WeightsContainer.of([3.14]), 0)); + const out = await p; + + const v = (await out.weights[0].data())[0]; + expect(v).to.be.closeTo(3.14, 1e-6); + }); + + it("limits bias under symmetric Byzantine attacks", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["h1", "h2", "b1", "b2"])); + + const p = agg.getPromiseForAggregation(); + agg.add("h1", WeightsContainer.of([1]), 0); + agg.add("h2", WeightsContainer.of([1]), 0); + agg.add("b1", WeightsContainer.of([100]), 0); + agg.add("b2", WeightsContainer.of([-100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + expect(Math.abs(v - 1)).to.be.lessThan(Math.abs((1 + 1 + 100 - 100)/4 - 1)); + }); + + it("reduces influence of extreme outliers", async () => { + const agg = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 3, 0); + agg.setNodes(Set(["a", "b", "c", "d"])); + + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([0.5]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([100]), 0); + + const out = await p; + const v = (await out.weights[0].data())[0]; + + const mean = (0 + 0.5 + 1 + 100) / 4; + const honestCenter = (0 + 0.5 + 1) / 3; + + expect(Math.abs(v - honestCenter)).to.be.lessThan(Math.abs(mean - honestCenter)); + }); + + it("reset state when starting fresh aggregator", async () => { + const run = async () => { + const agg = new ByzantineRobustAggregator(0, 2, "absolute", 1.0, 3, 0.9); + agg.setNodes(Set(["a", "b"])); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + return (await (await p).weights[0].data())[0]; + }; + + expect(await run()).to.be.closeTo(await run(), 1e-6); + }); }); diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index 64d5cbe43..f61346124 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -9,8 +9,33 @@ import { aggregation } from "../index.js"; * Byzantine-robust aggregator using Centered Clipping (CC), based on the * "Learning from History for Byzantine Robust Optimization" paper: https://arxiv.org/abs/2012.10333 * - * This class implements a gradient aggregation rule that clips updates - * in an iterative fashion to mitigate the influence of Byzantine nodes, as well as momentum calculations. + * This class implements Centered Clipping (Algorithm 1) with an additional + * server-side per-client momentum mechanism inspired by Algorithm 2. + * + * We initialize using the mean of contributions when no previous + * aggregate exists. This improves convergence compared to zero initialization. + * + * NOTE: + * - Momentum: + * m_i^t = (1 - β) g_i^t + β m_i^{t-1} + * - Aggregation is then performed on {m_i} + * + * WARNING: + * This implementation requires stable client identities and is not + * compatible with secure aggregation, since per-client momentum + * must be tracked on the server. + * + * Use Case: + * + * Designed for federated or distributed learning with potentially malicious + * (Byzantine) clients. Centered Clipping limits the influence of extreme or + * corrupted updates by bounding each client's contribution. + * + * CC alone can be sensitive to poor initialization (e.g., early extreme + * Byzantine updates), as clipping limits updates but does not correct a + * bad initial estimate. The added per-client momentum helps stabilize + * training over time by leveraging historical information. + * */ export class ByzantineRobustAggregator extends MultiRoundAggregator { private readonly clippingRadius: number; @@ -35,7 +60,7 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { * - Type: `number` * - Must be between 0 and 1. * - Used to compute the exponential moving average of past aggregates (i.e., momentum vector). - * The update typically looks like: `v_t = beta * v_{t-1} + (1 - beta) * g_t`, where `g_t` is the current clipped average. + * The update typically looks like: `m_i^t = (1 - β) g_i^t + β m_i^{t-1}`. * - A higher beta gives more weight to past rounds (more smoothing), while a lower beta makes the aggregator more responsive to new updates. */ @@ -60,7 +85,7 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { const prevMomentum = this.historyMomentums.get(nodeId); const newMomentum = prevMomentum ? contribution.mapWith(prevMomentum, (g, m) => g.mul(1 - this.beta).add(m.mul(this.beta))) - : contribution; // no scaling on first momentum + : contribution.map(g => g.mul(1 - this.beta)); this.historyMomentums = this.historyMomentums.set(nodeId, newMomentum); this.contributions = this.contributions.setIn([0, nodeId], newMomentum); @@ -77,34 +102,55 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { return aggregation.avg(currentContributions.values()); } - // Step 1: Initialize v to average of previous aggregations + // Step 1: Initialize v using previous aggregate or mean of contributions let v: WeightsContainer; if (this.prevAggregate) { - v = this.prevAggregate; + v = this.prevAggregate.map(t => tf.clone(t)); // Clone to avoid in-place modifications } else { - // Use shape of the first contribution to create zero vector - const first = currentContributions.values().next(); - if (first.done) throw new Error("zero sized contribution") - v = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); + v = aggregation.avg(currentContributions.values()); } + + const eps = tf.scalar(1e-12); + const one = tf.scalar(1); + const radius = tf.scalar(this.clippingRadius); + // Step 2: Iterative Centered Clipping for (let l = 0; l < this.maxIterations; l++) { const clippedDiffs = Array.from(currentContributions.values()).map(m => { const diff = m.sub(v); - const norm = tf.tidy(() => euclideanNorm(diff)); - const scale = tf.tidy(() => tf.minimum(tf.scalar(1), tf.div(tf.scalar(this.clippingRadius), norm))); + + const norm = euclideanNorm(diff); + + const safeNorm = tf.maximum(norm, eps); + + const scale = tf.minimum( + one, + tf.div(radius, safeNorm) + ); + const clipped = diff.mul(scale); - norm.dispose(); scale.dispose(); + + norm.dispose(); + safeNorm.dispose(); + scale.dispose(); + return clipped; }); const avgClip = aggregation.avg(clippedDiffs); const newV = v.add(avgClip); + clippedDiffs.forEach(d => d.dispose()); - v.dispose(); // Safe if v is no longer needed + + const oldV = v; v = newV; + oldV.dispose(); } - // Step 3: Update momentum history + + eps.dispose(); + one.dispose(); + radius.dispose(); + // Step 3: Update history this.prevAggregate = v; return v; } @@ -119,8 +165,11 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); - const total = norms.reduce((a, b) => tf.add(a, b)); + const zero = tf.scalar(0); + + const total = w.weights + .map(t => tf.sum(tf.square(t)) as tf.Scalar) + .reduce((a, b) => tf.add(a, b) as tf.Scalar, zero); return tf.sqrt(total); }); } \ No newline at end of file diff --git a/discojs/src/aggregator/byzantine_vs_percentile.spec.ts b/discojs/src/aggregator/byzantine_vs_percentile.spec.ts new file mode 100644 index 000000000..3c141e53d --- /dev/null +++ b/discojs/src/aggregator/byzantine_vs_percentile.spec.ts @@ -0,0 +1,362 @@ +import { Set } from "immutable"; +import { describe, expect, it } from "vitest"; + +import { WeightsContainer } from "../index.js"; +import { ByzantineRobustAggregator } from "./byzantine.js"; +import { PercentileClippingAggregator } from "./percentile_clipping.js"; + +// Helper to convert WeightsContainer → number[][] for easy assertions +async function WSIntoArrays(ws: WeightsContainer): Promise { + return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); +} + +// Timing measurement helper +interface TimingResult { + name: string; + time: number; + result: number; +} + +async function measureAggregation( + aggregator: ByzantineRobustAggregator | PercentileClippingAggregator, + name: string, + peers: { id: string; value: number }[] +): Promise { + const promise = aggregator.getPromiseForAggregation(); + const currentRound = aggregator.round; + + const startTime = performance.now(); + peers.forEach(peer => { + aggregator.add(peer.id, WeightsContainer.of([peer.value]), currentRound); + }); + + const result = await promise; + const endTime = performance.now(); + + const arr = await WSIntoArrays(result); + const aggregatedValue = arr[0][0]; + + return { + name, + time: endTime - startTime, + result: aggregatedValue, + }; +} + +function formatTiming(timings: TimingResult[]): string { + const maxNameLen = Math.max(...timings.map(t => t.name.length)); + return timings + .map(t => ` ${t.name.padEnd(maxNameLen)} | ${t.time.toFixed(3)}ms | result: ${t.result.toFixed(4)}`) + .join('\n'); +} + +describe("Comparison: Centered Clipping vs Percentile Clipping", () => { +/** + * ============================================================ + * Comparison: Centered Clipping (CC) vs Percentile Clipping + * ============================================================ + * + * These tests highlight the fundamental differences between two + * aggregation strategies used in adversarial / federated settings. + * + * Centered Clipping (CC): + * - Iterative, principled aggregation rule with bounded updates + * - Provides theoretical robustness against Byzantine clients + * - Gradually refines the estimate over multiple iterations + * - More stable across rounds and symmetric/adversarial scenarios + * - Computationally more expensive (multiple passes over data) + * - Converges slowly if initialized far from the true signal + * + * Percentile Clipping: + * - Single-pass, heuristic aggregation based on norm thresholds + * - Fast and simple, with low computational overhead + * - Works well when outliers are clearly separable + * - Highly sensitive to data distribution and chosen percentile (tau) + * - Can behave like simple averaging in moderate/noisy settings + * - Can fail when Byzantine clients dominate or blend with honest ones + * + * When to use which: + * + * - Use Centered Clipping when: + * - Robustness is critical (adversarial or unreliable clients) + * - You can afford additional computation + * - You expect persistent or structured Byzantine behavior + * + * - Use Percentile Clipping when: + * - You need fast, lightweight aggregation + * - Data is mostly clean with occasional outliers + * - Strong robustness guarantees are not required + * + * Summary: + * CC - slower but principled and robust + * Percentile - faster but heuristic and less reliable + * + * The tests below illustrate these trade-offs across different + * regimes (extreme outliers, moderate attacks, multi-round behavior, etc.). + */ + it("CC improves with more iterations", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 1000 }, + ]; + + const ids = peers.map(p => p.id); + + const cc1 = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 1, 0); + const cc50 = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 50, 0); + + cc1.setNodes(Set(ids)); + cc50.setNodes(Set(ids)); + + const r1 = await measureAggregation(cc1, "cc1", peers); + const r50 = await measureAggregation(cc50, "cc50", peers); + + const honest = 1; + + expect(Math.abs(r50.result - honest)) + .to.be.lessThan(Math.abs(r1.result - honest)); + }); + + it("percentile behaves like mean under moderate Byzantine values", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 3 }, + { id: "b2", value: 3 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 5, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + const mean = (1 + 1 + 1 + 3 + 3) / 5; + + // percentile behaves close to mean + expect(Math.abs(resPC.result - mean)).to.be.lessThan(0.5); + + // both are biased away from honest + expect(Math.abs(resPC.result - honest)).to.be.greaterThan(0.3); + }); + + it("iterations improve CC but not percentile", async () => { + const peers = [ + { id: "h1", value: 0 }, + { id: "h2", value: 0 }, + { id: "b1", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc1 = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 1, 0); + const cc5 = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 5, 0); + + cc1.setNodes(Set(ids)); + cc5.setNodes(Set(ids)); + + const r1 = await measureAggregation(cc1, "cc1", peers); + const r5 = await measureAggregation(cc5, "cc5", peers); + + expect(Math.abs(r5.result)) + .to.be.lessThanOrEqual(Math.abs(r1.result)); + }); + + it("percentile sensitivity to tau parameter", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 100 }, + ]; + + const ids = peers.map(p => p.id); + + const lowTau = new PercentileClippingAggregator(0, 3, "absolute", 0.1); + const highTau = new PercentileClippingAggregator(0, 3, "absolute", 0.9); + + lowTau.setNodes(Set(ids)); + highTau.setNodes(Set(ids)); + + const rLow = await measureAggregation(lowTau, "low", peers); + const rHigh = await measureAggregation(highTau, "high", peers); + + expect(Math.abs(rLow.result - 1)) + .to.be.lessThan(Math.abs(rHigh.result - 1)); + }); + + it("CC is at least as stable across rounds as percentile", async () => { + const ids = ["h1", "h2", "b1"]; + + const cc = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + // Round 1 + const prev = 5; + + await measureAggregation(cc, "cc1", [ + { id: "h1", value: prev }, + { id: "h2", value: prev }, + { id: "b1", value: prev }, + ]); + + await measureAggregation(pc, "pc1", [ + { id: "h1", value: prev }, + { id: "h2", value: prev }, + { id: "b1", value: prev }, + ]); + + // Round 2 (attack) + const rCC = await measureAggregation(cc, "cc2", [ + { id: "h1", value: 10 }, + { id: "h2", value: 10 }, + { id: "b1", value: 100 }, + ]); + + const rPC = await measureAggregation(pc, "pc2", [ + { id: "h1", value: 10 }, + { id: "h2", value: 10 }, + { id: "b1", value: 100 }, + ]); + + const deltaCC = Math.abs(rCC.result - prev); + const deltaPC = Math.abs(rPC.result - prev); + + expect(deltaCC).to.be.at.most(deltaPC + 1e-6); + }); + + it("percentile breaks when Byzantine dominate percentile", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 10 }, + { id: "b2", value: 10 }, + { id: "b3", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 10, 0); + const pc = new PercentileClippingAggregator(0, 5, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + + // Percentile clearly drifts + expect(resPC.result).to.be.greaterThan(3); + + // Both are worse than honest + expect(Math.abs(resCC.result - honest)).to.be.greaterThan(1); + expect(Math.abs(resPC.result - honest)).to.be.greaterThan(1); + }); + + it("both aggregators behave similarly without Byzantine clients", async () => { + const peers = [ + { id: "a", value: 1 }, + { id: "b", value: 2 }, + { id: "c", value: 3 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 3, "absolute", 10, 1, 0); + const pc = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + expect(resCC.result).to.be.closeTo(resPC.result, 1e-6); + }); + + it("prints timing comparison (CC vs Percentile)", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "h3", value: 1 }, + { id: "b1", value: 1000 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 20, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "CC", peers); + const resPC = await measureAggregation(pc, "Percentile", peers); + + console.log("\nTiming comparison:\n" + formatTiming([resCC, resPC])); + + expect(resPC.time).to.be.lessThan(resCC.time); + }); + + it("CC handles symmetric attacks better than percentile", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 1 }, + { id: "b1", value: 100 }, + { id: "b2", value: -100 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honest = 1; + + expect(Math.abs(resCC.result - honest)) + .to.be.lessThan(Math.abs(resPC.result - honest)); + }); + + it("percentile is sensitive to honest variance", async () => { + const peers = [ + { id: "h1", value: 1 }, + { id: "h2", value: 2 }, + { id: "h3", value: 3 }, + { id: "b1", value: 10 }, + ]; + + const ids = peers.map(p => p.id); + + const cc = new ByzantineRobustAggregator(0, 4, "absolute", 1.0, 5, 0); + const pc = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + + cc.setNodes(Set(ids)); + pc.setNodes(Set(ids)); + + const resCC = await measureAggregation(cc, "cc", peers); + const resPC = await measureAggregation(pc, "pc", peers); + + const honestMean = (1 + 2 + 3) / 3; + + expect(Math.abs(resCC.result - honestMean)) + .to.be.lessThan(Math.abs(resPC.result - honestMean)); + }); +}); \ No newline at end of file diff --git a/discojs/src/aggregator/percentile_clipping.spec.ts b/discojs/src/aggregator/percentile_clipping.spec.ts index 674d2cd09..1ba39f8c3 100644 --- a/discojs/src/aggregator/percentile_clipping.spec.ts +++ b/discojs/src/aggregator/percentile_clipping.spec.ts @@ -2,422 +2,168 @@ import { Set } from "immutable"; import { describe, expect, it } from "vitest"; import { WeightsContainer } from "../index.js"; -import { ByzantineRobustAggregator } from "./byzantine.js"; import { PercentileClippingAggregator } from "./percentile_clipping.js"; -// Helper to convert WeightsContainer → number[][] for easy assertions async function WSIntoArrays(ws: WeightsContainer): Promise { return Promise.all(ws.weights.map(async t => Array.from(await t.data()))); } -// Timing measurement helper -interface TimingResult { - name: string; - time: number; - result: number; -} - -async function measureAggregation( - aggregator: ByzantineRobustAggregator | PercentileClippingAggregator, - name: string, - peers: { id: string; value: number }[] -): Promise { - const promise = aggregator.getPromiseForAggregation(); - const currentRound = aggregator.round; - - const startTime = performance.now(); - peers.forEach(peer => { - aggregator.add(peer.id, WeightsContainer.of([peer.value]), currentRound); - }); - - const result = await promise; - const endTime = performance.now(); - - const arr = await WSIntoArrays(result); - const aggregatedValue = arr[0][0]; - - return { - name, - time: endTime - startTime, - result: aggregatedValue, - }; -} +describe("PercentileClippingAggregator", () => { -function formatTiming(timings: TimingResult[]): string { - const maxNameLen = Math.max(...timings.map(t => t.name.length)); - return timings - .map(t => ` ${t.name.padEnd(maxNameLen)} | ${t.time.toFixed(3)}ms | result: ${t.result.toFixed(4)}`) - .join('\n'); -} - -describe("Performance Comparison: Old vs New Byzantine Aggregators", () => { - /** - * Comparison test: Centered Clipping (Current) vs Percentile-based Clipping (Old) - * - * Setup: - * - Multiple honest updates (value=1.0) - * - Multiple Byzantine updates (value=100) - * - Measure robustness of aggregated result - */ - - it("both aggregators handle simple outlier rejection", async () => { - const honestPeers = ["honest1", "honest2"]; - const byzantinePeers = ["byzantine1"]; - const allPeers = honestPeers.concat(byzantinePeers); - - // --- TEST 1: New Aggregator (Centered Clipping with iterations) --- - const newAgg = new ByzantineRobustAggregator(0, 3, "absolute", 1.0, 1, 0); - newAgg.setNodes(Set(allPeers)); - - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 100.0 })), - ]; - - const timingNew = await measureAggregation(newAgg, "New ByzantineRobust", peersWithValues); - - // --- TEST 2: Old Aggregator (Percentile Clipping) --- - const oldAgg = new PercentileClippingAggregator(0, 3, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - - const timingOld = await measureAggregation(oldAgg, "Old PercentileClipping", peersWithValues); - - // Both should produce a value closer to 1.0 than 100.0 - expect(timingNew.result).toBeLessThan(50); - expect(timingOld.result).toBeLessThan(50); - - // Print timing comparison - console.log("\n=== Timing Comparison: Simple Outlier Rejection ==="); - console.log(formatTiming([timingNew, timingOld])); + it("throws on invalid constructor parameters", () => { + expect(() => new PercentileClippingAggregator(0, 1, "absolute", 0)).to.throw(); + expect(() => new PercentileClippingAggregator(0, 1, "absolute", 1)).to.throw(); + expect(() => new PercentileClippingAggregator(0, 1, "absolute", -0.1)).to.throw(); }); - it("old aggregator with different percentiles", async () => { - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1", "byzantine2"]; - const allPeers = honestPeers.concat(byzantinePeers); - - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 50.0 })), - ]; - - const testPercentiles = [0.05, 0.1, 0.2, 0.5]; - const timings: TimingResult[] = []; + it("behaves like mean when no clipping occurs", async () => { + const agg = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c"])); - for (const tau of testPercentiles) { - const agg = new PercentileClippingAggregator(0, 5, "absolute", tau); - agg.setNodes(Set(allPeers)); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([2]), 0); + agg.add("c", WeightsContainer.of([3]), 0); - const timing = await measureAggregation(agg, `tau=${tau}`, peersWithValues); - timings.push(timing); + const out = await p; + const arr = await WSIntoArrays(out); - // Should clip towards honest value - expect(timing.result).toBeLessThan(30); - } - - console.log("\n=== Timing Comparison: Old Aggregator with Different Percentiles ==="); - console.log(formatTiming(timings)); + expect(arr[0][0]).to.be.closeTo(2, 1e-6); }); - it("new aggregator with different clipping radii", async () => { - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1", "byzantine2"]; - const allPeers = honestPeers.concat(byzantinePeers); - - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 50.0 })), - ]; + it("reduces influence of a large outlier (heuristically)", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); - const testRadii = [0.5, 1.0, 2.0, 5.0]; - const timings: TimingResult[] = []; + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([1]), 0); + agg.add("d", WeightsContainer.of([100]), 0); - for (const radius of testRadii) { - const agg = new ByzantineRobustAggregator(0, 5, "absolute", radius, 1, 0); - agg.setNodes(Set(allPeers)); + const out = await p; + const v = (await out.weights[0].data())[0]; - const timing = await measureAggregation(agg, `radius=${radius}`, peersWithValues); - timings.push(timing); + const mean = (1 + 1 + 1 + 100) / 4; - // With larger radius, more Byzantine influence - expect(timing.result).toBeGreaterThan(0); - } - - console.log("\n=== Timing Comparison: New Aggregator with Different Clipping Radii ==="); - console.log(formatTiming(timings)); + expect(Math.abs(v - 1)).to.be.lessThan(Math.abs(mean - 1)); }); - it("old aggregator stores previous aggregation state", async () => { - const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.1); - const [peer1, peer2] = ["peer1", "peer2"]; - agg.setNodes(Set([peer1, peer2])); + it("is idempotent when all inputs are identical", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); - // Round 1 - const peersRound1 = [ - { id: peer1, value: 5.0 }, - { id: peer2, value: 5.0 }, - ]; + const p = agg.getPromiseForAggregation(); + ["a", "b", "c", "d"].forEach(id => agg.add(id, WeightsContainer.of([5]), 0)); - const timingRound1 = await measureAggregation(agg, "Round 1", peersRound1); - expect(timingRound1.result).to.equal(5.0); + const out = await p; + const v = (await out.weights[0].data())[0]; - // Round 2 - should center around previous result - const peersRound2 = [ - { id: peer1, value: 10.0 }, - { id: peer2, value: 10.0 }, - ]; + expect(v).to.be.closeTo(5, 1e-6); + }); - const timingRound2 = await measureAggregation(agg, "Round 2", peersRound2); + it("is invariant to client ordering", async () => { + const values = [1, 2, 100]; + const ids1 = ["a", "b", "c"]; + const ids2 = ["c", "a", "b"]; + + const run = async (ids: string[]) => { + const agg = new PercentileClippingAggregator(0, 3, "absolute", 0.5); + agg.setNodes(Set(ids)); + const p = agg.getPromiseForAggregation(); + ids.forEach((id, i) => + agg.add(id, WeightsContainer.of([values[i]]), 0) + ); + return (await (await p).weights[0].data())[0]; + }; - // With centering on previous (5.0), updates to 10.0 should result in something close to 10.0 - expect(timingRound2.result).toBeGreaterThan(5.0); + const out1 = await run(ids1); + const out2 = await run(ids2); - console.log("\n=== Timing Comparison: State Preservation Across Rounds ==="); - console.log(formatTiming([timingRound1, timingRound2])); + expect(out1).to.be.closeTo(out2, 1e-6); }); - it("scalability: larger peer set (10 peers)", async () => { - const numHonest = 7; - const numByzantine = 3; - const honestPeers = Array.from({ length: numHonest }, (_, i) => `honest${i}`); - const byzantinePeers = Array.from({ length: numByzantine }, (_, i) => `byzantine${i}`); - const allPeers = honestPeers.concat(byzantinePeers); - - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 100.0 })), - ]; - - const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 1, 0); - newAgg.setNodes(Set(allPeers)); - const timingNew = await measureAggregation(newAgg, "New (10 peers)", peersWithValues); - - const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - const timingOld = await measureAggregation(oldAgg, "Old (10 peers)", peersWithValues); - - console.log("\n=== Scalability Test: 10 Peers (7 honest, 3 Byzantine) ==="); - console.log(formatTiming([timingNew, timingOld])); - console.log(` Speedup: ${(timingOld.time / timingNew.time).toFixed(2)}x`); - }); + it("lower percentile increases clipping strength", async () => { + const nodes = ["a", "b", "c", "d"]; + const inputs = [1, 1, 1, 100]; - it("iterative refinement: new aggregator with multiple iterations", async () => { - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1", "byzantine2"]; - const allPeers = honestPeers.concat(byzantinePeers); + const aggLow = new PercentileClippingAggregator(0, 4, "absolute", 0.1); + const aggHigh = new PercentileClippingAggregator(0, 4, "absolute", 0.9); - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 50.0 })), - ]; + aggLow.setNodes(Set(nodes)); + aggHigh.setNodes(Set(nodes)); - const iterations = [1, 2, 5, 10]; - const timings: TimingResult[] = []; + const pLow = aggLow.getPromiseForAggregation(); + const pHigh = aggHigh.getPromiseForAggregation(); - for (const iter of iterations) { - const agg = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, iter, 0); - agg.setNodes(Set(allPeers)); + nodes.forEach((n, i) => { + aggLow.add(n, WeightsContainer.of([inputs[i]]), 0); + aggHigh.add(n, WeightsContainer.of([inputs[i]]), 0); + }); - const timing = await measureAggregation(agg, `iterations=${iter}`, peersWithValues); - timings.push(timing); - } + const vLow = (await (await pLow).weights[0].data())[0]; + const vHigh = (await (await pHigh).weights[0].data())[0]; - console.log("\n=== Performance Impact of Iterative Refinement ==="); - console.log(formatTiming(timings)); + expect(Math.abs(vLow - 1)).to.be.lessThan(Math.abs(vHigh - 1)); }); - it("equivalence: new aggregator with 1 iteration matches old aggregator", async () => { - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1", "byzantine2"]; - const allPeers = honestPeers.concat(byzantinePeers); + it("handles zero-norm inputs without NaN", async () => { + const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.5); + agg.setNodes(Set(["a", "b"])); - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 50.0 })), - ]; + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([0]), 0); + agg.add("b", WeightsContainer.of([0]), 0); - const newAggWithOneIter = new ByzantineRobustAggregator(0, 5, "absolute", 1.0, 1, 0); - newAggWithOneIter.setNodes(Set(allPeers)); + const out = await p; + const v = (await out.weights[0].data())[0]; - const oldAgg = new PercentileClippingAggregator(0, 5, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - - const timingNew = await measureAggregation(newAggWithOneIter, "New (maxIter=1)", peersWithValues); - const timingOld = await measureAggregation(oldAgg, "Old (tau=0.1)", peersWithValues); - - console.log("\n=== Equivalence Test: Single Iteration Convergence ==="); - console.log(formatTiming([timingNew, timingOld])); - console.log(` Result difference: ${Math.abs(timingNew.result - timingOld.result).toFixed(4)}`); - console.log(` Speed ratio (new/old): ${(timingNew.time / timingOld.time).toFixed(2)}x`); - - expect(timingNew.result).toBeLessThan(30); - expect(timingOld.result).toBeLessThan(30); - - // With single iteration, results should be very close (within reasonable tolerance) - expect(Math.abs(timingNew.result - timingOld.result)).toBeLessThan(5); + expect(Number.isFinite(v)).to.be.true; }); - it("byzantine robustness: high ratio attack (40% malicious peers)", async () => { - const numHonest = 6; - const numByzantine = 4; - const honestPeers = Array.from({ length: numHonest }, (_, i) => `honest${i}`); - const byzantinePeers = Array.from({ length: numByzantine }, (_, i) => `byzantine${i}`); - const allPeers = honestPeers.concat(byzantinePeers); - - // Honest send gradient 1.0, Byzantine send large outlier - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 100.0 })), // 40% Byzantine pulling result up - ]; - - const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 1, 0); - newAgg.setNodes(Set(allPeers)); - const timingNew = await measureAggregation(newAgg, "New (40% Byzantine)", peersWithValues); - - const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - const timingOld = await measureAggregation(oldAgg, "Old (40% Byzantine)", peersWithValues); - - console.log("\n=== Byzantine Robustness: High Ratio Attack (4/10 = 40% malicious) ==="); - console.log(formatTiming([timingNew, timingOld])); - console.log(` Result gap: new=${timingNew.result.toFixed(4)}, old=${timingOld.result.toFixed(4)}`); - console.log(` Winner: ${timingNew.result < timingOld.result ? "NEW (closer to honest 1.0)" : "OLD (closer to honest 1.0)"}`); - }); + it("respects roundCutoff", async () => { + const agg = new PercentileClippingAggregator(1, 1, "absolute", 0.5); + agg.setNodes(Set(["a"])); - it("byzantine robustness: gradient poisoning attack (crafted gradients)", async () => { - // Byzantine gradient attack: send gradients designed to manipulate centroid - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1", "byzantine2"]; - const allPeers = honestPeers.concat(byzantinePeers); - - // Honest: standard gradient 1.0 - // Byzantine: crafted to move result away from honest consensus - // Poisoning strategy: send same large value to coordinate - const peersWithValues = [ - ...honestPeers.map(id => ({ id, value: 1.0 })), - ...byzantinePeers.map(id => ({ id, value: 10.0 })), - ]; - - const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 2.0, 5, 0); - newAgg.setNodes(Set(allPeers)); - const timingNew = await measureAggregation(newAgg, "New (5 iterations)", peersWithValues); - - const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.2); - oldAgg.setNodes(Set(allPeers)); - const timingOld = await measureAggregation(oldAgg, "Old (tau=0.2)", peersWithValues); - - console.log("\n=== Gradient Poisoning Attack (coordinated Byzantine values) ==="); - console.log(formatTiming([timingNew, timingOld])); - console.log(` Result gap: new=${timingNew.result.toFixed(4)}, old=${timingOld.result.toFixed(4)}`); - console.log(` Expected honest value: 1.0000`); - console.log(` Winner: ${Math.abs(timingNew.result - 1.0) < Math.abs(timingOld.result - 1.0) ? "NEW (closer to honest)" : "OLD (closer to honest)"}`); - }); + const p0 = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([10]), 0); + const v0 = (await (await p0).weights[0].data())[0]; + expect(v0).to.equal(10); - it("byzantine robustness: adaptive multi-round attack", async () => { - // Multi-round attack: Byzantine adapts based on previous aggregation - // Round 1: test the aggregator behavior - // Round 2: Byzantine sends crafted response gradient - - const honestPeers = ["honest1", "honest2", "honest3"]; - const byzantinePeers = ["byzantine1"]; - const allPeers = honestPeers.concat(byzantinePeers); - - // Round 1 setup - const round1Values = [ - ...honestPeers.map(id => ({ id, value: 5.0 })), - ...byzantinePeers.map(id => ({ id, value: 5.0 })), // Byzantine cooperates round 1 - ]; - - const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 1.0, 5, 0); - newAgg.setNodes(Set(allPeers)); - const timing1New = await measureAggregation(newAgg, "New Round 1", round1Values); - - const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - const timing1Old = await measureAggregation(oldAgg, "Old Round 1", round1Values); - - // Round 2: Byzantine launches adaptive attack - const round2Values = [ - ...honestPeers.map(id => ({ id, value: 10.0 })), // Honest update - ...byzantinePeers.map(id => ({ id, value: 50.0 })), // Byzantine aggressive attack in round 2 - ]; - - const timing2New = await measureAggregation(newAgg, "New Round 2 (attack)", round2Values); - const timing2Old = await measureAggregation(oldAgg, "Old Round 2 (attack)", round2Values); - - console.log("\n=== Adaptive Multi-Round Attack ==="); - console.log("Round 1 (cooperation):"); - console.log(formatTiming([timing1New, timing1Old])); - console.log("\nRound 2 (adaptive Byzantine attack):"); - console.log(formatTiming([timing2New, timing2Old])); - console.log(` New result: ${timing2New.result.toFixed(4)} (expected ~10.0)`); - console.log(` Old result: ${timing2Old.result.toFixed(4)} (expected ~10.0)`); - console.log(` Winner: ${Math.abs(timing2New.result - 10.0) < Math.abs(timing2Old.result - 10.0) ? "NEW (better rejects attack)" : "OLD (better rejects attack)"}`); + const p2 = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([20]), 2); + const v2 = (await (await p2).weights[0].data())[0]; + expect(v2).to.equal(20); }); - it("heterogeneous gradients: realistic multi-tensor federated model", async () => { - // Realistic FL scenario: aggregate weights across multiple layers/tensors with different scales - // Layer 1: large values (e.g., from first dense layer) - // Layer 2: small values (e.g., from final output layer) - - const honestPeers = ["honest1", "honest2", "honest3", "honest4"]; - const byzantinePeers = ["byzantine1"]; - const allPeers = honestPeers.concat(byzantinePeers); - - // Create multi-tensor contributions - const createHeterogeneousGradient = (baseValue: number): WeightsContainer => { - return WeightsContainer.of([baseValue * 100, baseValue * 10, baseValue]); // Different scales - }; - - const newAgg = new ByzantineRobustAggregator(0, allPeers.length, "absolute", 5.0, 3, 0); - newAgg.setNodes(Set(allPeers)); - const promiseNew = newAgg.getPromiseForAggregation(); - const startNew = performance.now(); + it("can fail under strong Byzantine attack (documented limitation)", async () => { + const agg = new PercentileClippingAggregator(0, 4, "absolute", 0.5); + agg.setNodes(Set(["a", "b", "c", "d"])); - honestPeers.forEach(id => { - newAgg.add(id, createHeterogeneousGradient(1.0), newAgg.round); - }); - byzantinePeers.forEach(id => { - newAgg.add(id, createHeterogeneousGradient(100.0), newAgg.round); // Byzantine - }); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + agg.add("c", WeightsContainer.of([50]), 0); + agg.add("d", WeightsContainer.of([100]), 0); - const resultNew = await promiseNew; - const timeNew = performance.now() - startNew; - const arrNew = await WSIntoArrays(resultNew); + const out = await p; + const v = (await out.weights[0].data())[0]; - const oldAgg = new PercentileClippingAggregator(0, allPeers.length, "absolute", 0.1); - oldAgg.setNodes(Set(allPeers)); - const promiseOld = oldAgg.getPromiseForAggregation(); - const startOld = performance.now(); + // We don't assert correctness — only that it doesn't explode + expect(Number.isFinite(v)).to.be.true; + }); - honestPeers.forEach(id => { - oldAgg.add(id, createHeterogeneousGradient(1.0), oldAgg.round); - }); - byzantinePeers.forEach(id => { - oldAgg.add(id, createHeterogeneousGradient(100.0), oldAgg.round); // Byzantine - }); + it("reset state when starting fresh aggregator", async () => { + const run = async () => { + const agg = new PercentileClippingAggregator(0, 2, "absolute", 0.5); + agg.setNodes(Set(["a", "b"])); + const p = agg.getPromiseForAggregation(); + agg.add("a", WeightsContainer.of([1]), 0); + agg.add("b", WeightsContainer.of([1]), 0); + return (await (await p).weights[0].data())[0]; + }; - const resultOld = await promiseOld; - const timeOld = performance.now() - startOld; - const arrOld = await WSIntoArrays(resultOld); - - console.log("\n=== Heterogeneous Gradients: Multi-Tensor Federated Model ==="); - console.log("Gradient structure: [layer1=100×value, layer2=10×value, layer3=value]"); - console.log("Honest peers send: [100, 10, 1]"); - console.log("Byzantine sends: [10000, 1000, 100]"); - console.log(`\nNew aggregator (${timeNew.toFixed(2)}ms):`); - console.log(` Layer 1: ${arrNew[0][0].toFixed(2)} (expected ~100)`); - console.log(` Layer 2: ${arrNew[0][1].toFixed(2)} (expected ~10)`); - console.log(` Layer 3: ${arrNew[0][2].toFixed(2)} (expected ~1)`); - console.log(`\nOld aggregator (${timeOld.toFixed(2)}ms):`); - console.log(` Layer 1: ${arrOld[0][0].toFixed(2)} (expected ~100)`); - console.log(` Layer 2: ${arrOld[0][1].toFixed(2)} (expected ~10)`); - console.log(` Layer 3: ${arrOld[0][2].toFixed(2)} (expected ~1)`); - - // Check relative error - const newError = Math.abs((arrNew[0][0] - 100) / 100) + Math.abs((arrNew[0][1] - 10) / 10) + Math.abs(arrNew[0][2] - 1); - const oldError = Math.abs((arrOld[0][0] - 100) / 100) + Math.abs((arrOld[0][1] - 10) / 10) + Math.abs(arrOld[0][2] - 1); - console.log(`\nTotal relative error: new=${newError.toFixed(3)}, old=${oldError.toFixed(3)}`); - console.log(`Winner: ${newError < oldError ? "NEW (better handles multi-scale)" : "OLD (better handles multi-scale)"}`); + expect(await run()).to.be.closeTo(await run(), 1e-6); }); -}); + +}); \ No newline at end of file diff --git a/discojs/src/aggregator/percentile_clipping.ts b/discojs/src/aggregator/percentile_clipping.ts index 88a28b957..5f2c6a575 100644 --- a/discojs/src/aggregator/percentile_clipping.ts +++ b/discojs/src/aggregator/percentile_clipping.ts @@ -6,11 +6,16 @@ import { WeightsContainer, client } from "../index.js"; import { aggregation } from "../index.js"; /** - * Old Byzantine-robust aggregator using Percentile-based Clipping + * Percentile-based clipping aggregator. * - * This class implements a gradient aggregation rule that clips updates based on a - * percentile-computed threshold (tau) to mitigate the influence of Byzantine nodes. - * Unlike the iterative Centered Clipping approach, this uses a single-pass percentile-based clipping. + * This method clips updates using a threshold τ computed as a percentile + * of update norms. Unlike Centered Clipping, this is a single-pass heuristic + * and does not provide formal Byzantine robustness guarantees. + * + * Use Case: + * Suitable for mitigating mild outliers or noisy updates when most clients + * are honest. Not suitable for adversarial Byzantine settings, as the + * percentile threshold can be influenced by malicious clients. * * Algorithm: * 1. Center all peer weights w.r.t. the previous aggregation @@ -52,23 +57,20 @@ export class PercentileClippingAggregator extends MultiRoundAggregator { override aggregate(): WeightsContainer { const currentContributions = this.contributions.get(0); - if (!currentContributions) throw new Error("aggregating without any contribution"); + if (!currentContributions || currentContributions.size === 0) throw new Error("aggregating without any contribution"); this.log(AggregationStep.AGGREGATE); - // Step 1: Get the centering reference (previous aggregation or zero vector) + // Step 1: Get the centering reference (previous aggregation or initial avg vector) let centerReference: WeightsContainer; if (this.prevAggregate) { - centerReference = this.prevAggregate; + centerReference = this.prevAggregate.map(t => tf.clone(t)); } else { - // Use shape of the first contribution to create zero vector - const first = currentContributions.values().next(); - if (first.done) throw new Error("zero sized contribution"); - centerReference = first.value.map((t: tf.Tensor) => tf.zerosLike(t)); + centerReference = aggregation.avg(currentContributions.values()).map(t => tf.clone(t)); } // Step 2: Center the weights with respect to the reference - const centeredWeights = Array.from(currentContributions.values()).map(w => + const centeredWeights = Array.from(currentContributions.values()).map(w => w.sub(centerReference) ); @@ -81,19 +83,23 @@ export class PercentileClippingAggregator extends MultiRoundAggregator { // Step 5: Clip weights based on tau // Each peer gets one scale factor based on their Frobenius norm const clippedWeights = centeredWeights.map((w, peerIdx) => { - const scaleFactor = Math.min(1, tau / normArray[peerIdx]); + //const scaleFactor = Math.min(1, tau / normArray[peerIdx]); + const norm = normArray[peerIdx]; + const safeNorm = Math.max(norm, 1e-12); + + const scaleFactor = Math.min(1, tau / safeNorm); return w.map((t: tf.Tensor) => t.mul(scaleFactor)); }); + centeredWeights.forEach(w => w.dispose()); + // Step 6: Average the clipped weights and add back the reference const clippedAvg = aggregation.avg(clippedWeights); const result = centerReference.add(clippedAvg); + centerReference.dispose(); clippedWeights.forEach(w => w.dispose()); clippedAvg.dispose(); - if (!this.prevAggregate) { - centerReference.dispose(); - } // Step 7: Store result for next round this.prevAggregate = result; @@ -102,7 +108,10 @@ export class PercentileClippingAggregator extends MultiRoundAggregator { private computePercentile(array: number[], percentile: number): number { // Linear interpolation for percentile calculation - const sorted = [...array].sort((a, b) => a - b); + const clean = array.filter(Number.isFinite); + if (clean.length === 0) return 0; + + const sorted = [...clean].sort((a, b) => a - b); const pos = (sorted.length - 1) * percentile; const base = Math.floor(pos); const rest = pos - base; @@ -121,12 +130,11 @@ export class PercentileClippingAggregator extends MultiRoundAggregator { function frobeniusNorm(w: WeightsContainer): number { // Computes the Frobenius (L2) norm of all tensors in a WeightsContainer - // sqrt(sum of all squared elements across all tensors) return tf.tidy(() => { - const norms: tf.Scalar[] = w.weights.map(t => tf.sum(tf.square(t))); - const total = norms.reduce((a, b) => tf.add(a, b)); - const result = tf.sqrt(total); - const value = result.dataSync()[0]; - return value; + const total = w.weights + .map(t => tf.sum(tf.square(t))) + .reduce((a, b) => tf.add(a, b), tf.scalar(0)); + + return tf.sqrt(total).dataSync()[0]; }); } diff --git a/package-lock.json b/package-lock.json index 618fb71a6..9e5450d31 100644 --- a/package-lock.json +++ b/package-lock.json @@ -64,6 +64,7 @@ "devDependencies": { "@tensorflow/tfjs-node": "4", "@types/simple-peer": "9", + "fast-check": "3", "nodemon": "3", "ts-node": "10" } @@ -7917,6 +7918,29 @@ ], "license": "MIT" }, + "node_modules/fast-check": { + "version": "3.23.2", + "resolved": "https://registry.npmjs.org/fast-check/-/fast-check-3.23.2.tgz", + "integrity": "sha512-h5+1OzzfCC3Ef7VbtKdcv7zsstUQwUDlYpUTvjeUsJAssPgLn7QzbboPtL5ro04Mq0rPOsMzl7q5hIbRs2wD1A==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT", + "dependencies": { + "pure-rand": "^6.1.0" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/fast-deep-equal": { "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", @@ -9646,6 +9670,7 @@ "os": [ "android" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9666,6 +9691,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9686,6 +9712,7 @@ "os": [ "darwin" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9706,6 +9733,7 @@ "os": [ "freebsd" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9726,6 +9754,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9746,6 +9775,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9766,6 +9796,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9786,6 +9817,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9806,6 +9838,7 @@ "os": [ "linux" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9826,6 +9859,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -9846,6 +9880,7 @@ "os": [ "win32" ], + "peer": true, "engines": { "node": ">= 12.0.0" }, @@ -11504,6 +11539,23 @@ "node": ">=6" } }, + "node_modules/pure-rand": { + "version": "6.1.0", + "resolved": "https://registry.npmjs.org/pure-rand/-/pure-rand-6.1.0.tgz", + "integrity": "sha512-bVWawvoZoBYpp6yIoQtQXHZjmz35RSVHnUOTefl8Vcjr8snTPY1wnpSPMWekcFwbxI6gtmT7rSYPFvz71ldiOA==", + "dev": true, + "funding": [ + { + "type": "individual", + "url": "https://github.com/sponsors/dubzzz" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fast-check" + } + ], + "license": "MIT" + }, "node_modules/qs": { "version": "6.14.1", "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", @@ -14678,7 +14730,7 @@ "@tsconfig/node20": "20", "@types/d3": "7", "@types/jsdom": "27", - "@vitejs/plugin-vue": "^6.0.4", + "@vitejs/plugin-vue": "6", "@vue/test-utils": "2", "@vue/tsconfig": "0.8", "canvas": "3", From 3f496cd2f285ce2be2b57cc78150a1da414a1159 Mon Sep 17 00:00:00 2001 From: mina5rovic Date: Tue, 14 Apr 2026 20:09:07 +0200 Subject: [PATCH 3/3] Fix casting error --- discojs/src/aggregator/byzantine.ts | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/discojs/src/aggregator/byzantine.ts b/discojs/src/aggregator/byzantine.ts index f61346124..2d5c25e77 100644 --- a/discojs/src/aggregator/byzantine.ts +++ b/discojs/src/aggregator/byzantine.ts @@ -165,11 +165,8 @@ export class ByzantineRobustAggregator extends MultiRoundAggregator { function euclideanNorm(w: WeightsContainer): tf.Scalar { // Computes the Euclidean (L2) norm of all tensors in a WeightsContainer by summing the squares of their elements and taking the square root. return tf.tidy(() => { - const zero = tf.scalar(0); - - const total = w.weights - .map(t => tf.sum(tf.square(t)) as tf.Scalar) - .reduce((a, b) => tf.add(a, b) as tf.Scalar, zero); - return tf.sqrt(total); + const squaredSums = w.weights.map(t => tf.sum(tf.square(t))); + const total = tf.addN(squaredSums); + return tf.sqrt(total) as tf.Scalar; }); } \ No newline at end of file