Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/simulation/scripts/launch_train_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback

from extractors import ( # noqa: F401
CNN1DExtractor,
Expand Down Expand Up @@ -58,6 +60,15 @@
print(f"{model.batch_size=}")
print(f"{model.device=}")

# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(
save_freq=50_000, # Fait des backups toutes les 100_000 itéartions
save_path="Backups/",
name_prefix="back_up_model",
save_replay_buffer=True,
save_vecnormalize=True,
)

while True:
onnx_utils.export_onnx(
model,
Expand All @@ -71,10 +82,12 @@
model.learn(
total_timesteps=c.total_timesteps,
progress_bar=False,
callback=PlotModelIO(),
callback=[PlotModelIO(),checkpoint_callback],
)
else:
model.learn(total_timesteps=c.total_timesteps, progress_bar=True)


model.learn(total_timesteps=c.total_timesteps, progress_bar=True,callback=checkpoint_callback)

print("iteration over")
# TODO: we could just use a callback to save checkpoints or export the model to onnx
Expand Down