diff --git a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java index 5e040943f703..d2b582d0c135 100644 --- a/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java +++ b/gcp/src/main/java/org/apache/iceberg/gcp/gcs/GCSFileIO.java @@ -27,6 +27,7 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Comparator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -327,16 +328,27 @@ public void deleteFiles(Iterable pathsToDelete) throws BulkDeletionFailu @SuppressWarnings("resource") private void internalDeleteFiles(Stream blobIdsToDelete) { - Streams.stream( - Iterators.partition( - blobIdsToDelete.iterator(), - clientForStoragePath(ROOT_STORAGE_PREFIX).gcpProperties().deleteBatchSize())) - .forEach( - batch -> { - if (!batch.isEmpty()) { - clientForStoragePath(batch.get(0).toGsUtilUri()).storage().delete(batch); - } - }); + // Group blobs by their per-prefix client first so each GCS API call uses the credentials + // configured for the prefix that owns the blob, instead of reusing the first object's + // client for the whole batch. + Map> blobsByClient = new LinkedHashMap<>(); + blobIdsToDelete.forEach( + blobId -> + blobsByClient + .computeIfAbsent( + clientForStoragePath(blobId.toGsUtilUri()), key -> Lists.newArrayList()) + .add(blobId)); + + blobsByClient.forEach( + (client, blobs) -> + Streams.stream( + Iterators.partition(blobs.iterator(), client.gcpProperties().deleteBatchSize())) + .forEach( + batch -> { + if (!batch.isEmpty()) { + client.storage().delete(batch); + } + })); } @Override diff --git a/gcp/src/test/java/org/apache/iceberg/gcp/gcs/TestGCSFileIO.java b/gcp/src/test/java/org/apache/iceberg/gcp/gcs/TestGCSFileIO.java index f6841664e0d3..9f019d467fe7 100644 --- a/gcp/src/test/java/org/apache/iceberg/gcp/gcs/TestGCSFileIO.java +++ b/gcp/src/test/java/org/apache/iceberg/gcp/gcs/TestGCSFileIO.java @@ -19,14 +19,18 @@ package org.apache.iceberg.gcp.gcs; import static java.lang.String.format; +import static org.apache.iceberg.gcp.GCPProperties.GCS_DELETE_BATCH_SIZE; import static org.apache.iceberg.gcp.GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENABLED; import static org.apache.iceberg.gcp.GCPProperties.GCS_OAUTH2_REFRESH_CREDENTIALS_ENDPOINT; import static org.apache.iceberg.gcp.GCPProperties.GCS_OAUTH2_TOKEN; import static org.apache.iceberg.gcp.GCPProperties.GCS_OAUTH2_TOKEN_EXPIRES_AT; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; import com.google.auth.oauth2.AccessToken; import com.google.auth.oauth2.OAuth2Credentials; @@ -64,6 +68,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; public class TestGCSFileIO { private static final String TEST_BUCKET = "TEST_BUCKET"; @@ -204,6 +209,148 @@ public void testDeletePrefix() { .isEqualTo(1); } + @Test + public void deleteFilesRoutesToCorrectClientPerPrefix() { + String bucket1 = "bucket1"; + String bucket2 = "bucket2"; + Storage backing = LocalStorageHelper.getOptions().getService(); + BlobId blob1a = BlobId.of(bucket1, "table/a.dat"); + BlobId blob1b = BlobId.of(bucket1, "table/b.dat"); + BlobId blob2a = BlobId.of(bucket2, "table/a.dat"); + BlobId blob2b = BlobId.of(bucket2, "table/b.dat"); + for (BlobId blobId : ImmutableList.of(blob1a, blob1b, blob2a, blob2b)) { + backing.create(BlobInfo.newBuilder(blobId).build()); + } + + try (GCSFileIO fileIO = new GCSFileIO(() -> spyWithBatchDeleteStub(backing))) { + fileIO.setCredentials( + ImmutableList.of( + StorageCredential.create( + "gs://" + bucket1, + ImmutableMap.of(GCS_OAUTH2_TOKEN, "token1", GCS_OAUTH2_TOKEN_EXPIRES_AT, "2000")), + StorageCredential.create( + "gs://" + bucket2, + ImmutableMap.of( + GCS_OAUTH2_TOKEN, "token2", GCS_OAUTH2_TOKEN_EXPIRES_AT, "3000")))); + fileIO.initialize( + ImmutableMap.of(GCS_OAUTH2_TOKEN, "rootToken", GCS_OAUTH2_TOKEN_EXPIRES_AT, "1000")); + + // Interleave so the first object in the batch is in bucket1 but later objects are in + // bucket2. With the bug, the bucket1 client would be reused for the whole batch. + Iterable deletes = + ImmutableList.of( + "gs://" + bucket1 + "/table/a.dat", + "gs://" + bucket2 + "/table/a.dat", + "gs://" + bucket1 + "/table/b.dat", + "gs://" + bucket2 + "/table/b.dat"); + fileIO.deleteFiles(deletes); + + assertThat(backing.list(bucket1).iterateAll()).isEmpty(); + assertThat(backing.list(bucket2).iterateAll()).isEmpty(); + + Storage client1 = fileIO.client("gs://" + bucket1 + "/anything"); + Storage client2 = fileIO.client("gs://" + bucket2 + "/anything"); + Storage rootClient = fileIO.client("gs://random-bucket/anything"); + assertThat(client1).isNotSameAs(client2).isNotSameAs(rootClient); + + ArgumentCaptor> client1Batches = captorForBlobBatches(); + verify(client1).delete(client1Batches.capture()); + assertThat(ImmutableList.copyOf(client1Batches.getValue())).containsExactly(blob1a, blob1b); + + ArgumentCaptor> client2Batches = captorForBlobBatches(); + verify(client2).delete(client2Batches.capture()); + assertThat(ImmutableList.copyOf(client2Batches.getValue())).containsExactly(blob2a, blob2b); + + verify(rootClient, never()).delete(any(Iterable.class)); + } + } + + @Test + public void deleteFilesBatchesPerClient() { + String bucket1 = "bucket1"; + String bucket2 = "bucket2"; + Storage backing = LocalStorageHelper.getOptions().getService(); + List bucket1Blobs = Lists.newArrayList(); + List bucket2Blobs = Lists.newArrayList(); + for (int i = 0; i < 5; i++) { + BlobId b1 = BlobId.of(bucket1, "table/file" + i + ".dat"); + BlobId b2 = BlobId.of(bucket2, "table/file" + i + ".dat"); + backing.create(BlobInfo.newBuilder(b1).build()); + backing.create(BlobInfo.newBuilder(b2).build()); + bucket1Blobs.add(b1); + bucket2Blobs.add(b2); + } + + try (GCSFileIO fileIO = new GCSFileIO(() -> spyWithBatchDeleteStub(backing))) { + fileIO.setCredentials( + ImmutableList.of( + StorageCredential.create( + "gs://" + bucket1, + ImmutableMap.of(GCS_OAUTH2_TOKEN, "token1", GCS_OAUTH2_TOKEN_EXPIRES_AT, "2000")), + StorageCredential.create( + "gs://" + bucket2, + ImmutableMap.of( + GCS_OAUTH2_TOKEN, "token2", GCS_OAUTH2_TOKEN_EXPIRES_AT, "3000")))); + fileIO.initialize( + ImmutableMap.of( + GCS_OAUTH2_TOKEN, + "rootToken", + GCS_OAUTH2_TOKEN_EXPIRES_AT, + "1000", + GCS_DELETE_BATCH_SIZE, + "2")); + + List deletes = Lists.newArrayList(); + for (int i = 0; i < 5; i++) { + deletes.add("gs://" + bucket1 + "/table/file" + i + ".dat"); + deletes.add("gs://" + bucket2 + "/table/file" + i + ".dat"); + } + fileIO.deleteFiles(deletes); + + assertThat(backing.list(bucket1).iterateAll()).isEmpty(); + assertThat(backing.list(bucket2).iterateAll()).isEmpty(); + + Storage client1 = fileIO.client("gs://" + bucket1 + "/anything"); + Storage client2 = fileIO.client("gs://" + bucket2 + "/anything"); + + assertPerClientBatches(client1, bucket1, bucket1Blobs); + assertPerClientBatches(client2, bucket2, bucket2Blobs); + } + } + + private void assertPerClientBatches(Storage client, String bucket, List expectedBlobs) { + ArgumentCaptor> batches = captorForBlobBatches(); + verify(client, atLeastOnce()).delete(batches.capture()); + List seen = Lists.newArrayList(); + for (Iterable batch : batches.getAllValues()) { + List batchList = ImmutableList.copyOf(batch); + assertThat(batchList).isNotEmpty().hasSizeLessThanOrEqualTo(2); + assertThat(batchList).allMatch(id -> bucket.equals(id.getBucket())); + seen.addAll(batchList); + } + assertThat(seen).containsExactlyInAnyOrderElementsOf(expectedBlobs); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static ArgumentCaptor> captorForBlobBatches() { + return ArgumentCaptor.forClass((Class) Iterable.class); + } + + @SuppressWarnings("unchecked") + private static Storage spyWithBatchDeleteStub(Storage backing) { + Storage spied = spy(backing); + doAnswer( + invoke -> { + Iterable iter = invoke.getArgument(0); + List answer = Lists.newArrayList(); + iter.forEach(blobId -> answer.add(backing.delete(blobId))); + return answer; + }) + .when(spied) + .delete(any(Iterable.class)); + return spied; + } + @ParameterizedTest @MethodSource("org.apache.iceberg.TestHelpers#serializers") public void testGCSFileIOSerialization(