Skip to content

per cell lejepa implementation#2111

Closed
shmh40 wants to merge 7 commits intoshmh40/dev/jepa-latent-forecastfrom
shmh40/dev/lejepa-per-cell
Closed

per cell lejepa implementation#2111
shmh40 wants to merge 7 commits intoshmh40/dev/jepa-latent-forecastfrom
shmh40/dev/lejepa-per-cell

Conversation

@shmh40
Copy link
Copy Markdown
Contributor

@shmh40 shmh40 commented Mar 25, 2026

Description

Draft implementation of per-cell LeJEPA with SIGReg loss. LeJEPALoss should probably just go under the LatentLoss class, to do. Also needs checking.

We have a "SelfTeacher" since we don't do any stop grad/EMA etc. with LeJEPA.

Issue Number

Closes ???

Is this PR a draft? Mark it as draft.

Checklist before asking for review

  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

@shmh40 shmh40 self-assigned this Mar 25, 2026
@shmh40
Copy link
Copy Markdown
Contributor Author

shmh40 commented Mar 25, 2026

image

@shmh40 shmh40 requested a review from sophie-xhonneux March 25, 2026 19:00
@github-actions github-actions bot added the model Related to model training or definition (not generic infra) label Mar 25, 2026
@github-actions github-actions bot added data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure labels Mar 26, 2026
self.latent_heads = nn.ModuleDict()
self.latent_pre_norm = nn.LayerNorm(cf.ae_global_dim_embed)

ssl_loss_types = ("LossLatentSSLStudentTeacher", "LossLeJEPA")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if we come up with a consistent naming, e.g. SSL_ for the ssl losses, so that we do not need to maintain an explicit list here.

student_masks = student_masks.squeeze(dim=1)
teacher_masks = teacher_masks.squeeze(dim=1)

if temporal:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this here? This should be apparent from the masking--forecasting has no mask on the target.

z = z.float() # float32 for cos/sin precision
n, d = z.shape

# Trapezoidal quadrature weights with Gaussian window on [0, 3]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The quadrature should go to a separate function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

data Anything related to the datasets used in the project eval anything related to the model evaluation pipeline infra Issues related to infrastructure model:pretrain model Related to model training or definition (not generic infra)

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants