(self, latent)
| 390 | return spatial_pos_embed |
| 391 | |
| 392 | def forward(self, latent): |
| 393 | # [TODO] to support height and width for runtime |
| 394 | if self.pos_embed_max_size is not None: |
| 395 | height, width = latent.shape[-2:] |
| 396 | else: |
| 397 | height, width = latent.shape[-2] // self.patch_size, latent.shape[ |
| 398 | -1] // self.patch_size |
| 399 | latent = self.proj(latent) |
| 400 | if self.flatten: |
| 401 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC |
| 402 | if self.layer_norm: |
| 403 | latent = self.norm(latent) |
| 404 | if self.pos_embed is None: |
| 405 | return latent.cast(latent.dtype) |
| 406 | # Interpolate or crop positional embeddings as needed |
| 407 | if self.pos_embed_max_size: |
| 408 | pos_embed = self.cropped_pos_embed(height, width) |
| 409 | else: |
| 410 | if self.height != height or self.width != width: |
| 411 | pos_embed = get_2d_sincos_pos_embed( |
| 412 | embed_dim=self.pos_embed.value.shape[-1], |
| 413 | grid_size=(height, width), |
| 414 | base_size=self.base_size, |
| 415 | interpolation_scale=self.interpolation_scale, |
| 416 | ) |
| 417 | pos_embed = unsqueeze(pos_embed.cast('float32'), axis=0) |
| 418 | else: |
| 419 | pos_embed = self.pos_embed.value |
| 420 | |
| 421 | pos_embed = pos_embed.cast(latent.dtype) |
| 422 | output = (latent + pos_embed).cast(latent.dtype) |
| 423 | return output |
| 424 | |
| 425 | |
| 426 | def get_timestep_embedding( |
nothing calls this directly
no test coverage detected