Changes needed for jepa latent forecasting#2110
Changes needed for jepa latent forecasting#2110shmh40 wants to merge 12 commits intosophiex/dev-ssl/deep-sslfrom
Conversation
…hmh40/dev/jepa-latent-forecast
|
@shmh40 : could you explain conceptually why a separate teacher offset is needed and why this is not covered by |
|
As far as I can tell, in student-teacher mode, in msds we have source_select and target_select as only "network_input". This means that they only use _build_stream_data_input, since _build_stream_data_output only adds to the stream_data if the mode is "target_coords" or "target_values". forecast: offset: sets output_offset, which is only used for output_data (in _get_data_windows and _build_stream_data_output), and since this output is never fed to the student and teacher networks, it makes no difference. I have tested with some print statements. teacher_time_offset explicitly offsets the input_data_target, which is then fed to the teacher. |
…hmh40/dev/jepa-latent-forecast
|
Superseded by #2196 going into develop-ssl. |
We cannot have a branch |
Description
We want to be able to do JEPA loss where the student has data at time t, and teacher data at t+1. This requires changes in multi_stream_data_sampler to shift the window for the teacher source data, and a light change to the SSL latent loss if we want to do this without spatial masking, as otherwise the masks for the student and teacher are the same in latent space, so no patches to compute the loss on.
Example config included. Potentially teacher_time_offset should be inside the target_input.
Issue Number
Closes #2109
Is this PR a draft? Mark it as draft.
Checklist before asking for review
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60