Skip to content

optimize positional encoding#2172

Merged
clessig merged 4 commits intoecmwf:developfrom
javak87:javad/dev/optimize_encoder
Apr 9, 2026
Merged

optimize positional encoding#2172
clessig merged 4 commits intoecmwf:developfrom
javak87:javad/dev/optimize_encoder

Conversation

@javak87
Copy link
Copy Markdown
Contributor

@javak87 javak87 commented Apr 5, 2026

Description

This PR introduces a minor change in the code, resulting in a significant performance gain.

Issue Number

Fixes #2173

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

Performance comparison with develop branch

  • Run ../WeatherGenerator-private/hpc/launch-slurm.py --time 60
run_id HPC PR Ingested Samples per GPU
mf3vpsec JWB develop (1 node) 1044
ve8c4vuy JWB javad/dev/optimize_encoder (1 nodes) 1294
Screenshot from 2026-04-05 19-38-29
  • Run ../WeatherGenerator-private/hpc/launch-slurm.py --time 60 --base-config ./config/config_forecasting.yml
run_id HPC PR Ingested Samples per GPU
v3ngv74i JWB develop (1 node) 670
yuvx9hwm JWB javad/dev/optimize_encoder (1 nodes) 760
Screenshot from 2026-04-05 19-53-39
  • Run ../WeatherGenerator-private/hpc/launch-slurm.py --time 60 --base-config ./config/config_jepa.yml
run_id HPC PR Ingested Samples per GPU
d4vw793o JWB develop (1 node) 2308
ob69315g JWB javad/dev/optimize_encoder (1 nodes) 4486

Performance improvements ranging from 14% to 94%, depending on the configuration, are expected.

@javak87 javak87 marked this pull request as draft April 5, 2026 19:05
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Apr 5, 2026
@clessig clessig marked this pull request as ready for review April 6, 2026 14:35
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 6, 2026

Performance improvement with multiple streams:

New:

000 : 00280/04096 : 000280 : loss = 7.8765E-01 (lr=2.93E-05, s/sec=0.876)

LossPhysical.ERA5.mse.avg : 9.4160E-01 
LossPhysical.NPPATMS.mse.avg : 3.8870E-01 
LossPhysical.SurfaceCombined.mse.avg : 5.8787E-01 
LossPhysical.loss_avg : 7.8765E-01 

Old:

000 : 00110/04096 : 000110 : loss = 9.2028E-01 (lr=6.49E-06, s/sec=0.627)

LossPhysical.ERA5.mse.avg : 1.0701E+00 
LossPhysical.NPPATMS.mse.avg : 3.5584E-01 
LossPhysical.SurfaceCombined.mse.avg : 7.0956E-01 
LossPhysical.loss_avg : 9.2028E-01 

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 6, 2026

@javak87 : I am happy to merge it. Any reason it was still marked as draft?

@javak87
Copy link
Copy Markdown
Contributor Author

javak87 commented Apr 6, 2026

@javak87 : I am happy to merge it. Any reason it was still marked as draft?

Not a specific reason. You can merge it.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 6, 2026

@javak87 : Can we use this version:

        rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0)
        rows = rows.expand(tok_counts.shape[0], -1)
        pe_idxs = rows[rows < tok_counts.unsqueeze(1)]

It's equivalent to your code but avoids one shape promotion (in the third and fourth line of your code this is happening ones implicit and ones explicit).

@javak87
Copy link
Copy Markdown
Contributor Author

javak87 commented Apr 6, 2026

@javak87 : Can we use this version:

        rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0)
        rows = rows.expand(tok_counts.shape[0], -1)
        pe_idxs = rows[rows < tok_counts.unsqueeze(1)]

It's equivalent to your code but avoids one shape promotion (in the third and fourth line of your code this is happening ones implicit and ones explicit).

Good suggestion!!
Let me run and make sure it's performant.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 6, 2026

@javak87 : Can we use this version:

        rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0)
        rows = rows.expand(tok_counts.shape[0], -1)
        pe_idxs = rows[rows < tok_counts.unsqueeze(1)]

It's equivalent to your code but avoids one shape promotion (in the third and fourth line of your code this is happening ones implicit and ones explicit).

Good suggestion!! Let me run and make sure it's performant.

Ok, please double-check and then we can merge.

@javak87
Copy link
Copy Markdown
Contributor Author

javak87 commented Apr 7, 2026

@javak87 : Can we use this version:

        rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0)
        rows = rows.expand(tok_counts.shape[0], -1)
        pe_idxs = rows[rows < tok_counts.unsqueeze(1)]

It's equivalent to your code but avoids one shape promotion (in the third and fourth line of your code this is happening ones implicit and ones explicit).

Good suggestion!! Let me run and make sure it's performant.

Ok, please double-check and then we can merge.

Since config_jepa.yml is more sensitive to this optimization, I tested your suggested changes. The number of ingested samples decreased from 4486 to 4416 per GPU.

Given this, I think my proposed change performs slightly better.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 7, 2026

decreased from 4486 to 4416 per GPU.

@javak87 : For me this is in the noise range. Can you reproduce this difference reliably?

@javak87
Copy link
Copy Markdown
Contributor Author

javak87 commented Apr 8, 2026

decreased from 4486 to 4416 per GPU.

@javak87 : For me this is in the noise range. Can you reproduce this difference reliably?

Run config_jepa.yml for 180 mins:
Again, decreased from 13532 to 13392 per GPU.
Here is the result:

Screenshot from 2026-04-08 23-18-47

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 9, 2026

decreased from 4486 to 4416 per GPU.
@javak87 : For me this is in the noise range. Can you reproduce this difference reliably?

Run config_jepa.yml for 180 mins: Again, decreased from 13532 to 13392 per GPU. Here is the result:

Screenshot from 2026-04-08 23-18-47

Let's use

        rows = torch.arange( tok_counts.max(), device=tok_counts.device).unsqueeze(0)
        rows = rows.expand(tok_counts.shape[0], -1)
        pe_idxs = rows[rows < tok_counts.unsqueeze(1)]

There is one temporary less. The small degradation might change with minor changes in pytorch and I prefer the cleaner solution.

@clessig clessig merged commit 3d50683 into ecmwf:develop Apr 9, 2026
5 checks passed
wael-mika pushed a commit to wael-mika/WeatherGenerator that referenced this pull request Apr 13, 2026
* optimize positional encoding

* update positional encoding impl

---------

Co-authored-by: Javad Kasravi <j.kasravi@fz-juelich.de>
Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Many Memcpy D to H

2 participants