MCPcopy Index your code
hub / github.com/apache/tvm / attention_python

Function attention_python

python/tvm/topi/testing/attention_python.py:25–125  ·  view source on GitHub ↗

Attention operator in python Parameters ---------- q : np.ndarray Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by `layout`. k : np.ndarray Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the

(
    q: np.ndarray,
    k: np.ndarray,
    v: np.ndarray,
    bias: np.ndarray | None,
    qk_scale: float,
    causal: str,
    window_size: int | None = None,
    layout: str = "BSNH",
)

Source from the content-addressed store, hash-verified

23
24
25def attention_python(
26 q: np.ndarray,
27 k: np.ndarray,
28 v: np.ndarray,
29 bias: np.ndarray | None,
30 qk_scale: float,
31 causal: str,
32 window_size: int | None = None,
33 layout: str = "BSNH",
34): # pylint: disable=too-many-arguments, too-many-locals, invalid-name
35 """Attention operator in python
36
37 Parameters
38 ----------
39 q : np.ndarray
40 Query tensor with shape [batch, seq_length, num_heads, head_dim] in the layout specified by
41 `layout`.
42 k : np.ndarray
43 Key tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim] in the layout specified
44 by `layout`.
45 v : np.ndarray
46 Value tensor with shape [batch, seq_length_kv, num_kv_heads, head_dim_v] in the layout
47 specified by `layout`.
48 bias : np.ndarray
49 Bias tensor with shape [batch, num_heads, seq_length, seq_length]
50 qk_scale : float
51 Scale factor for the query-key product.
52 causal : str
53 The type of causal mask to apply. Can be "none", "TopLeft", or "BottomRight".
54 window_size : Optional[int]
55 The window size for the causal mask.
56 layout : str
57 The layout of the input tensors, e.g. "BSNH" or "BNSH".
58
59 Returns
60 -------
61 np.ndarray
62 The output tensor with shape [batch, seq_length, num_heads, head_dim_v] in the layout
63 specified by `layout`.
64 """
65 assert layout in ["BSNH", "BNSH", "SBNH"]
66
67 dim_b = layout.find("B")
68 dim_s = layout.find("S")
69 dim_n = layout.find("N")
70 dim_h = layout.find("H")
71
72 q = q.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s, h
73 k = k.transpose(dim_b, dim_n, dim_s, dim_h) # b, n, s_kv, h
74 kt = k.transpose(0, 1, 3, 2) # b, n, h, s_kv
75 v = v.transpose(dim_b, dim_n, dim_s, dim_h)
76
77 num_heads = q.shape[1]
78 num_kv_heads = k.shape[1]
79 s = q.shape[2]
80 s_kv = k.shape[2]
81
82 if num_heads != num_kv_heads:

Callers

nothing calls this directly

Calls 10

softmax_pythonFunction · 0.90
trilMethod · 0.80
triuMethod · 0.80
maxMethod · 0.80
absFunction · 0.50
transposeMethod · 0.45
repeatMethod · 0.45
expMethod · 0.45
sumMethod · 0.45
divideMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…