An unofficial implementation of GameNGen simulating various classic Atari games via Diffusion Model.
- Install
uvand ensure it is available on yourPATH.
To install all required dependencies, run:
uv syncThis will set up your environment according to the project's pyproject.toml configuration.
First, activate the virtual environment created by uv. For detailed instructions, refer to the uv documentation.
This project uses a customized version of the rl-baselines3-zoo framework, which primarily adds a wrapper around the environment to capture gameplay videos and agent actions during RL training.
To collect gameplay data with rl-baselines3-zoo, run:
python -m rl_zoo3.train --algo ppo --env PongNoFrameskip-v4 --data-collect-dir dataThis command will create a dataset of game episodes in the data directory, which can be used for subsequent training.
Before training, you need to preprocess the raw recorded episodes into a format suitable for training the diffusion model. This script compiles and optionally shuffles the episode files into a final dataset directory.
python scripts/prepare_gamengen_dataset.py --data-dir data/ppo/PongNoFrameskip-v4_1 --output-dir data/gamengen/pongOnce your dataset is ready, you can begin training the GameNGen model. The following command will launch the training script with the default settings as defined in scripts/train_gamengen.sh. This script will handle model initialization, configuration, and checkpointing.
bash scripts/train_gamengen.sh