Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/kernels/megablox/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
}
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, manual_type=jax.sharding.ManualAxisType()),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down Expand Up @@ -775,7 +775,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
}
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, manual_type=jax.sharding.ManualAxisType()),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down
15 changes: 14 additions & 1 deletion src/maxtext/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def gmm(
weight_gather_axes: List[Tuple[str, int]] | None = None,
input_buffer_count: tuple[int, int, int] = (2, 2, 2),
combine_scopes: bool = False,
lhs_vma_axes: tuple = tuple(),
rhs_vma_axes: tuple = tuple(),
# TODO(amandaliang): get rid of the qwix_rule in favor of Qwix's interception feature
qwix_rule: qwix.QtRule | None = None,
):
Expand All @@ -65,7 +67,7 @@ def gmm(
)

gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13))
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 5, 6, 9, 10, 11, 12, 13, 14, 15))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
return gmm_fwd_bwd(
lhs,
Expand All @@ -82,6 +84,8 @@ def gmm(
quantization_rule,
use_tokamax_backend,
weight_gather_axes,
lhs_vma_axes,
rhs_vma_axes,
)


Expand All @@ -100,6 +104,8 @@ def _gmm_fwd(
quantization_rule: qwix.QtRule | None = None,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
lhs_vma_axes: tuple = tuple(),
rhs_vma_axes: tuple = tuple(),
) -> tuple[
jnp.ndarray,
tuple[
Expand Down Expand Up @@ -175,6 +181,8 @@ def _gmm_bwd(
quantization_rule: qwix.QtRule | None,
use_tokamax_backend: bool,
weight_gather_axes: List[Tuple[str, int]] | None,
lhs_vma_axes: tuple,
rhs_vma_axes: tuple,
residual: tuple[
jnp.ndarray | qpl.QArray,
jnp.ndarray | qpl.QArray,
Expand Down Expand Up @@ -264,6 +272,9 @@ def _gmm_bwd(
transpose_rhs=not transpose_rhs,
interpret=interpret,
)
for axis in lhs_vma_axes:
dlhs = jax.lax.pcast(dlhs, axis_name=axis, to="varying")

drhs = backend.tgmm(
lhs.swapaxes(0, 1),
drhs_dout,
Expand All @@ -274,6 +285,8 @@ def _gmm_bwd(
num_actual_groups,
interpret=interpret,
)
for axis in rhs_vma_axes:
drhs = jax.lax.pcast(drhs, axis_name=axis, to="varying")

# NOTE: If the rhs transposition is fused into the forward pass we need to
# return the transpose of the rhs gradient that we calculated above.
Expand Down
22 changes: 21 additions & 1 deletion src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,9 +893,18 @@ def sparse_matmul(
):
"""Perform sparse matrix multiplication of inputs and Experts."""

use_vma = True
def gmm(
inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes, input_buffer_count, combine_scopes
):
def extract_vma(tensor):
try:
aval = jax.typeof(tensor)
if hasattr(aval, 'vma') and aval.vma:
return tuple(sorted(list(aval.vma)))
except Exception:
pass
return tuple()
# TODO (b/491979205) pipeline fsdp ag per repeat fails tokamax gmm
if self.config.using_pipeline_parallelism and self.config.pipeline_fsdp_ag_per_repeat:
tokamax_group_sizes = group_sizes
Expand Down Expand Up @@ -931,6 +940,8 @@ def gmm(
)
if self.config.use_tokamax_gmm:
if self.config.quantization:
lhs_vma_axes = extract_vma(inputs)
rhs_vma_axes = extract_vma(kernel)
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
Expand All @@ -944,6 +955,8 @@ def gmm(
weight_gather_axes=weight_gather_axes,
input_buffer_count=input_buffer_count,
combine_scopes=combine_scopes,
lhs_vma_axes=lhs_vma_axes,
rhs_vma_axes=rhs_vma_axes,
)
else:
output = tokamax.ragged_dot(
Expand All @@ -956,6 +969,8 @@ def gmm(
)
else:
if self.config.megablox:
lhs_vma_axes = extract_vma(inputs)
rhs_vma_axes = extract_vma(kernel)
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
Expand All @@ -967,6 +982,8 @@ def gmm(
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
lhs_vma_axes=lhs_vma_axes,
rhs_vma_axes=rhs_vma_axes,
)
else:
rhs_inputs = kernel
Expand Down Expand Up @@ -1103,7 +1120,7 @@ def gmm(
P(), # Handle None or replicate the output
P(), # Handle None or replicate the output
),
check_vma=False,
check_vma=use_vma,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
Expand Down Expand Up @@ -1289,7 +1306,10 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
# multiplied result from W_gate and W_up before downward projection
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)
# output of FFN
# print("shuwen intermediate_output")

intermediate_output = gmm_fn(
intermediate_layer,
Expand Down
2 changes: 0 additions & 2 deletions src/maxtext/layers/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ def func_to_vmap(body_instance, stages_inputs, stages_segment_ids, stages_positi
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0, "_overwrite_with_gradient": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
Expand Down Expand Up @@ -265,7 +264,6 @@ def func_to_vmap(
vmap_func = nn.vmap(
func_to_vmap,
in_axes=(0, 0, 0, 0, None, None),
spmd_axis_name=self.spmd_axis_name,
variable_axes={"params": 0},
split_rngs={"params": self.is_initializing(), "dropout": self.config.enable_dropout},
metadata_params={
Expand Down
Loading