MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / gegelu

Function gegelu

tensorrt_llm/functional.py:3388–3414  ·  view source on GitHub ↗
(x: Tensor, limit: Optional[float] = None)

Source from the content-addressed store, hash-verified

3386
3387
3388def gegelu(x: Tensor, limit: Optional[float] = None) -> Tensor:
3389 # a, b = x[..., ::2], x[..., 1::2]
3390 ndim = x.ndim()
3391 a_starts = [0 for i in range(ndim)]
3392 b_starts = [1 if i == (ndim - 1) else 0 for i in range(ndim)]
3393 shapes = concat([
3394 shape(x, i) / 2 if i == (ndim - 1) else shape(x, i) for i in range(ndim)
3395 ])
3396 strides = [2 if i == (ndim - 1) else 1 for i in range(ndim)]
3397
3398 a = slice(x, a_starts, shapes, strides)
3399 b = slice(x, b_starts, shapes, strides)
3400
3401 if limit is not None:
3402 a = clip(a, alpha=float(-1e20), beta=limit)
3403 b = clip(b, alpha=-limit, beta=limit)
3404
3405 # C = B + 1
3406 const1 = arange(constant(int32_array(1)), constant(int32_array(2)),
3407 trt_dtype_to_str(b.dtype))
3408 for _ in range(ndim - 1):
3409 const1 = expand_dims(const1, 0)
3410
3411 b_shape = concat([shape(b, i) for i in range(ndim)])
3412 const1_arr = expand(const1, b_shape)
3413
3414 return quick_gelu(a) * (b + const1_arr)
3415
3416
3417def group_norm(input: Tensor,

Callers

nothing calls this directly

Calls 11

concatFunction · 0.85
sliceFunction · 0.85
arangeFunction · 0.85
constantFunction · 0.85
trt_dtype_to_strFunction · 0.85
expand_dimsFunction · 0.85
expandFunction · 0.85
quick_geluFunction · 0.85
shapeFunction · 0.70
clipFunction · 0.70
ndimMethod · 0.45

Tested by

no test coverage detected