This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension
(
timesteps: Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
)
| 424 | |
| 425 | |
| 426 | def get_timestep_embedding( |
| 427 | timesteps: Tensor, |
| 428 | embedding_dim: int, |
| 429 | flip_sin_to_cos: bool = False, |
| 430 | downscale_freq_shift: float = 1, |
| 431 | scale: float = 1, |
| 432 | max_period: int = 10000, |
| 433 | ) -> Tensor: |
| 434 | """ |
| 435 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| 436 | |
| 437 | Args |
| 438 | timesteps (Tensor): |
| 439 | a 1-D Tensor of N indices, one per batch element. These may be fractional. |
| 440 | embedding_dim (int): |
| 441 | the dimension of the output. |
| 442 | flip_sin_to_cos (bool): |
| 443 | Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) |
| 444 | downscale_freq_shift (float): |
| 445 | Controls the delta between frequencies between dimensions |
| 446 | scale (float): |
| 447 | Scaling factor applied to the embeddings. |
| 448 | max_period (int): |
| 449 | Controls the maximum frequency of the embeddings |
| 450 | Returns |
| 451 | Tensor: an [N x dim] Tensor of positional embeddings. |
| 452 | """ |
| 453 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" |
| 454 | |
| 455 | half_dim = embedding_dim // 2 |
| 456 | exponent = -math.log(max_period) * np.arange( |
| 457 | start=0, stop=half_dim, dtype=np.float32) |
| 458 | exponent = exponent / (half_dim - downscale_freq_shift) |
| 459 | exponent = constant(exponent) |
| 460 | |
| 461 | emb = exp(exponent) |
| 462 | emb = unsqueeze(timesteps, -1).cast('float32') * unsqueeze(emb, 0) |
| 463 | |
| 464 | # scale embeddings |
| 465 | emb = scale * emb |
| 466 | |
| 467 | # flip sine and cosine embeddings |
| 468 | if flip_sin_to_cos: |
| 469 | emb = concat([cos(emb), sin(emb)], dim=-1) |
| 470 | else: |
| 471 | emb = concat([sin(emb), cos(emb)], dim=-1) |
| 472 | |
| 473 | # zero pad |
| 474 | if embedding_dim % 2 == 1: |
| 475 | emb = pad(emb, (0, 1, 0, 0)) |
| 476 | return emb |
| 477 | |
| 478 | |
| 479 | class TimestepEmbedding(Module): |