From c068c90b66ce58c11fd9702c014e6ea7882f3def Mon Sep 17 00:00:00 2001 From: Arthur RIPOLL Date: Mon, 6 Apr 2026 22:01:17 +0200 Subject: [PATCH] feat: now the model saves callback checkpoints --- .../scripts/launch_train_multiprocessing.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/simulation/scripts/launch_train_multiprocessing.py b/src/simulation/scripts/launch_train_multiprocessing.py index fc01f97..3cd8902 100644 --- a/src/simulation/scripts/launch_train_multiprocessing.py +++ b/src/simulation/scripts/launch_train_multiprocessing.py @@ -6,6 +6,8 @@ import torch.nn as nn from stable_baselines3 import PPO from stable_baselines3.common.vec_env import SubprocVecEnv +from stable_baselines3 import SAC +from stable_baselines3.common.callbacks import CheckpointCallback from extractors import ( # noqa: F401 CNN1DExtractor, @@ -58,6 +60,15 @@ print(f"{model.batch_size=}") print(f"{model.device=}") + # Save a checkpoint every 1000 steps + checkpoint_callback = CheckpointCallback( + save_freq=50_000, # Fait des backups toutes les 100_000 itéartions + save_path="Backups/", + name_prefix="back_up_model", + save_replay_buffer=True, + save_vecnormalize=True, + ) + while True: onnx_utils.export_onnx( model, @@ -71,10 +82,12 @@ model.learn( total_timesteps=c.total_timesteps, progress_bar=False, - callback=PlotModelIO(), + callback=[PlotModelIO(),checkpoint_callback], ) else: - model.learn(total_timesteps=c.total_timesteps, progress_bar=True) + + + model.learn(total_timesteps=c.total_timesteps, progress_bar=True,callback=checkpoint_callback) print("iteration over") # TODO: we could just use a callback to save checkpoints or export the model to onnx