x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) seq_idx: (batch, seqlen) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1), to be written to activation: either None or "silu" or "swish" out: (batch, dim, seqlen)
(
x,
weight,
bias=None,
seq_idx=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
)
| 98 | |
| 99 | |
| 100 | def causal_conv1d_fn( |
| 101 | x, |
| 102 | weight, |
| 103 | bias=None, |
| 104 | seq_idx=None, |
| 105 | initial_states=None, |
| 106 | return_final_states=False, |
| 107 | final_states_out=None, |
| 108 | activation=None, |
| 109 | ): |
| 110 | """ |
| 111 | x: (batch, dim, seqlen) |
| 112 | weight: (dim, width) |
| 113 | bias: (dim,) |
| 114 | seq_idx: (batch, seqlen) |
| 115 | initial_states: (batch, dim, width - 1) |
| 116 | final_states_out: (batch, dim, width - 1), to be written to |
| 117 | activation: either None or "silu" or "swish" |
| 118 | |
| 119 | out: (batch, dim, seqlen) |
| 120 | """ |
| 121 | return CausalConv1dFn.apply( |
| 122 | x, |
| 123 | weight, |
| 124 | bias, |
| 125 | seq_idx, |
| 126 | initial_states, |
| 127 | return_final_states, |
| 128 | final_states_out, |
| 129 | activation, |
| 130 | ) |
| 131 | |
| 132 | |
| 133 | def causal_conv1d_ref( |
no outgoing calls