MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / get_timestep_embedding

Function get_timestep_embedding

tensorrt_llm/layers/embedding.py:426–476  ·  view source on GitHub ↗

This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension

(
    timesteps: Tensor,
    embedding_dim: int,
    flip_sin_to_cos: bool = False,
    downscale_freq_shift: float = 1,
    scale: float = 1,
    max_period: int = 10000,
)

Source from the content-addressed store, hash-verified

424
425
426def get_timestep_embedding(
427 timesteps: Tensor,
428 embedding_dim: int,
429 flip_sin_to_cos: bool = False,
430 downscale_freq_shift: float = 1,
431 scale: float = 1,
432 max_period: int = 10000,
433) -> Tensor:
434 """
435 This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
436
437 Args
438 timesteps (Tensor):
439 a 1-D Tensor of N indices, one per batch element. These may be fractional.
440 embedding_dim (int):
441 the dimension of the output.
442 flip_sin_to_cos (bool):
443 Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
444 downscale_freq_shift (float):
445 Controls the delta between frequencies between dimensions
446 scale (float):
447 Scaling factor applied to the embeddings.
448 max_period (int):
449 Controls the maximum frequency of the embeddings
450 Returns
451 Tensor: an [N x dim] Tensor of positional embeddings.
452 """
453 assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
454
455 half_dim = embedding_dim // 2
456 exponent = -math.log(max_period) * np.arange(
457 start=0, stop=half_dim, dtype=np.float32)
458 exponent = exponent / (half_dim - downscale_freq_shift)
459 exponent = constant(exponent)
460
461 emb = exp(exponent)
462 emb = unsqueeze(timesteps, -1).cast('float32') * unsqueeze(emb, 0)
463
464 # scale embeddings
465 emb = scale * emb
466
467 # flip sine and cosine embeddings
468 if flip_sin_to_cos:
469 emb = concat([cos(emb), sin(emb)], dim=-1)
470 else:
471 emb = concat([sin(emb), cos(emb)], dim=-1)
472
473 # zero pad
474 if embedding_dim % 2 == 1:
475 emb = pad(emb, (0, 1, 0, 0))
476 return emb
477
478
479class TimestepEmbedding(Module):

Callers 1

forwardMethod · 0.70

Calls 6

constantFunction · 0.85
unsqueezeFunction · 0.85
concatFunction · 0.85
castMethod · 0.80
padFunction · 0.50
logMethod · 0.45

Tested by

no test coverage detected