-
Notifications
You must be signed in to change notification settings - Fork 66
Refactor Wan Model Training & Add Wan-VACE Training Support #352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| """Copyright 2025 Google LLC | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| https://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| import json | ||
| from typing import Optional, Tuple | ||
| import jax | ||
| from jax.sharding import Mesh, NamedSharding, PartitionSpec as P | ||
| from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer | ||
| import numpy as np | ||
| import orbax.checkpoint as ocp | ||
| from .. import max_logging | ||
| from ..pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 | ||
|
|
||
|
|
||
| class WanVaceCheckpointer2_1(WanCheckpointer): | ||
|
|
||
| def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: | ||
| if step is None: | ||
| step = self.checkpoint_manager.latest_step() | ||
| max_logging.log(f"Latest WAN checkpoint step: {step}") | ||
| if step is None: | ||
| max_logging.log("No WAN checkpoint found.") | ||
| return None, None | ||
| max_logging.log(f"Loading WAN checkpoint from step {step}") | ||
|
|
||
| cpu_devices = np.array(jax.devices(backend="cpu")) | ||
| mesh = Mesh(cpu_devices, axis_names=("data",)) | ||
| replicated_sharding = NamedSharding(mesh, P()) | ||
|
|
||
| metadatas = self.checkpoint_manager.item_metadata(step) | ||
| state = metadatas.wan_state | ||
|
|
||
| def add_sharding_to_struct(leaf_struct, sharding): | ||
| struct = ocp.utils.to_shape_dtype_struct(leaf_struct) | ||
| if hasattr(struct, "shape") and hasattr(struct, "dtype"): | ||
| return jax.ShapeDtypeStruct( | ||
| shape=struct.shape, dtype=struct.dtype, sharding=sharding | ||
| ) | ||
| return struct | ||
|
|
||
| target_shardings = jax.tree_util.tree_map( | ||
| lambda x: replicated_sharding, state | ||
| ) | ||
|
|
||
| with mesh: | ||
| abstract_train_state_with_sharding = jax.tree_util.tree_map( | ||
| add_sharding_to_struct, state, target_shardings | ||
| ) | ||
|
|
||
| max_logging.log("Restoring WAN checkpoint") | ||
| restored_checkpoint = self.checkpoint_manager.restore( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this replicating the sharding across devices? If so, would this be able to load on a trillium tpu with 32GB of HBM?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mesh is created using CPU devices, not TPU devices. So the model is being loaded into RAM. |
||
| step=step, | ||
| args=ocp.args.Composite( | ||
| wan_config=ocp.args.JsonRestore(), | ||
| wan_state=ocp.args.StandardRestore( | ||
| abstract_train_state_with_sharding | ||
| ), | ||
| ), | ||
| ) | ||
| max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") | ||
| max_logging.log(f"restored checkpoint wan_state {restored_checkpoint.wan_state.keys()}") | ||
| max_logging.log(f"optimizer found in checkpoint {'opt_state' in restored_checkpoint.wan_state.keys()}") | ||
| max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") | ||
| return restored_checkpoint, step | ||
|
|
||
| def load_diffusers_checkpoint(self): | ||
| pipeline = VaceWanPipeline2_1.from_pretrained(self.config) | ||
| return pipeline | ||
|
|
||
| def load_checkpoint(self, step=None) -> Tuple[VaceWanPipeline2_1, Optional[dict], Optional[int]]: | ||
| restored_checkpoint, step = self.load_wan_configs_from_orbax(step) | ||
| opt_state = None | ||
| if restored_checkpoint: | ||
| max_logging.log("Loading WAN pipeline from checkpoint") | ||
| pipeline = VaceWanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) | ||
| if "opt_state" in restored_checkpoint.wan_state.keys(): | ||
| opt_state = restored_checkpoint.wan_state["opt_state"] | ||
| else: | ||
| max_logging.log("No checkpoint found, loading default pipeline.") | ||
| pipeline = self.load_diffusers_checkpoint() | ||
|
|
||
| return pipeline, opt_state, step | ||
|
|
||
| def save_checkpoint( | ||
| self, train_step, pipeline: VaceWanPipeline2_1, train_states: dict | ||
| ): | ||
| """Saves the training state and model configurations.""" | ||
|
|
||
| def config_to_json(model_or_config): | ||
| return json.loads(model_or_config.to_json_string()) | ||
|
|
||
| max_logging.log(f"Saving checkpoint for step {train_step}") | ||
|
|
||
| # Save the checkpoint | ||
| self.checkpoint_manager.save( | ||
| train_step, | ||
| args=ocp.args.Composite( | ||
| wan_config=ocp.args.JsonSave(config_to_json(pipeline.transformer)), | ||
| wan_state=ocp.args.StandardSave(train_states), | ||
| ), | ||
| ) | ||
|
|
||
| max_logging.log(f"Checkpoint for step {train_step} is saved.") | ||
Uh oh!
There was an error while loading. Please reload this page.