This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param embedding_dim: the dimension of the output. :para
(timesteps,
embedding_dim,
flip_sin_to_cos=False,
downscale_freq_shift=1.0,
scale=1.0,
max_period=10000,
dtype=None)
| 23 | |
| 24 | |
| 25 | def get_timestep_embedding(timesteps, |
| 26 | embedding_dim, |
| 27 | flip_sin_to_cos=False, |
| 28 | downscale_freq_shift=1.0, |
| 29 | scale=1.0, |
| 30 | max_period=10000, |
| 31 | dtype=None): |
| 32 | """ |
| 33 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| 34 | :param timesteps: a 1-D Tensor of N indices, one per batch element. |
| 35 | These may be fractional. |
| 36 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
| 37 | embeddings. :return: an [N x dim] Tensor of positional embeddings. |
| 38 | """ |
| 39 | assert timesteps.rank() == 1, "Timesteps should be a 1d-array" |
| 40 | |
| 41 | half_dim = embedding_dim // 2 |
| 42 | |
| 43 | exponent = [ |
| 44 | i * -math.log(max_period) / (half_dim - downscale_freq_shift) |
| 45 | for i in range(half_dim) |
| 46 | ] |
| 47 | |
| 48 | if dtype == trt.DataType.HALF: |
| 49 | emb = exp(constant(fp16_array(exponent))) |
| 50 | else: |
| 51 | emb = exp(constant(fp32_array(exponent))) |
| 52 | |
| 53 | ts_shape = list(timesteps.size()) |
| 54 | ts_shape.append(1) |
| 55 | emb_shape = list(emb.size()) |
| 56 | emb_shape.insert(0, 1) |
| 57 | |
| 58 | emb = timesteps.view(ts_shape) * emb.view(emb_shape) |
| 59 | |
| 60 | emb = scale * emb |
| 61 | # concat sine and cosine embeddings |
| 62 | |
| 63 | # flip sine and cosine embeddings |
| 64 | if flip_sin_to_cos: |
| 65 | emb = concat([cos(emb), sin(emb)], dim=1) |
| 66 | else: |
| 67 | emb = concat([sin(emb), cos(emb)], dim=1) |
| 68 | |
| 69 | #TODO Enable below logic when TensorRT LLM supports pad feature. |
| 70 | # zero pad |
| 71 | # if embedding_dim % 2 == 1: |
| 72 | # emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) |
| 73 | return emb |
| 74 | |
| 75 | |
| 76 | class TimestepEmbedding(Module): |