x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) out: (batch, dim, seqlen)
(
x,
weight,
bias=None,
initial_states=None,
return_final_states=False,
final_states_out=None,
activation=None,
)
| 131 | |
| 132 | |
| 133 | def causal_conv1d_ref( |
| 134 | x, |
| 135 | weight, |
| 136 | bias=None, |
| 137 | initial_states=None, |
| 138 | return_final_states=False, |
| 139 | final_states_out=None, |
| 140 | activation=None, |
| 141 | ): |
| 142 | """ |
| 143 | x: (batch, dim, seqlen) |
| 144 | weight: (dim, width) |
| 145 | bias: (dim,) |
| 146 | initial_states: (batch, dim, width - 1) |
| 147 | final_states_out: (batch, dim, width - 1) |
| 148 | |
| 149 | out: (batch, dim, seqlen) |
| 150 | """ |
| 151 | if activation not in [None, "silu", "swish"]: |
| 152 | raise NotImplementedError("activation must be None, silu, or swish") |
| 153 | dtype_in = x.dtype |
| 154 | x = x.to(weight.dtype) |
| 155 | seqlen = x.shape[-1] |
| 156 | dim, width = weight.shape |
| 157 | if initial_states is None: |
| 158 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) |
| 159 | else: |
| 160 | x = torch.cat([initial_states, x], dim=-1) |
| 161 | out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) |
| 162 | out = out[..., :seqlen] |
| 163 | if return_final_states: |
| 164 | final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( |
| 165 | dtype_in |
| 166 | ) # (batch, dim, width - 1) |
| 167 | if final_states_out is not None: |
| 168 | final_states_out.copy_(final_states) |
| 169 | else: |
| 170 | final_states_out = final_states |
| 171 | out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) |
| 172 | return out if not return_final_states else (out, final_states_out) |
| 173 | |
| 174 | |
| 175 | def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): |