(x: Tensor, limit: Optional[float] = None)
| 3386 | |
| 3387 | |
| 3388 | def gegelu(x: Tensor, limit: Optional[float] = None) -> Tensor: |
| 3389 | # a, b = x[..., ::2], x[..., 1::2] |
| 3390 | ndim = x.ndim() |
| 3391 | a_starts = [0 for i in range(ndim)] |
| 3392 | b_starts = [1 if i == (ndim - 1) else 0 for i in range(ndim)] |
| 3393 | shapes = concat([ |
| 3394 | shape(x, i) / 2 if i == (ndim - 1) else shape(x, i) for i in range(ndim) |
| 3395 | ]) |
| 3396 | strides = [2 if i == (ndim - 1) else 1 for i in range(ndim)] |
| 3397 | |
| 3398 | a = slice(x, a_starts, shapes, strides) |
| 3399 | b = slice(x, b_starts, shapes, strides) |
| 3400 | |
| 3401 | if limit is not None: |
| 3402 | a = clip(a, alpha=float(-1e20), beta=limit) |
| 3403 | b = clip(b, alpha=-limit, beta=limit) |
| 3404 | |
| 3405 | # C = B + 1 |
| 3406 | const1 = arange(constant(int32_array(1)), constant(int32_array(2)), |
| 3407 | trt_dtype_to_str(b.dtype)) |
| 3408 | for _ in range(ndim - 1): |
| 3409 | const1 = expand_dims(const1, 0) |
| 3410 | |
| 3411 | b_shape = concat([shape(b, i) for i in range(ndim)]) |
| 3412 | const1_arr = expand(const1, b_shape) |
| 3413 | |
| 3414 | return quick_gelu(a) * (b + const1_arr) |
| 3415 | |
| 3416 | |
| 3417 | def group_norm(input: Tensor, |
nothing calls this directly
no test coverage detected