Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 35 additions & 248 deletions google/cloud/bigtable/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

"""User friendly container for Google Cloud Bigtable MutationBatcher."""
import threading
import queue
import concurrent.futures
import atexit


from google.api_core.exceptions import from_grpc_status
from dataclasses import dataclass
from google.cloud.bigtable.data.exceptions import MutationsExceptionGroup
from google.cloud.bigtable.data.mutations import RowMutationEntry


FLUSH_COUNT = 100 # after this many elements, send out the batch
Expand All @@ -41,131 +39,6 @@ def __init__(self, message, exc):
super().__init__(self.message)


class _MutationsBatchQueue(object):
"""Private Threadsafe Queue to hold rows for batching."""

def __init__(self, max_mutation_bytes=MAX_MUTATION_SIZE, flush_count=FLUSH_COUNT):
"""Specify the queue constraints"""
self._queue = queue.Queue()
self.total_mutation_count = 0
self.total_size = 0
self.max_mutation_bytes = max_mutation_bytes
self.flush_count = flush_count

def get(self):
"""
Retrieve an item from the queue. Recalculate queue size.

If the queue is empty, return None.
"""
try:
row = self._queue.get_nowait()
mutation_size = row.get_mutations_size()
self.total_mutation_count -= len(row._get_mutations())
self.total_size -= mutation_size
return row
except queue.Empty:
return None

def put(self, item):
"""Insert an item to the queue. Recalculate queue size."""

mutation_count = len(item._get_mutations())

self._queue.put(item)

self.total_size += item.get_mutations_size()
self.total_mutation_count += mutation_count

def full(self):
"""Check if the queue is full."""
if (
self.total_mutation_count >= self.flush_count
or self.total_size >= self.max_mutation_bytes
):
return True
return False


@dataclass
class _BatchInfo:
"""Keeping track of size of a batch"""

mutations_count: int = 0
rows_count: int = 0
mutations_size: int = 0


class _FlowControl(object):
def __init__(
self,
max_mutations=MAX_OUTSTANDING_ELEMENTS,
max_mutation_bytes=MAX_OUTSTANDING_BYTES,
):
"""Control the inflight requests. Keep track of the mutations, row bytes and row counts.
As requests to backend are being made, adjust the number of mutations being processed.

If threshold is reached, block the flow.
Reopen the flow as requests are finished.
"""
self.max_mutations = max_mutations
self.max_mutation_bytes = max_mutation_bytes
self.inflight_mutations = 0
self.inflight_size = 0
self.event = threading.Event()
self.event.set()
self._lock = threading.Lock()

def is_blocked(self):
"""Returns True if:

- inflight mutations >= max_mutations, or
- inflight bytes size >= max_mutation_bytes, or
"""

return (
self.inflight_mutations >= self.max_mutations
or self.inflight_size >= self.max_mutation_bytes
)

def control_flow(self, batch_info):
"""
Calculate the resources used by this batch
"""

with self._lock:
self.inflight_mutations += batch_info.mutations_count
self.inflight_size += batch_info.mutations_size
self.set_flow_control_status()

def wait(self):
"""
Wait until flow control pushback has been released.
It awakens as soon as `event` is set.
"""
self.event.wait()

def set_flow_control_status(self):
"""Check the inflight mutations and size.

If values exceed the allowed threshold, block the event.
"""
if self.is_blocked():
self.event.clear() # sleep
else:
self.event.set() # awaken the threads

def release(self, batch_info):
"""
Release the resources.
Decrement the row size to allow enqueued mutations to be run.
"""
with self._lock:
self.inflight_mutations -= batch_info.mutations_count
self.inflight_size -= batch_info.mutations_size
self.set_flow_control_status()


class MutationsBatcher(object):
"""A MutationsBatcher is used in batch cases where the number of mutations
is large or unknown. It will store :class:`DirectRow` in memory until one of the
Expand Down Expand Up @@ -214,29 +87,41 @@ def __init__(
flush_interval=1,
batch_completed_callback=None,
):
self._rows = _MutationsBatchQueue(
max_mutation_bytes=max_row_bytes, flush_count=flush_count
)
self.table = table
self._executor = concurrent.futures.ThreadPoolExecutor()
atexit.register(self.close)
self._timer = threading.Timer(flush_interval, self.flush)
self._timer.start()
self.flow_control = _FlowControl(
max_mutations=MAX_OUTSTANDING_ELEMENTS,
max_mutation_bytes=MAX_OUTSTANDING_BYTES,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you're discarding the flow control? You should be able to pass these through to the data client batcher

self.futures_mapping = {}
self.exceptions = queue.Queue()
self._flush_count = flush_count
self._max_row_bytes = max_row_bytes
self._flush_interval = flush_interval
self._user_batch_completed_callback = batch_completed_callback
self._init_batcher()
atexit.register(self.close)
self._exceptions = queue.Queue()

@property
def flush_count(self):
return self._rows.flush_count
return self._flush_count

@property
def max_row_bytes(self):
return self._rows.max_mutation_bytes
return self._max_row_bytes

def _init_batcher(self):
self._batcher = self.table._table_impl.mutations_batcher(
flush_interval=self._flush_interval,
flush_limit_mutation_count=self._flush_count,
flush_limit_bytes=self._max_row_bytes,
)
self._batcher._user_batch_completed_callback = (
self._user_batch_completed_callback
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like you're only storing self._flush_interval, self._flush_count, and self._max_row_bytes so you have the context to re-build a new batcher later?

This pattern is fine, but I think functool.partial can make this kind of thing cleaner, letting you condensing all the state into single nullable function

self._batcher_build_fn = partial(self._build_batcher, callback, interval=flush_interval, ...)

def _build_batcher(self, callback, **kwargs):
    batcher = self.table._table_impl.mutations_batcher(**kwargs)
    batcher._user_batch_completed_callback = callback
    return batcher

And then every time you need a new batcher, you just call self._batcher = self._batcher_build_fn()

(It could be simplified more to get rid of the extra _build_batcher function if we make _user_batch_completed_callback into an unadvertised kwarg in the data client, but I'm not sure if that's worth it)


def _close_batcher(self):
try:
self._batcher.close()
except MutationsExceptionGroup as exc_group:
for error in exc_group.exceptions:
# Return the cause of the FailedMutationEntryError to the user,
# as this might be more what they're expecting.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this comment be improved, to be more definitive? Maybe something like # Unpack the root cause from FailedMutationEntryError wrapper

self._exceptions.put(error.__cause__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a good time to address this TODO?

The code seems to assume it will be FailedMutationEntryError, so we should make the types agree


def __enter__(self):
"""Starting the MutationsBatcher as a context manager"""
Expand All @@ -260,10 +145,7 @@ def mutate(self, row):
* :exc:`~.table._BigtableRetryableError` if any row returned a transient error.
* :exc:`RuntimeError` if the number of responses doesn't match the number of rows that were retried
"""
self._rows.put(row)

if self._rows.full():
self._flush_async()
self._batcher.append(RowMutationEntry(row.row_key, row._get_mutations()))

def mutate_rows(self, rows):
"""Add multiple rows to the batch. If the current batch meets one of the size
Expand Down Expand Up @@ -298,102 +180,8 @@ def flush(self):
:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
rows_to_flush = []
row = self._rows.get()
while row is not None:
rows_to_flush.append(row)
row = self._rows.get()
response = self._flush_rows(rows_to_flush)
return response

def _flush_async(self):
"""Sends the current batch to Cloud Bigtable asynchronously.

:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
next_row = self._rows.get()
while next_row is not None:
# start a new batch
rows_to_flush = [next_row]
batch_info = _BatchInfo(
mutations_count=len(next_row._get_mutations()),
rows_count=1,
mutations_size=next_row.get_mutations_size(),
)
# fill up batch with rows
next_row = self._rows.get()
while next_row is not None and self._row_fits_in_batch(
next_row, batch_info
):
rows_to_flush.append(next_row)
batch_info.mutations_count += len(next_row._get_mutations())
batch_info.rows_count += 1
batch_info.mutations_size += next_row.get_mutations_size()
next_row = self._rows.get()
# send batch over network
# wait for resources to become available
self.flow_control.wait()
# once unblocked, submit the batch
# event flag will be set by control_flow to block subsequent thread, but not blocking this one
self.flow_control.control_flow(batch_info)
future = self._executor.submit(self._flush_rows, rows_to_flush)
# schedule release of resources from flow control
self.futures_mapping[future] = batch_info
future.add_done_callback(self._batch_completed_callback)

def _batch_completed_callback(self, future):
"""Callback for when the mutation has finished to clean up the current batch
and release items from the flow controller.
Raise exceptions if there's any.
Release the resources locked by the flow control and allow enqueued tasks to be run.
"""
processed_rows = self.futures_mapping[future]
self.flow_control.release(processed_rows)
del self.futures_mapping[future]

def _row_fits_in_batch(self, row, batch_info):
"""Checks if a row can fit in the current batch.

:type row: class
:param row: :class:`~google.cloud.bigtable.row.DirectRow`.

:type batch_info: :class:`_BatchInfo`
:param batch_info: Information about the current batch.

:rtype: bool
:returns: True if the row can fit in the current batch.
"""
new_rows_count = batch_info.rows_count + 1
new_mutations_count = batch_info.mutations_count + len(row._get_mutations())
new_mutations_size = batch_info.mutations_size + row.get_mutations_size()
return (
new_rows_count <= self.flush_count
and new_mutations_size <= self.max_row_bytes
and new_mutations_count <= self.flow_control.max_mutations
and new_mutations_size <= self.flow_control.max_mutation_bytes
)

def _flush_rows(self, rows_to_flush):
"""Mutate the specified rows.

:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
"""
responses = []
if len(rows_to_flush) > 0:
response = self.table.mutate_rows(rows_to_flush)

if self._user_batch_completed_callback:
self._user_batch_completed_callback(response)

for result in response:
if result.code != 0:
exc = from_grpc_status(result.code, result.message)
self.exceptions.put(exc)
responses.append(result)

return responses
self._close_batcher()
self._init_batcher()

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Clean up resources. Flush and shutdown the ThreadPoolExecutor."""
Expand All @@ -406,9 +194,8 @@ def close(self):
:raises:
* :exc:`.batcherMutationsBatchError` if there's any error in the mutations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the exc name in the docstrings might be broken?

"""
self.flush()
self._executor.shutdown(wait=True)
self._close_batcher()
atexit.unregister(self.close)
if self.exceptions.qsize() > 0:
exc = list(self.exceptions.queue)
if self._exceptions.qsize() > 0:
exc = list(self._exceptions.queue)
raise MutationsBatchError("Errors in batch mutations.", exc=exc)
2 changes: 0 additions & 2 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import annotations

from typing import (
Callable,
cast,
Any,
AsyncIterable,
Expand Down Expand Up @@ -116,7 +115,6 @@
if TYPE_CHECKING:
from google.cloud.bigtable.data._helpers import RowKeySamples
from google.cloud.bigtable.data._helpers import ShardedQuery
from google.rpc import status_pb2

if CrossSync.is_async:
from google.cloud.bigtable.data._async.mutations_batcher import (
Expand Down
4 changes: 3 additions & 1 deletion google/cloud/bigtable/data/_async/mutations_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,9 @@ def __init__(
self._newest_exceptions: deque[Exception] = deque(
maxlen=self._exception_list_limit
)
self._user_batch_completed_callback = None
self._user_batch_completed_callback: Optional[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add a comment here describing that this is currently just used by the shim

Callable[[list[status_pb2.Status]], None]
] = None
# clean up on program exit
atexit.register(self._on_exit)

Expand Down
21 changes: 12 additions & 9 deletions google/cloud/bigtable/data/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from __future__ import annotations

from typing import Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import cast, Callable, Sequence, List, Optional, Tuple, TYPE_CHECKING, Union
import time
import enum
from collections import namedtuple
Expand Down Expand Up @@ -272,14 +272,17 @@ def _get_status(exc: Optional[Exception]) -> status_pb2.Status:
Returns:
status_pb2.Status: A Status proto object.
"""
if (
isinstance(exc, core_exceptions.GoogleAPICallError)
and exc.grpc_status_code is not None
):
return status_pb2.Status( # type: ignore[unreachable]
code=exc.grpc_status_code.value[0],
message=exc.message,
details=exc.details,
if isinstance(exc, core_exceptions.GoogleAPICallError):
status_code = cast(Optional["grpc.StatusCode"], exc.grpc_status_code)
if status_code is not None:
return status_pb2.Status(
code=status_code.value[0],
message=exc.message,
details=exc.details,
)
return status_pb2.Status(
code=code_pb2.Code.UNKNOWN,
message="An unknown error has occurred",
)

return status_pb2.Status(
Expand Down
Loading
Loading