Add zeros to the end of the sequence to reach the max length. Supports padding multiple arrays at once. Args: element: The sequence to pad. max_length: The max length of the sequence. truncate: Whether to truncate the sequence to the max length. If `False`, sequences longer t
(
element: PyTree[Array["sequence"]],
max_length: int,
*,
truncate: bool = False,
fill_value: int = 0,
axis: int = -1,
)
| 23 | |
| 24 | # Do not @typechecked as `element` can be `list` too. |
| 25 | def pad( |
| 26 | element: PyTree[Array["sequence"]], |
| 27 | max_length: int, |
| 28 | *, |
| 29 | truncate: bool = False, |
| 30 | fill_value: int = 0, |
| 31 | axis: int = -1, |
| 32 | ) -> PyTree[Array["max_length"]]: |
| 33 | """Add zeros to the end of the sequence to reach the max length. |
| 34 | |
| 35 | Supports padding multiple arrays at once. |
| 36 | |
| 37 | Args: |
| 38 | element: The sequence to pad. |
| 39 | max_length: The max length of the sequence. |
| 40 | truncate: Whether to truncate the sequence to the max length. If `False`, |
| 41 | sequences longer than the `max_length` will raise an error. |
| 42 | fill_value: The value to fill the padded sequence with. |
| 43 | axis: The axis in which to pad the sequence (only -1 supported at the |
| 44 | moment). |
| 45 | |
| 46 | Returns: |
| 47 | The padded sequence of length `max_length`. |
| 48 | """ |
| 49 | if axis != -1: |
| 50 | raise NotImplementedError("Only `axis=-1` is supported.") |
| 51 | return jax.tree.map( |
| 52 | lambda x: _pad( |
| 53 | x, |
| 54 | max_length=max_length, |
| 55 | fill_value=fill_value, |
| 56 | truncate=truncate, |
| 57 | ), |
| 58 | element, |
| 59 | is_leaf=_is_list_array, # Also supports `[0, 1, ...]` |
| 60 | ) |
| 61 | |
| 62 | |
| 63 | def _pad( |