MCPcopy
hub / github.com/google-deepmind/gemma / pad

Function pad

gemma/gm/data/_functional.py:25–60  ·  view source on GitHub ↗

Add zeros to the end of the sequence to reach the max length. Supports padding multiple arrays at once. Args: element: The sequence to pad. max_length: The max length of the sequence. truncate: Whether to truncate the sequence to the max length. If `False`, sequences longer t

(
    element: PyTree[Array["sequence"]],
    max_length: int,
    *,
    truncate: bool = False,
    fill_value: int = 0,
    axis: int = -1,
)

Source from the content-addressed store, hash-verified

23
24# Do not @typechecked as `element` can be `list` too.
25def pad(
26 element: PyTree[Array["sequence"]],
27 max_length: int,
28 *,
29 truncate: bool = False,
30 fill_value: int = 0,
31 axis: int = -1,
32) -> PyTree[Array["max_length"]]:
33 """Add zeros to the end of the sequence to reach the max length.
34
35 Supports padding multiple arrays at once.
36
37 Args:
38 element: The sequence to pad.
39 max_length: The max length of the sequence.
40 truncate: Whether to truncate the sequence to the max length. If `False`,
41 sequences longer than the `max_length` will raise an error.
42 fill_value: The value to fill the padded sequence with.
43 axis: The axis in which to pad the sequence (only -1 supported at the
44 moment).
45
46 Returns:
47 The padded sequence of length `max_length`.
48 """
49 if axis != -1:
50 raise NotImplementedError("Only `axis=-1` is supported.")
51 return jax.tree.map(
52 lambda x: _pad(
53 x,
54 max_length=max_length,
55 fill_value=fill_value,
56 truncate=truncate,
57 ),
58 element,
59 is_leaf=_is_list_array, # Also supports `[0, 1, ...]`
60 )
61
62
63def _pad(

Callers

nothing calls this directly

Calls 2

_padFunction · 0.85
mapMethod · 0.45

Tested by

no test coverage detected