Skip to content

Add YOLO26 object detection contrib model#151

Open
jimburtoft wants to merge 8 commits into
aws-neuron:mainfrom
jimburtoft:contrib/yolo26
Open

Add YOLO26 object detection contrib model#151
jimburtoft wants to merge 8 commits into
aws-neuron:mainfrom
jimburtoft:contrib/yolo26

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Summary

  • Adds Ultralytics YOLO26 object detection models (n/s/m/l/x, 2.4-58.9M params) as a contrib model for real-time inference on AWS Trainium2 and Inferentia2 via torch_neuronx.trace()
  • All 5 detection variants compile and run with high accuracy (CosSim 0.987-0.997), plus pose and OBB task heads
  • Neuron outperforms compiled A10G GPU by 1.4-4.5x on s/m/l/x variants at peak DP throughput

Validation

Validated on 4 configurations: trn2.3xlarge × {SDK 2.28, 2.29} and inf2.xlarge × {SDK 2.28, 2.29}.

Instance SDK Tests yolo26n CosSim yolo26s CosSim yolo26n img/s yolo26s img/s
trn2.3xlarge 2.28 13/13 pytest 0.9943 0.9931 32.3 66.0
trn2.3xlarge 2.29 13/13 pytest 0.9941 0.9931 33.2 65.5
inf2.xlarge 2.28 6/6 standalone 0.9965 0.9931 60.1 64.1
inf2.xlarge 2.29 6/6 standalone 0.9965 0.9931 69.7 76.7

Peak Throughput (trn2.3xlarge, LNC=1, DP=8)

Variant Params Dtype img/s vs A10G Compiled
YOLO26n 2.4M FP32 272 0.13x
YOLO26s 10.0M FP32 1,523 1.43x
YOLO26m 21.9M BF16 1,267 2.67x
YOLO26l 26.3M BF16 1,093 2.95x
YOLO26x 58.9M BF16 876 4.49x

Files

contrib/models/YOLO26/
  README.md                          # Model card, benchmarks, compatibility matrix
  yolo26_neuron_notebook.ipynb       # Complete workflow notebook (tested end-to-end)
  src/
    __init__.py                      # Exports: YOLO26NeuronModel, compile_yolo26, etc.
    modeling_yolo26.py               # Trace wrapper, DP support, validation (~280 lines)
  test/
    __init__.py
    unit/__init__.py
    integration/
      __init__.py
      test_model.py                  # 13 integration tests (compile, accuracy, DP, perf)

Key Design Decisions

  • torch_neuronx.trace() (not NxDI model classes): YOLO26 is a CNN with no KV cache, no attention matrices, no token generation. All variants fit on a single NeuronCore (<180 MB NEFF). Data Parallelism provides throughput scaling.
  • end2end=False: topk/sort operations are not supported on Neuron (NCC_EVRF029). Raw [B, 84, 8400] output with CPU postprocessing (~0.1ms overhead).
  • BF16 for m/l/x: FP32 exceeds SB allocation for larger variants (NCC_IGCA030). n/s use FP32.
  • No --auto-cast flags: matmult autocast produces NaN for Conv2d-dominant models.
  • LNC-aware compilation: --lnc 1 compiler flag required when running on LNC=1 mode.

Target

aws-neuron/neuronx-distributed-inference main branch.

Ultralytics YOLO26 (n/s/m/l/x) on Trainium2 via torch_neuronx.trace().
All 5 detection variants plus pose and OBB task heads compile and run
with high accuracy (CosSim 0.987-0.997).

Peak throughput on trn2.3xlarge (LNC=1, DP=8):
- YOLO26s: 1,523 img/s (1.43x vs A10G compiled)
- YOLO26m: 1,267 img/s (2.67x vs A10G compiled)
- YOLO26l: 1,093 img/s (2.95x vs A10G compiled)
- YOLO26x:   876 img/s (4.49x vs A10G compiled)

Includes modeling module, 13 integration tests (all passing),
Jupyter notebook, and README with benchmarks.
Tested all 4 combinations:
- trn2.3xlarge SDK 2.28: 13/13 pytest passed
- trn2.3xlarge SDK 2.29: 13/13 pytest passed
- inf2.xlarge SDK 2.28: 6/6 standalone tests passed
- inf2.xlarge SDK 2.29: 6/6 standalone tests passed

inf2 single-core throughput: yolo26n 60-70 img/s, yolo26s 64-77 img/s.
Updated compatibility matrix and notebook prerequisites.
This contrib uses `torch_neuronx.trace()` rather than NxDI model classes because: (1) all variants fit trivially on a single NeuronCore (<180 MB NEFF), (2) there is no KV cache or token generation, and (3) the Conv2d-dominant architecture does not benefit from NxDI's attention infrastructure. Data Parallelism across NeuronCores provides throughput scaling.

Key Neuron porting challenges:
- **`topk`/`sort` unsupported:** End-to-end postprocessing requires `torch.topk` which fails with `NCC_EVRF029`. Solution: trace with `end2end=False` for raw output, run postprocessing on CPU.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The topk/sort limitation requiring end2end=False means users get raw [B, 84, 8400] output and must implement their own NMS postprocessing. We should include a CPU postprocessing util function (even if simple) so users have a complete detection pipeline.

data = json.loads(result.stdout)
total = sum(dev.get("nc_count", 0) for dev in data)
return total
except Exception:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generic exception

… document C2PSA issue

Review feedback from tejasamx-aws:
- Add postprocess_detections() CPU utility with NMS for complete
  detection pipeline from raw [B, 84, 8400] output
- Replace bare 'except Exception' with specific exception types in
  get_neuron_core_count()

C2PSA known issue (confirmed on SDK 2.29.1, inf2.xlarge):
- C2PSA attention module (layer 10) produces numerically incorrect
  output on NCv2 at all resolutions (CosSim ~0.46 vs CPU reference)
- Full-model CosSim of 0.99+ masks this because C2PSA output is
  diluted by correct backbone/neck/head outputs
- Backbone compilation fails at batch size >= 4 with non-square
  input (NCC_IPCC901 compiler error)
- Documented in Known Issues with workaround guidance
Root cause: neuronx-cc produces incorrect output for
torch.Tensor.split() with unequal split sizes on dim=2 of a 4D tensor
after a .view() reshape. The C2PSA Attention module's
.split([key_dim, key_dim, head_dim], dim=2) triggers this, causing
CosSim ~0.45 vs CPU (should be >0.99).

Fix: Patch Attention.forward in prepare_yolo26() to use explicit
tensor slicing ([:, :, :key_dim, :]) instead of .split(). This
compiles correctly and produces CosSim 0.9999.

Verified on inf2.xlarge, SDK 2.29.1:
- Before fix: C2PSA CosSim 0.45, full-model detection accuracy broken
- After fix: C2PSA CosSim 0.9999, matching CPU reference

The underlying compiler issue is tracked in aws-neuron-sdk#1323.
…around)

The neuronx-cc compiler crashes (exit code 70) when .split((c, c), dim=1)
is used in C2f blocks at batch_size=4 with small spatial dimensions
(H*W < ~264). Using .chunk(2, 1) -- semantically identical -- compiles
correctly at all batch sizes and spatial dimensions.

Root cause investigation: only bs=4 triggers the failure (bs=1-3,5-16
all pass). The boundary is ~264 total spatial pixels at 256 channels.
This is the same underlying .split() compiler bug as the C2PSA issue.

Verified: Layer 8 (C3k2) at [4, 256, 12, 20] now compiles with chunk.
See: aws-neuron/aws-neuron-sdk#1323
…ound

Root cause identified: C2PSA.forward uses .split((c,c), dim=1) which, when
combined with downstream attention, corrupts all batch elements except element 0
(CosSim ~0.08-0.23 vs CPU reference). This is the same neuronx-cc .split() bug
as Issues 1 and 2, manifesting as silent numerical corruption instead of a
compilation failure.

Fix: Patch C2PSA.forward to use .chunk(2, 1) which is semantically identical
but produces correct HLO. Verified: full YOLO26l model at bs=2 now produces
CosSim > 0.999 for all batch elements.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants