Skip to content

Changes needed for jepa latent forecasting#2110

Closed
shmh40 wants to merge 12 commits intosophiex/dev-ssl/deep-sslfrom
shmh40/dev/jepa-latent-forecast
Closed

Changes needed for jepa latent forecasting#2110
shmh40 wants to merge 12 commits intosophiex/dev-ssl/deep-sslfrom
shmh40/dev/jepa-latent-forecast

Conversation

@shmh40
Copy link
Copy Markdown
Contributor

@shmh40 shmh40 commented Mar 25, 2026

Description

We want to be able to do JEPA loss where the student has data at time t, and teacher data at t+1. This requires changes in multi_stream_data_sampler to shift the window for the teacher source data, and a light change to the SSL latent loss if we want to do this without spatial masking, as otherwise the masks for the student and teacher are the same in latent space, so no patches to compute the loss on.

Example config included. Potentially teacher_time_offset should be inside the target_input.

Issue Number

Closes #2109

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@shmh40 shmh40 requested a review from sophie-xhonneux March 25, 2026 15:26
@shmh40 shmh40 self-assigned this Mar 25, 2026
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Mar 25, 2026
@shmh40 shmh40 changed the base branch from develop to develop-ssl March 26, 2026 12:14
@shmh40 shmh40 changed the base branch from develop-ssl to sophiex/dev-ssl/deep-ssl March 26, 2026 16:28
@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Mar 29, 2026

@shmh40 : could you explain conceptually why a separate teacher offset is needed and why this is not covered by

forecast:
  offset: 1

@shmh40
Copy link
Copy Markdown
Contributor Author

shmh40 commented Mar 30, 2026

As far as I can tell, in student-teacher mode, in msds we have source_select and target_select as only "network_input". This means that they only use _build_stream_data_input, since _build_stream_data_output only adds to the stream_data if the mode is "target_coords" or "target_values". forecast: offset: sets output_offset, which is only used for output_data (in _get_data_windows and _build_stream_data_output), and since this output is never fed to the student and teacher networks, it makes no difference. I have tested with some print statements.

teacher_time_offset explicitly offsets the input_data_target, which is then fed to the teacher.

@shmh40
Copy link
Copy Markdown
Contributor Author

shmh40 commented Apr 10, 2026

Superseded by #2196 going into develop-ssl.

@clessig
Copy link
Copy Markdown
Collaborator

clessig commented Apr 12, 2026

Superseded by #2196 going into develop-ssl.

We cannot have a branch develop-ssl. This will diverge the developments and make it very hard to merge things later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

Enable JEPA latent forecasting

3 participants