Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/maxdiffusion/checkpointing/ltx2_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,19 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
return restored_checkpoint, step

def load_checkpoint(
self, step=None, vae_only=False, load_transformer=True
self, step=None, vae_only=False, load_transformer=True, load_upsampler=False
) -> Tuple[LTX2Pipeline, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_ltx2_configs_from_orbax(step)
opt_state = None

if restored_checkpoint:
max_logging.log("Loading LTX2 pipeline from checkpoint")
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer)
pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint, vae_only, load_transformer, load_upsampler)
if "opt_state" in restored_checkpoint.ltx2_state.keys():
opt_state = restored_checkpoint.ltx2_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading pipeline from pretrained hub")
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer)
pipeline = LTX2Pipeline.from_pretrained(self.config, vae_only, load_transformer, load_upsampler)

return pipeline, opt_state, step

Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,13 @@ jit_initializers: True
enable_single_replica_ckpt_restoring: False
seed: 0
audio_format: "s16"

# LTX-2 Latent Upsampler
run_latent_upsampler: False
upsampler_model_path: "Lightricks/LTX-2"
upsampler_spatial_patch_size: 1
upsampler_temporal_patch_size: 1
upsampler_adain_factor: 0.0
upsampler_tone_map_compression_ratio: 0.0
upsampler_rational_spatial_scale: 2.0
upsampler_output_type: "pil"
11 changes: 8 additions & 3 deletions src/maxdiffusion/generate_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def get_git_commit_hash():


def call_pipeline(config, pipeline, prompt, negative_prompt):
# Set default generation arguments
generator = jax.random.key(config.seed) if hasattr(config, "seed") else jax.random.key(0)
guidance_scale = config.guidance_scale if hasattr(config, "guidance_scale") else 3.0

Expand All @@ -99,6 +98,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
decode_noise_scale=getattr(config, "decode_noise_scale", None),
max_sequence_length=getattr(config, "max_sequence_length", 1024),
dtype=jnp.bfloat16 if getattr(config, "activations_dtype", "bfloat16") == "bfloat16" else jnp.float32,
output_type=getattr(config, "upsampler_output_type", "pil"),
)
return out

Expand All @@ -114,9 +114,11 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
else:
max_logging.log("Could not retrieve Git commit hash.")

checkpoint_loader = LTX2Checkpointer(config=config)
if pipeline is None:
checkpoint_loader = LTX2Checkpointer(config=config)
pipeline, _, _ = checkpoint_loader.load_checkpoint()
# Use the config flag to determine if the upsampler should be loaded
run_latent_upsampler = getattr(config, "run_latent_upsampler", False)
pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=run_latent_upsampler)

pipeline.enable_vae_slicing()
pipeline.enable_vae_tiling()
Expand All @@ -135,6 +137,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
)

out = call_pipeline(config, pipeline, prompt, negative_prompt)

# out should have .frames and .audio
videos = out.frames if hasattr(out, "frames") else out[0]
audios = out.audio if hasattr(out, "audio") else None
Expand All @@ -143,6 +146,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
max_logging.log(f"model name: {getattr(config, 'model_name', 'ltx-video')}")
max_logging.log(f"model path: {config.pretrained_model_name_or_path}")
max_logging.log(f"model type: {getattr(config, 'model_type', 'T2V')}")
if getattr(config, "run_latent_upsampler", False):
max_logging.log(f"upsampler model path: {config.upsampler_model_path}")
max_logging.log(f"hardware: {jax.devices()[0].platform}")
max_logging.log(f"number of devices: {jax.device_count()}")
max_logging.log(f"per_device_batch_size: {config.per_device_batch_size}")
Expand Down
Loading
Loading