Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an [N
(timesteps, dim, max_period=10000)
| 101 | |
| 102 | |
| 103 | def timestep_embedding(timesteps, dim, max_period=10000): |
| 104 | """ |
| 105 | Create sinusoidal timestep embeddings. |
| 106 | |
| 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| 108 | These may be fractional. |
| 109 | :param dim: the dimension of the output. |
| 110 | :param max_period: controls the minimum frequency of the embeddings. |
| 111 | :return: an [N x dim] Tensor of positional embeddings. |
| 112 | """ |
| 113 | half = dim // 2 |
| 114 | freqs = th.exp( |
| 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half |
| 116 | ).to(device=timesteps.device) |
| 117 | args = timesteps[:, None].float() * freqs[None] |
| 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) |
| 119 | if dim % 2: |
| 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) |
| 121 | return embedding |
| 122 | |
| 123 | |
| 124 | def checkpoint(func, inputs, params, flag): |