diff --git a/MLExamples/TinyTransformer/README.md b/MLExamples/TinyTransformer/README.md index bb1bd657..0768f448 100644 --- a/MLExamples/TinyTransformer/README.md +++ b/MLExamples/TinyTransformer/README.md @@ -42,7 +42,7 @@ This workshop follows a progressive optimization methodology with four implement | **V3 Triton** | 156,652 | 52.3 | 51.3 | 0.6 | 0.4 | 916.2 | **3.13x** | | **V4 Ultra** | 157,169 | 52.1 | 51.1 | 0.6 | 0.4 | 916.5 | **3.14x** | -**See [PERFORMANCE_RESULTS.md](PERFORMANCE_RESULTS.md) for complete analysis** +Performance figures for the small and medium configurations are summarized in the tables above and in [Key Performance Insights](#key-performance-insights). ### Profiling Tools Progression @@ -70,15 +70,15 @@ Each version introduces additional profiling capabilities: ## Quick Start -### 0. Set up environment -On the training cluster's compute node, the required environment may be set up using the following -commands: +### 0. Set up and verify environment +On the training cluster's compute node, load the modules (adjust names/versions for your site): ```bash module load rocm pytorch openmpi rocprofiler-compute rocprofiler-systems/develop ``` -### 1. Verify Environment +Then confirm ROCm, PyTorch, and the GPU(s) are setup correctly: + ```bash # Check ROCm installation rocminfo @@ -90,7 +90,7 @@ rocm-smi python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA Available: {torch.cuda.is_available()}'); print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"N/A\"}')" ``` -### 2. Run Version 1 (Baseline) - 5 minutes +### 1. Run Version 1 (Baseline) - 5 minutes ```bash cd version1_pytorch_baseline/ python tiny_llama_v1.py --batch-size 8 --seq-len 128 --num-steps 20 @@ -105,7 +105,7 @@ For a deeper analysis with the PyTorch profiler, and visualizing the output in T please follow the workshop exercises in [version1_pytorch_baseline/README.md](https://github.com/amd/HPCTrainingExamples/tree/main/MLExamples/TinyTransformer/version1_pytorch_baseline#workshop-exercises). -### 3. Run Version 2 (Fused) - 5 minutes +### 2. Run Version 2 (Fused) - 5 minutes ```bash cd version2_pytorch_fused python tiny_llama_v2.py --batch-size 8 --seq-len 128 --num-steps 30 @@ -148,7 +148,7 @@ with `rocprof-sys` using the command below: ```bash rocprof-sys-run --profile --trace -- python tiny_llama_v3.py --batch-size 8 --seq-len 128 --num-steps 30 ``` -View the trace at [https://ui.perfetto.dev](https://ui.perfetto.dev). +View the trace with [https://ui.perfetto.dev](https://ui.perfetto.dev). ### 4. Run Version 4 (Ultra optimized) - 5 minutes ```bash @@ -174,75 +174,33 @@ cd version3_triton/exercises/performance_debugging/ ## Directory Structure +Layout under `MLExamples/TinyTransformer/` in this repository: + ``` -ai-workshop-training/ - README.md # This overview - setup/ # Environment and prerequisites - environment_setup.md # Detailed setup instructions - environment_setup.sh # Automated setup script - requirements.txt # Python dependencies - validation_scripts/ # Environment validation - test_environment.py # Comprehensive environment test - test_rocm_installation.py # ROCm stack validation - test_profiling_tools.py # Profiling tools validation - version1_pytorch_baseline/ # Standard PyTorch implementation - README.md # Detailed guided instructions - tiny_llama_v1.py # Enhanced baseline implementation - run_pytorch_profiler.py # PyTorch profiler integration - run_deepspeed_flops.py # DeepSpeed FLOPS profiler - run_all_profilers.sh # Orchestrated profiling script - exercises/ # Hands-on exercises and analysis - exercise_1_baseline_analysis.md - exercise_2_memory_analysis.md - exercise_3_bottleneck_identification.md - version2_pytorch_fused/ # Fused operations optimization - README.md # Fusion optimization guide - tiny_llama_v2.py # Fused implementation - run_pytorch_profiler.py # Enhanced PyTorch profiling - run_deepspeed_flops.py # FLOPS analysis - run_rocprofv3.sh # rocprofv3 integration - run_rocprof_sys.sh # System profiling - run_rocprof_compute.sh # Kernel-level profiling - run_all_profilers.sh # Complete profiling suite - exercises/ # Advanced profiling exercises - exercise_1_fusion_analysis.md - exercise_2_flash_attention.md - exercise_3_rocm_tools_intro.md - version3_triton/ # Triton kernel integration - README.md # Triton optimization guide - tiny_llama_v3.py # Triton-enhanced implementation - triton_kernels.py # Custom Triton kernels - run_pytorch_profiler.py # Framework profiling - run_deepspeed_flops.py # Computational analysis - run_rocprofv3.sh # Legacy profiling - run_rocprof_sys.sh # System monitoring - run_rocprof_compute.sh # Advanced kernel analysis - run_all_profilers.sh # Complete profiling - exercises/ # Triton development exercises - exercise_1_triton_basics.md - exercise_2_custom_kernels.md - exercise_3_performance_tuning.md - version4_pytorch_sdpa/ # Ultra-fused implementation - README.md # Ultra-optimization guide - tiny_llama_v4.py # Ultra-fused implementation - triton_ultra_kernels.py # Ultra-fused kernels - [profiling scripts] # Complete profiling suite - exercises/ # Advanced optimization - exercise_1_ultra_fusion.md - exercise_2_register_optimization.md - exercise_3_production_deployment.md - analysis_tools/ # Performance analysis utilities - compare_versions.py # Cross-version performance comparison - roofline_analysis.py # Roofline model implementation - performance_dashboard.py # Interactive performance dashboard - regression_tester.py # Automated regression testing - report_generator.py # Comprehensive report generation - slides/ # Presentation materials - luka_presentation_materials/ # AI workshop slides - workshop_overview.pptx - profiling_methodology.pptx - optimization_techniques.pptx - results_analysis.pptx +TinyTransformer/ +├── README.md +├── TINY_LLAMA_ARCHITECTURE.md +├── TECHNICAL_APPENDICES.md +├── version1_pytorch_baseline/ +│ ├── tiny_llama_v1.py +│ ├── run_pytorch_profiler.py, run_deepspeed_flops.py, run_all_profilers.sh +│ ├── run_*.sh, launch_performance_study.sh +│ └── exercises/ +│ ├── exercise_1_baseline_analysis.md +│ ├── exercise_2_memory_analysis.md +│ └── exercise_3_bottleneck_identification.md +├── version2_pytorch_fused/ +│ ├── tiny_llama_v2.py +│ ├── run_*.py, run_*.sh, launch_performance_study.sh +│ └── exercises/ +├── version3_triton/ +│ ├── tiny_llama_v3.py, run_triton_profiling.py, run_rocprof_triton.sh +│ ├── launch_performance_study.sh +│ └── exercises/ (including performance_debugging/) +└── version4_pytorch_sdpa/ + ├── tiny_llama_v4.py, run_ultra_profiling.py, launch_performance_study.sh + └── exercises/ + └── exercise1_ultra_fusion.md ``` ## Workshop Execution Timeline @@ -345,10 +303,10 @@ Developed for the CASTIEL AI Workshop (October 16, 2024) by HPC/AI performance e ## License -MIT License - See LICENSE file for details +MIT License — see the repository [`LICENSE.md`](../../LICENSE.md) at the git root of **HPCTrainingExamples**. --- -**Ready to start profiling? Begin with the [Environment Setup Guide](setup/environment_setup.md)** +**Ready to start profiling?** Begin with [Quick Start](#quick-start) (environment modules and first runs) above. diff --git a/MLExamples/TinyTransformer/version1_pytorch_baseline/PYTORCH_BASELINE_WORKSHOP_WALKTHROUGH.md b/MLExamples/TinyTransformer/version1_pytorch_baseline/PYTORCH_BASELINE_WORKSHOP_WALKTHROUGH.md index 59d84818..ef30687b 100644 --- a/MLExamples/TinyTransformer/version1_pytorch_baseline/PYTORCH_BASELINE_WORKSHOP_WALKTHROUGH.md +++ b/MLExamples/TinyTransformer/version1_pytorch_baseline/PYTORCH_BASELINE_WORKSHOP_WALKTHROUGH.md @@ -807,7 +807,7 @@ Understanding the available options: ```bash --enable-pytorch-profiler # Enable PyTorch profiler --profile-dir ./profiles # Directory for profile output ---profile-memory # Include memory profiling +--enable-memory-profiling # CUDA alloc records in profiler trace (TensorBoard memory views) --profile-operators # Detailed operator profiling --profile-steps 5 # Number of steps to profile ``` @@ -934,7 +934,7 @@ python3 tiny_llama_v1.py \ --num-steps 20 \ --enable-pytorch-profiler \ --profile-dir ./pytorch_profiles \ - --profile-memory + --enable-memory-profiling ``` **With DeepSpeed FLOPS Profiler:** @@ -1472,7 +1472,7 @@ python3 tiny_llama_v1.py \ --seq-len 128 \ --num-steps 15 \ --enable-pytorch-profiler \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_analysis_bs4 # Batch size 8 @@ -1481,7 +1481,7 @@ python3 tiny_llama_v1.py \ --seq-len 128 \ --num-steps 15 \ --enable-pytorch-profiler \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_analysis_bs8 # Batch size 16 @@ -1490,7 +1490,7 @@ python3 tiny_llama_v1.py \ --seq-len 128 \ --num-steps 15 \ --enable-pytorch-profiler \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_analysis_bs16 ``` @@ -1618,7 +1618,7 @@ python3 tiny_llama_v1.py \ --batch-size 8 \ --seq-len 64 \ --num-steps 10 \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_seq64 # Sequence length 128 (baseline) @@ -1626,7 +1626,7 @@ python3 tiny_llama_v1.py \ --batch-size 8 \ --seq-len 128 \ --num-steps 10 \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_seq128 # Sequence length 256 @@ -1634,7 +1634,7 @@ python3 tiny_llama_v1.py \ --batch-size 8 \ --seq-len 256 \ --num-steps 10 \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_seq256 # Sequence length 512 (might OOM - use smaller batch if needed) @@ -1642,7 +1642,7 @@ python3 tiny_llama_v1.py \ --batch-size 4 \ --seq-len 512 \ --num-steps 5 \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_seq512 ``` @@ -1694,7 +1694,7 @@ python3 tiny_llama_v1.py \ --seq-len 128 \ --num-steps 10 \ --enable-pytorch-profiler \ - --profile-memory \ + --enable-memory-profiling \ --profile-operators \ --profile-dir ./memory_hotspots ``` diff --git a/MLExamples/TinyTransformer/version1_pytorch_baseline/README.md b/MLExamples/TinyTransformer/version1_pytorch_baseline/README.md index 7c1f20d3..4c44bf43 100644 --- a/MLExamples/TinyTransformer/version1_pytorch_baseline/README.md +++ b/MLExamples/TinyTransformer/version1_pytorch_baseline/README.md @@ -398,7 +398,7 @@ python tiny_llama_v1.py \ --batch-size 8 \ --seq-len 128 \ --enable-pytorch-profiler \ - --profile-memory \ + --enable-memory-profiling \ --profile-dir ./memory_analysis # Generate memory timeline visualization diff --git a/MLExamples/TinyTransformer/version1_pytorch_baseline/exercises/exercise_1_baseline_analysis.md b/MLExamples/TinyTransformer/version1_pytorch_baseline/exercises/exercise_1_baseline_analysis.md index 1cb9b199..9e934a16 100644 --- a/MLExamples/TinyTransformer/version1_pytorch_baseline/exercises/exercise_1_baseline_analysis.md +++ b/MLExamples/TinyTransformer/version1_pytorch_baseline/exercises/exercise_1_baseline_analysis.md @@ -8,7 +8,7 @@ Establish baseline performance metrics for Tiny LLaMA V1 and understand the prof ### Prerequisites -- Completed environment setup from `../setup/` +- Environment ready for PyTorch + ROCm (see workshop `README.md` [Quick Start](../../README.md#quick-start)) - Verified environment with validation scripts ### Duration diff --git a/MLExamples/TinyTransformer/version1_pytorch_baseline/tiny_llama_v1.py b/MLExamples/TinyTransformer/version1_pytorch_baseline/tiny_llama_v1.py index defb8dca..ecf3e27e 100644 --- a/MLExamples/TinyTransformer/version1_pytorch_baseline/tiny_llama_v1.py +++ b/MLExamples/TinyTransformer/version1_pytorch_baseline/tiny_llama_v1.py @@ -22,7 +22,7 @@ python tiny_llama_v1.py --enable-pytorch-profiler --profile-dir ./profiles # With memory profiling - python tiny_llama_v1.py --enable-pytorch-profiler --profile-memory + python tiny_llama_v1.py --enable-pytorch-profiler --enable-memory-profiling # Complete profiling suite python tiny_llama_v1.py --enable-all-profiling --profile-dir ./complete_analysis @@ -117,6 +117,7 @@ def reset(self): self.metrics = { 'training_speed': [], 'memory_usage': [], + 'gpu_peak_memory_mb': [], 'gpu_utilization': [], 'loss_values': [], 'batch_times': [], @@ -141,8 +142,14 @@ def end_timing(self) -> float: self.start_time = None return elapsed - def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, float]): - """Record metrics for a training batch.""" + def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, float], + gpu_peak_memory_mb: Optional[float] = None): + """Record metrics for a training batch. + + gpu_peak_memory_mb: per-step peak device memory (bytes->MB) from + torch.cuda.max_memory_allocated() after reset_peak_memory_stats() at + step start; captures transient activations during backward. + """ self.total_samples += batch_size self.metrics['loss_values'].append(loss) self.metrics['batch_times'].append(timings.get('total', 0)) @@ -154,6 +161,8 @@ def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, if torch.cuda.is_available(): memory_mb = torch.cuda.memory_allocated() / (1024**2) self.metrics['memory_usage'].append(memory_mb) + if gpu_peak_memory_mb is not None: + self.metrics['gpu_peak_memory_mb'].append(gpu_peak_memory_mb) # Training speed (samples per second) if timings.get('total', 0) > 0: @@ -175,7 +184,13 @@ def get_summary(self) -> Dict[str, Any]: 'avg_optimizer_time': np.mean(self.metrics['optimizer_times']), } - if self.metrics['memory_usage']: + if self.metrics['gpu_peak_memory_mb']: + summary.update({ + 'peak_memory_mb': max(self.metrics['gpu_peak_memory_mb']), + 'avg_peak_memory_mb': np.mean(self.metrics['gpu_peak_memory_mb']), + 'avg_memory_mb': np.mean(self.metrics['memory_usage']) if self.metrics['memory_usage'] else 0.0, + }) + elif self.metrics['memory_usage']: summary.update({ 'peak_memory_mb': max(self.metrics['memory_usage']), 'avg_memory_mb': np.mean(self.metrics['memory_usage']) @@ -647,6 +662,9 @@ def train_tiny_llama( print("=" * 70) for step in range(num_steps): + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + # Start batch timing batch_timings = {} monitor.start_timing() @@ -692,8 +710,14 @@ def train_tiny_llama( # Total batch time batch_timings['total'] = sum(batch_timings.values()) + peak_mb: Optional[float] = None + if torch.cuda.is_available(): + torch.cuda.synchronize() + peak_mb = torch.cuda.max_memory_allocated() / (1024**2) + # Record metrics - monitor.record_batch_metrics(batch_size, loss.item(), batch_timings) + monitor.record_batch_metrics(batch_size, loss.item(), batch_timings, + gpu_peak_memory_mb=peak_mb) # PyTorch profiler step if pytorch_profiler: @@ -702,12 +726,13 @@ def train_tiny_llama( # Progress logging if step % 10 == 0: speed = batch_size / batch_timings['total'] if batch_timings['total'] > 0 else 0 - memory_mb = torch.cuda.memory_allocated() / (1024**2) if torch.cuda.is_available() else 0 + live_mb = torch.cuda.memory_allocated() / (1024**2) if torch.cuda.is_available() else 0 + peak_log = f"{peak_mb:6.1f}" if peak_mb is not None else " n/a" print(f"Step {step:3d}/{num_steps} | " f"Loss: {loss.item():.4f} | " f"Speed: {speed:5.1f} samples/sec | " - f"Memory: {memory_mb:6.1f} MB | " + f"Peak: {peak_log} MB | Live: {live_mb:6.1f} MB | " f"Time: {batch_timings['total']*1000:5.1f}ms") print("=" * 70) @@ -743,7 +768,9 @@ def train_tiny_llama( print(f" Final loss: {summary.get('avg_loss', 0):.4f}") if 'peak_memory_mb' in summary: - print(f" Peak memory usage: {summary['peak_memory_mb']:.1f} MB") + print(f" Peak device memory (high-water per step): {summary['peak_memory_mb']:.1f} MB") + if 'avg_peak_memory_mb' in summary: + print(f" Avg peak per step: {summary['avg_peak_memory_mb']:.1f} MB") # Save performance data if profiler_config.profile_dir: @@ -771,6 +798,7 @@ def train_tiny_llama( } profile_path = Path(profiler_config.profile_dir) / "performance_summary.json" + profile_path.parent.mkdir(parents=True, exist_ok=True) with open(profile_path, 'w') as f: json.dump(profile_data, f, indent=2) diff --git a/MLExamples/TinyTransformer/version2_pytorch_fused/README.md b/MLExamples/TinyTransformer/version2_pytorch_fused/README.md index 60e73ffe..9329edca 100644 --- a/MLExamples/TinyTransformer/version2_pytorch_fused/README.md +++ b/MLExamples/TinyTransformer/version2_pytorch_fused/README.md @@ -474,73 +474,116 @@ def calculate_arithmetic_intensity(operation_type, batch_size, seq_len, hidden_d ## Workshop Exercises +**Host–GPU affinity:** On multi-NUMA systems, it is crucial to pin the CPU cores, local memory, and GPU correctly. Poor affinity increases cross-socket traffic significantly causing misleading timings. +A quick way to pin the Python process to the first CPU and GPU is: +```bash +ROCR_VISIBLE_DEVICES=0 numactl -C 0 -m 0 +``` +See the [Affinity exercises](https://github.com/amd/HPCTrainingExamples/tree/main/Affinity) for how to discover your topology and set the affinity accordingly. + + ### Exercise 1: Kernel Fusion Analysis -**Objective**: Compare baseline vs. fused implementations to quantify fusion benefits. +**Objective**: Compare the unfused, fused, and compiled configurations on the same `tiny_llama_v2.py` code path to quantify the benefits of fusion. + + +#### Step 1: Three-way throughput comparison + +From `version2_pytorch_fused/`, run the same batch size, sequence length, and step count three times. Save each run to its own `--profile-dir` so JSON summaries do not overwrite each other. -#### Step 1: Baseline Comparison ```bash -# Run Version 1 baseline for comparison -cd ../version1_pytorch_baseline -python tiny_llama_v1.py --batch-size 8 --seq-len 128 --num-steps 30 > ../version2_baseline_comparison.log +cd version2_pytorch_fused -# Run Version 2 fused implementation -cd ../version2_pytorch_fused -python tiny_llama_v2.py --batch-size 8 --seq-len 128 --num-steps 30 > fused_performance.log +# 1. Unfused baseline (equivalent to Version 1) +python tiny_llama_v2.py \ + --batch-size 8 --seq-len 128 --num-steps 30 --disable-all-fusion \ + --profile-dir ./bench_no_fusion + +# 2. Fused QKV + Flash Attention + SwiGLU +python tiny_llama_v2.py \ + --batch-size 8 --seq-len 128 --num-steps 30 \ + --profile-dir ./bench_fused + +# 3. Fused + torch.compile +python tiny_llama_v2.py \ + --batch-size 8 --seq-len 128 --num-steps 30 --enable-torch-compile \ + --profile-dir ./bench_torch_compile ``` -#### Step 2: Kernel Count Analysis -```bash -# PyTorch profiler comparison -python run_pytorch_profiler.py --batch-size 8 --profile-dir ./fusion_analysis --generate-report +Compare the performance you see for the different models. What are the differences? + +#### Step 2: Optional operator-level profiling + +Compare the kernel launch patterns between the three cases with the built-in PyTorch profiler: -# Compare kernel counts between versions -python analyze_kernel_reduction.py --baseline ../version1_pytorch_baseline/pytorch_profiles --fused ./fusion_analysis +```bash +python tiny_llama_v2.py \ + --batch-size 8 --seq-len 128 --num-steps 10 --enable-pytorch-profiler \ + --profile-dir ./fusion_analysis ``` +Open the Chrome trace or TensorBoard timeline and compare the unfused and fused versions. Do you see the ~43% fewer attention-related kernels per layer reported by the Python script? -**Expected Results:** -- 40-60% reduction in kernel launch count -- 1.4-1.8x speedup in overall training -- Improved GPU utilization metrics +#### Reference results + +The following reference results have been obtained on an MI300A with PyTorch 2.9.1 and ROCm 7.2.0 with the same model setup as described above. + +| Configuration | Throughput (samples/s) | Avg batch time (ms) | Peak device memory (MB) | +|-----------------|------------------------|---------------------|-------------------------| +| `--disable-all-fusion` (V1-equivalent) | 293 | 27.3 | 998 | +| Default fused | 437 | 18.3 | 967 | +| `--enable-torch-compile` | 794 | 10.1 | 875 | + +On this setup, fusion yields ~**1.5×** throughput over the unfused path; adding `torch.compile` reaches ~**2.7×** vs. unfused and ~**1.8×** vs. fused alone. +With the short sequence length of `seq=128` because with this short sequence length, the majority of the memory is consumed by the weights and gradients. +Continue to exercise 2 to learn more about the impact of kernel fusion and Flash Attention on the memory consumption. ### Exercise 2: Flash Attention Memory Analysis -**Objective**: Analyze memory efficiency improvements from Flash Attention. +**Objective**: Show how peak device memory scales with sequence length for naive attention vs. Flash Attention. + +#### Memory scaling of unfused and fused attention + +Next, investigate how the memory consumption scales if we increase the sequence length with both naive unfused attention and the fused Flash Attention kernel. +For this, enable `--enable-memory-profiling` so the summary reports **peak device memory** per run. Keep `batch-size 4` and `num-steps 20` fixed while sweeping sequence length. +Run this for both variants and compare the scaling. Below, you can find some reference results to compare to. -#### Step 1: Memory Scaling Test ```bash -# Test memory scaling with sequence length for seq_len in 128 256 512 1024; do python tiny_llama_v2.py \ --seq-len $seq_len \ --batch-size 4 \ + --num-steps 20 \ --enable-memory-profiling \ --profile-dir ./flash_attention_seq${seq_len} done ``` -#### Step 2: Memory Bandwidth Analysis -```bash -# Analyze memory bandwidth utilization -python run_deepspeed_flops.py \ - --batch-size 8 \ - --seq-len 256 \ - --computational-intensity \ - --generate-roofline -``` +#### Reference results -**Expected Results:** +The following reference results have been obtained on an MI300A with PyTorch 2.9.1 and ROCm 7.2.0 with the same model setup as described above. -- Linear memory scaling vs. quadratic for baseline -- 2-4x memory reduction for longer sequences -- Improved arithmetic intensity metrics +| Configuration | seq=128 | seq=256 | seq=512 | seq=1024 | +|---------------|---------|---------|---------|----------| +| `--disable-all-fusion` | 764 | 1031 | 1669 | 3471 | +| Default fused (Flash Attention) | 764 | 967 | 1414 | 2302 | +|---------------|---------|---------|---------|----------| +| Ratio | 1.00x | 1.06x | 1.18x | 1.51x | -### Exercise 3: ROCm Tools Deep Dive +Clearly, the fused attention kernel reduces the required memory significantly. Why is that? +Unfused attention materializes an \(S \times S\) attention matrix, so the peak memory rises close to **quadratically** in sequence length once that tensor dominates. Flash Attention avoids storing the full matrix as it computes the local attention scores on-the-fly resulting in a roughly **linear** scaling in \(S\). At `seq=128`, all both models report the same peak because weights and activations dominate. -**Objective**: Master ROCm profiling tools for hardware-level optimization. +Does the further fusion with `torch.compile` lower the peak even more? Try it out! + +### Exercise 3: Using ROCm Tools + +**Objective**: Explore ROCm profiling tools for hardware-level optimization. AMD offers three performance profiling tools for ROCm based applications: -`rocprofv3`, `rocprof-sys`, and `rocprof-compute`. For more details about these tools, see + - `rocprofv3` (hotspot analysis and timeline traces) + - `rocprof-sys` (hotspot and timeline profiling including CPU and MPI) + - `rocprof-compute` (in-depth profiling of kernel) + +For more details about these tools, see [Appendix C of the TECHNICAL_APPENDICES.md](https://github.com/amd/HPCTrainingExamples/blob/main/MLExamples/TinyTransformer/TECHNICAL_APPENDICES.md#appendix-c-rocm-profiling-tools-reference). about each tool. @@ -574,7 +617,7 @@ rocprof-compute profile -n roof --kernel-names --roof-only --device 0 -- python This generates three PDF files: two roofline plots and a legend. -To collect a profile, then analyze a particular dispatch, run the following commands: +To collect a profile, then analyze a particular kernel dispatch, run the following commands: ```bash rocprof-compute profile -n ver2 --no-roof -- python3 tiny_llama_v2.py --batch-size 8 --seq-len 128 --num-steps 30 diff --git a/MLExamples/TinyTransformer/version2_pytorch_fused/tiny_llama_v2.py b/MLExamples/TinyTransformer/version2_pytorch_fused/tiny_llama_v2.py index 716e8225..028dfc33 100644 --- a/MLExamples/TinyTransformer/version2_pytorch_fused/tiny_llama_v2.py +++ b/MLExamples/TinyTransformer/version2_pytorch_fused/tiny_llama_v2.py @@ -141,6 +141,7 @@ def reset(self): self.metrics = { 'training_speed': [], 'memory_usage': [], + 'gpu_peak_memory_mb': [], 'gpu_utilization': [], 'loss_values': [], 'batch_times': [], @@ -168,8 +169,14 @@ def end_timing(self) -> float: self.start_time = None return elapsed - def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, float], fusion_stats: Dict[str, Any] = None): - """Record metrics for a training batch with fusion statistics.""" + def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, float], fusion_stats: Dict[str, Any] = None, + gpu_peak_memory_mb: Optional[float] = None): + """Record metrics for a training batch with fusion statistics. + + gpu_peak_memory_mb: per-step peak device memory (bytes->MB) from + torch.cuda.max_memory_allocated() after reset_peak_memory_stats() at + step start; captures transient activations during backward. + """ self.total_samples += batch_size self.metrics['loss_values'].append(loss) self.metrics['batch_times'].append(timings.get('total', 0)) @@ -181,6 +188,8 @@ def record_batch_metrics(self, batch_size: int, loss: float, timings: Dict[str, if torch.cuda.is_available(): memory_mb = torch.cuda.memory_allocated() / (1024**2) self.metrics['memory_usage'].append(memory_mb) + if gpu_peak_memory_mb is not None: + self.metrics['gpu_peak_memory_mb'].append(gpu_peak_memory_mb) # Training speed if timings.get('total', 0) > 0: @@ -206,7 +215,13 @@ def get_summary(self) -> Dict[str, Any]: 'avg_optimizer_time': np.mean(self.metrics['optimizer_times']), } - if self.metrics['memory_usage']: + if self.metrics['gpu_peak_memory_mb']: + summary.update({ + 'peak_memory_mb': max(self.metrics['gpu_peak_memory_mb']), + 'avg_peak_memory_mb': np.mean(self.metrics['gpu_peak_memory_mb']), + 'avg_memory_mb': np.mean(self.metrics['memory_usage']) if self.metrics['memory_usage'] else 0.0, + }) + elif self.metrics['memory_usage']: summary.update({ 'peak_memory_mb': max(self.metrics['memory_usage']), 'avg_memory_mb': np.mean(self.metrics['memory_usage']) @@ -223,10 +238,16 @@ def get_summary(self) -> Dict[str, Any]: fusion_summary = {} for key, values in total_fusion_stats.items(): - if isinstance(values[0], (int, float)): - fusion_summary[f'avg_{key}'] = np.mean(values) + sample = values[0] + # bool subclasses int — must branch on bool first so flags keep canonical keys + if isinstance(sample, bool): + fusion_summary[key] = bool(sample) + elif isinstance(sample, int): + fusion_summary[key] = int(round(np.mean(values))) + elif isinstance(sample, float): + fusion_summary[key] = float(np.mean(values)) else: - fusion_summary[key] = values[-1] # Keep latest non-numeric value + fusion_summary[key] = values[-1] summary['fusion_statistics'] = fusion_summary @@ -818,6 +839,9 @@ def train_tiny_llama_v2( print("=" * 70) for step in range(num_steps): + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + # Start batch timing batch_timings = {} monitor.start_timing() @@ -863,12 +887,18 @@ def train_tiny_llama_v2( # Total batch time batch_timings['total'] = sum(batch_timings.values()) + peak_mb: Optional[float] = None + if torch.cuda.is_available(): + torch.cuda.synchronize() + peak_mb = torch.cuda.max_memory_allocated() / (1024**2) + # Record metrics with fusion statistics monitor.record_batch_metrics( batch_size, loss.item(), batch_timings, - fusion_stats + fusion_stats, + gpu_peak_memory_mb=peak_mb, ) # PyTorch profiler step @@ -878,12 +908,13 @@ def train_tiny_llama_v2( # Progress logging if step % 10 == 0: speed = batch_size / batch_timings['total'] if batch_timings['total'] > 0 else 0 - memory_mb = torch.cuda.memory_allocated() / (1024**2) if torch.cuda.is_available() else 0 + live_mb = torch.cuda.memory_allocated() / (1024**2) if torch.cuda.is_available() else 0 + peak_log = f"{peak_mb:6.1f}" if peak_mb is not None else " n/a" print(f"Step {step:3d}/{num_steps} | " f"Loss: {loss.item():.4f} | " f"Speed: {speed:5.1f} samples/sec | " - f"Memory: {memory_mb:6.1f} MB | " + f"Peak: {peak_log} MB | Live: {live_mb:6.1f} MB | " f"Time: {batch_timings['total']*1000:5.1f}ms") print("=" * 70) @@ -919,7 +950,11 @@ def train_tiny_llama_v2( print(f" Final loss: {summary.get('avg_loss', 0):.4f}") if 'peak_memory_mb' in summary: - print(f" Peak memory usage: {summary['peak_memory_mb']:.1f} MB") + print(f" Peak device memory (high-water per step): {summary['peak_memory_mb']:.1f} MB") + if 'avg_peak_memory_mb' in summary: + print(f" Avg peak per step: {summary['avg_peak_memory_mb']:.1f} MB") + if 'avg_memory_mb' in summary: + print(f" Avg live allocations after step: {summary['avg_memory_mb']:.1f} MB") # Fusion efficiency summary if 'fusion_statistics' in summary: @@ -1136,4 +1171,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/MLExamples/TinyTransformer/version3_triton/README.md b/MLExamples/TinyTransformer/version3_triton/README.md index 24d5e8b2..c8f5794a 100644 --- a/MLExamples/TinyTransformer/version3_triton/README.md +++ b/MLExamples/TinyTransformer/version3_triton/README.md @@ -474,7 +474,7 @@ MEMORY_ACCESS_PATTERNS = { Ensure Triton is installed in your environment: ```bash -# Should already be installed from setup/ +# Should already be installed with your PyTorch / ROCm modules pip install triton ```