-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathbatch_test_gemm.sh
More file actions
executable file
·61 lines (54 loc) · 2.4 KB
/
batch_test_gemm.sh
File metadata and controls
executable file
·61 lines (54 loc) · 2.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#!/bin/bash
export PYTORCH_ROCM_ARCH=gfx950
problems=(
# Level 1 - Basic GEMM
"/root/agent/kernel-agent/datasets/KernelBench/level1/1_Square_matrix_multiplication_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/2_Standard_matrix_multiplication_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/6_Matmul_with_large_K_dimension_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/7_Matmul_with_small_K_dimension_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/8_Matmul_with_irregular_shapes_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/9_Tall_skinny_matrix_multiplication_.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/16_Matmul_with_transposed_A.py"
"/root/agent/kernel-agent/datasets/KernelBench/level1/17_Matmul_with_transposed_B.py"
# Level 2 - Fused GEMM
"/root/agent/kernel-agent/datasets/KernelBench/level2/12_Gemm_Multiply_LeakyReLU.py"
"/root/agent/kernel-agent/datasets/KernelBench/level2/29_Matmul_Mish_Mish.py"
"/root/agent/kernel-agent/datasets/KernelBench/level2/40_Matmul_Scaling_ResidualAdd.py"
"/root/agent/kernel-agent/datasets/KernelBench/level2/76_Gemm_Add_ReLU.py"
"/root/agent/kernel-agent/datasets/KernelBench/level2/86_Matmul_Divide_GELU.py"
)
echo "================================================"
echo "BATCH GEMM TEST - $(date)"
echo "================================================"
total=0
passed=0
speedups=()
for problem in "${problems[@]}"; do
name=$(basename "$problem" .py)
echo -e "\n--- Testing: $name ---"
# Run single attempt
result=$(python run_loop.py --problem "$problem" --max-attempts 1 2>&1)
# Extract status
if echo "$result" | grep -q "Accuracy: PASS"; then
speedup=$(echo "$result" | grep "Best speedup:" | awk '{print $3}')
echo "✓ PASS - Speedup: $speedup"
passed=$((passed + 1))
speedups+=("$speedup")
else
error=$(echo "$result" | grep -E "NaN|FAIL|Compile error" | head -1)
echo "✗ FAIL - $error"
speedups+=("0.00x")
fi
total=$((total + 1))
done
echo ""
echo "================================================"
echo "BATCH SUMMARY"
echo "================================================"
echo "Passed: $passed / $total"
echo ""
echo "Individual results:"
for i in "${!problems[@]}"; do
name=$(basename "${problems[$i]}" .py)
echo " $name: ${speedups[$i]}"
done