diff --git a/examples/tokenize_data.py b/examples/tokenize_data.py index 327f9cd1..314b7eaf 100644 --- a/examples/tokenize_data.py +++ b/examples/tokenize_data.py @@ -76,30 +76,49 @@ def pack_sequences( ['▁toys', '▁.', '', '', '▁but', '▁just', '▁one', '▁look'] """ packed_sequences = [] + packed_position_ids = [] buffer = [] + position_buffer = [] for input_ids in batch["input_ids"]: - # Add the current sequence to the buffer - buffer.extend(input_ids) - buffer.append(eos_token_id) # Add EOS at the end of each sequence + # Truncate sequences that individually exceed max_seq_len (including EOS token). + seq_with_eos = (input_ids + [eos_token_id])[:max_seq_len] + # Position IDs reset to 0 at the start of each sub-sequence. + seq_positions = list(range(len(seq_with_eos))) - # Check if buffer needs to be split into chunks - while len(buffer) > max_seq_len: - # Take a full chunk from the buffer and append it to packed_sequences - packed_sequences.append(buffer[:max_seq_len]) - # Remove the processed chunk from the buffer - buffer = buffer[max_seq_len:] + # If adding this sequence would overflow, flush the current buffer first. + # This ensures every chunk starts at a sequence boundary (position_ids[0] == 0). + if buffer and len(buffer) + len(seq_with_eos) > max_seq_len: + padding_length = max_seq_len - len(buffer) + packed_sequences.append(buffer + [pad_token_id] * padding_length) + packed_position_ids.append(position_buffer + [0] * padding_length) + buffer = [] + position_buffer = [] + + buffer.extend(seq_with_eos) + position_buffer.extend(seq_positions) + + # Flush immediately if exactly full (no padding needed). + if len(buffer) == max_seq_len: + packed_sequences.append(buffer) + packed_position_ids.append(position_buffer) + buffer = [] + position_buffer = [] # Add the last buffer if it's exactly chunk_size if len(buffer) == max_seq_len: packed_sequences.append(buffer) + packed_position_ids.append(position_buffer) elif len(buffer) > cutoff_size: # if the buffer is larger than the cutoff size, pad it to the chunk_size # if not, we do not include in the packed_sequences - buffer.extend([pad_token_id] * (max_seq_len - len(buffer))) + padding_length = max_seq_len - len(buffer) + buffer.extend([pad_token_id] * padding_length) + position_buffer.extend([0] * padding_length) packed_sequences.append(buffer) + packed_position_ids.append(position_buffer) - output = {"input_ids": packed_sequences} + output = {"input_ids": packed_sequences, "position_ids": packed_position_ids} if add_labels: output["labels"] = [ [ @@ -109,7 +128,6 @@ def pack_sequences( for example in output["input_ids"] ] - # mask attention for padding tokens, a better version would also mask cross-sequence dependencies output["attention_mask"] = [ [0 if token_id == pad_token_id else 1 for token_id in example] for example in output["input_ids"]