MCPcopy
hub / github.com/hustvl/Vim / causal_conv1d_ref

Function causal_conv1d_ref

causal-conv1d/causal_conv1d/causal_conv1d_interface.py:133–172  ·  view source on GitHub ↗

x: (batch, dim, seqlen) weight: (dim, width) bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) out: (batch, dim, seqlen)

(
    x,
    weight,
    bias=None,
    initial_states=None,
    return_final_states=False,
    final_states_out=None,
    activation=None,
)

Source from the content-addressed store, hash-verified

131
132
133def causal_conv1d_ref(
134 x,
135 weight,
136 bias=None,
137 initial_states=None,
138 return_final_states=False,
139 final_states_out=None,
140 activation=None,
141):
142 """
143 x: (batch, dim, seqlen)
144 weight: (dim, width)
145 bias: (dim,)
146 initial_states: (batch, dim, width - 1)
147 final_states_out: (batch, dim, width - 1)
148
149 out: (batch, dim, seqlen)
150 """
151 if activation not in [None, "silu", "swish"]:
152 raise NotImplementedError("activation must be None, silu, or swish")
153 dtype_in = x.dtype
154 x = x.to(weight.dtype)
155 seqlen = x.shape[-1]
156 dim, width = weight.shape
157 if initial_states is None:
158 out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
159 else:
160 x = torch.cat([initial_states, x], dim=-1)
161 out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
162 out = out[..., :seqlen]
163 if return_final_states:
164 final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
165 dtype_in
166 ) # (batch, dim, width - 1)
167 if final_states_out is not None:
168 final_states_out.copy_(final_states)
169 else:
170 final_states_out = final_states
171 out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
172 return out if not return_final_states else (out, final_states_out)
173
174
175def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):

Callers 2

test_causal_conv1dFunction · 0.90

Calls 2

toMethod · 0.45
catMethod · 0.45

Tested by 2

test_causal_conv1dFunction · 0.72