Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/build_and_test_maxtext.yml
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,8 @@ jobs:
steps:
- name: Check test results
run: |
# If doc-only, all tests should be skipped
if [ "${NEEDS_DOC_ONLY_CHECK_OUTPUTS_RUN_TESTS}" == "false" ]; then
echo "Documentation-only changes detected, tests were skipped"
if [ "${NEEDS_ANALYZE_CHANGES_OUTPUTS_RUN_TESTS}" == "false" ]; then
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this change seems like it belongs in separate PR? Is this change needed to upgrade jax?

echo "Tests were skipped"
exit 0
fi
Expand Down
12 changes: 6 additions & 6 deletions .github/workflows/run_pathways_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,32 +107,32 @@ jobs:
PYTHONPATH: "${{ github.workspace }}/src"
services:
resource_manager:
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:20260413-jax_0.9.2
ports:
- "29001:29001"
- "29002:29002"
options:
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29001, --node_type=resource_manager, --enforce_kernel_ipv6_support=false, --instance_count=1, --instance_type=tpuv6e:2x2, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
env:
TPU_SKIP_MDS_QUERY: true

worker:
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:20260413-jax_0.9.2
ports:
- "29005:29005"
- "29006:29006"
- "8471:8471"
- "8080:8080"
options:
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
--entrypoint=[/usr/pathways/run/cloud_pathways_server_sanitized, --server_port=29005, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp]
--tpu=4

proxy:
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest
image: us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:20260413-jax_0.9.2
ports:
- "29000:29000"
env:
IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS: true
XLA_FLAGS: "--xla_dump_to=/tmp/aot_test_dump --xla_dump_hlo_as_text --xla_dump_hlo_module_re=jit_train_step"
options:
--entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true]
--entrypoint=[/usr/pathways/run/cloud_proxy_server_sanitized, --server_port=29000, --resource_manager_address=localhost:29001, --enforce_kernel_ipv6_support=false, --gcs_scratch_location=gs://cloud-pathways-staging/tmp, --xla_tpu_scoped_vmem_limit_kib=65536, --xla_tpu_spmd_rng_bit_generator_unsafe=true]
8 changes: 8 additions & 0 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ jobs:
# omit this libtpu init args for gpu tests
if [ "${INPUTS_DEVICE_TYPE}" != "cuda12" ]; then
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
else
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this also be in this PR?

# For cuda12, explicitly point to the pip-installed CUDA libraries
# to avoid conflicts with system-level installations on the runner.
if [ -d ".venv/lib/python3.12/site-packages/nvidia" ]; then
export LD_LIBRARY_PATH=$(pwd)/.venv/lib/python3.12/site-packages/nvidia/cudnn/lib:${LD_LIBRARY_PATH}
else
echo "Warning: Could not find pinned nvidia libraries in .venv."
fi
fi
if [ "${INPUTS_TOTAL_WORKERS}" -gt 1 ]; then
$PYTHON_EXE -m pip install --quiet pytest-split pytest-xdist
Expand Down
67 changes: 23 additions & 44 deletions docs/development/update_dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,15 @@ to keep dependencies in sync for users installing MaxText from source.
To update dependencies, you will follow these general steps:

1. **Modify base requirements**: Update the desired dependencies in
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific files
(`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`,
`src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`).
`src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files
(`src/dependencies/requirements/base_requirements/tpu-requirements.txt`,
`src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).
2. **Find the JAX build commit hash**: The dependency generation process is
pinned to a specific nightly build of JAX. You need to find the commit hash
for the desired JAX build.
3. **Generate the requirement files**: Run the `seed-env` CLI tool to generate
new, fully-pinned requirements files based on your changes.
4. **Update project files**: Copy the newly generated files into the
`src/dependencies/requirements/generated_requirements/` directory. If
necessary, also update any dependencies that are installed directly from
GitHub from the generated files to `src/dependencies/extra_deps`.
5. **Verify the new dependencies**: Test the new dependencies to ensure the
4. **Verify the new dependencies**: Test the new dependencies to ensure the
project installs and runs correctly.

The following sections provide detailed instructions for each step.
Expand All @@ -70,31 +66,23 @@ if you want to build `seed-env` from source.

## Step 1: Modify base requirements

Update the desired dependencies in
`src/dependencies/requirements/base_requirements/requirements.txt` or the
hardware-specific files
(`src/dependencies/requirements/base_requirements/tpu-base-requirements.txt`,
`src/dependencies/requirements/base_requirements/gpu-base-requirements.txt`).
Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`src/dependencies/requirements/base_requirements/tpu-requirements.txt`, `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`).

## Step 2: Find the JAX build commit hash

The dependency generation process is pinned to a specific nightly build of JAX.
You need to find the commit hash for the desired JAX build.

You can find the latest commit hashes in the
[JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose
a recent, successful build and copy its full commit hash.
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build) and copy its full commit hash.

## Step 3: Generate the requirements files

Next, run the `seed-env` CLI to generate the new requirements files. You will
need to do this separately for the TPU and GPU environments. The generated files
will be placed in a directory specified by `--output-dir`.

### For TPU
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).

### TPU Pre-Training

Run the following command, replacing `<jax-build-commit-hash>` with the hash you
copied in the previous step.
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
Expand All @@ -104,45 +92,36 @@ seed-env \
--python-version=3.12 \
--requirements-txt=tpu-requirements.txt \
--output-dir=generated_tpu_artifacts

# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_tpu_artifacts/tpu-requirements.txt \
src/dependencies/requirements/generated_requirements/tpu-requirements.txt
```

### For GPU
### GPU Pre-Training

Similarly, run the command for the GPU requirements.
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/cuda12-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/gpu-base-requirements.txt \
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=cuda12-requirements.txt \
--hardware=cuda12 \
--output-dir=generated_gpu_artifacts
```

## Step 4: Update project files

After generating the new requirements, you need to update the files in the
MaxText repository.

1. **Copy the generated files:**

- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.

2. **Update `src/dependencies/extra_deps` (if necessary):**
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and
`google-jetstream`, that are installed directly from GitHub source. These are
defined in `base_requirements/requirements.txt`, and the `seed-env` tool will
carry them over to the generated requirements files.
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_gpu_artifacts/cuda12-requirements.txt \
src/dependencies/requirements/generated_requirements/cuda12-requirements.txt
```

## Step 5: Verify the new dependencies
## Step 4: Verify the new dependencies

Finally, test that the new dependencies install correctly and that MaxText runs
as expected.

1. **Install MaxText:** Follow the instructions to
[install MaxText from source](install-from-source).
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source).

2. **Run tests:** Run MaxText tests to ensure there are no regressions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
absl-py
aqtp
array-record
chex
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
cloud-tpu-diagnostics!=1.1.14
datasets
drjax
flax
Expand All @@ -24,6 +25,7 @@ numpy
omegaconf
optax
orbax-checkpoint
parameterized
pathwaysutils
pillow
pre-commit
Expand All @@ -34,15 +36,14 @@ pylint
pytest
pytype
sentencepiece
seqio
tensorboard-plugin-profile
tensorboardx
tensorflow-datasets
tensorflow-text
tensorflow
tiktoken
tokamax
tokamax!=0.1.0
transformers
uvloop
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
-r requirements.txt
google-tunix
Loading
Loading