Deciphering src_key_padding_mask
Introduction
How do you know your ML model is learning correctly?
After all, given enough capacity, it could probably fit any input you give it.
You can have a low loss and decently correct outputs, but is your model actually correct?
I was checking my code recently and decided to dive deep into how src_key_padding_mask
in PyTorch Transformers work and check if my outputs are actually correct.
Padding in Transformers
In most Transformer models, padding is necessary as your inputs may have a different size. For example, a given sentence could be 100 tokens long, but another sentence is 101 tokens long. For efficient batch training, it is necessary to pad these inputs. Of course, one can argue that the additional padding would cause inefficiency, but that’s really a separate point for another day and in general you would still pad.
In the code example below, we specifically test the TransformerEncoder
.
It is also relatively easy to think of a test that would convince us that it is doing the right thing.
Consider an input of shape (32, 10, 512)
, which represent the batch size, sequence length and dimension respectively.
Further consider how perhaps only the first two tokens are relevant, and the other 8 are simply padded tokens.
You would expect that this input, together with a correctly constructed src_key_padding_mask
would yield a relevant output (output[:, :2, :]
), with the other outputs being irrelevant (output[:, 2:, :]
).
Finally, if you put into this same Transformer input[:, :2, :]
, the output should be identical to output[:, :2, :]
with the src_key_padding_mask
.
Put simply, self-attention on the two relevant tokens is equivalent to self-attention on all ten tokens with the correct src_key_padding_mask
.
With this intuition, I worked with Google Gemini to generate some starter code to test this, and it’s correct!
Code
import torch
import torch.nn as nn
# Test the model
if __name__ == "__main__": # Create input data (bsz x seq_len x dim)
bsz, seq_len, dim = 32, 10, 512
input_data = torch.randn(bsz, seq_len, dim)
# Initialize model
encoder_layer = nn.TransformerEncoderLayer(
d_model=512,
nhead=8,
batch_first=True, # (bsz, seq_len, dim) format
dropout=0.0, # ensures more determinism
)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3)
for pad_idx in range(3, 9):
# Create the src_key_padding_mask
src_key_padding_mask = torch.zeros(bsz, 10, dtype=torch.bool)
src_key_padding_mask[:, pad_idx:] = (
True # True indices are the masked positions
)
# Forward pass
output = input_data
trunc = input_data[:, :pad_idx, :]
for layer_idx, mod in enumerate(transformer_encoder.layers):
output = mod(
output, src_key_padding_mask=src_key_padding_mask
) # input data with src_key_padding_mask
trunc = mod(trunc) # truncated relevant input data
# Observe how the truncated output is identical to the normal output with src_key_padding_mask
# indicating that the padding is correctly handled
assert torch.all(torch.isclose(trunc, output[:, :pad_idx, :], atol=1e-06))
print("All test cases passed!")