Crops positional embeddings for SD3 compatibility.
(self, height, width)
| 356 | f"Unsupported pos_embed_type: {self.pos_embed_type}") |
| 357 | |
| 358 | def cropped_pos_embed(self, height, width): |
| 359 | """Crops positional embeddings for SD3 compatibility.""" |
| 360 | if self.pos_embed_max_size is None: |
| 361 | raise ValueError("`pos_embed_max_size` must be set for cropping.") |
| 362 | |
| 363 | height = height // self.patch_size |
| 364 | width = width // self.patch_size |
| 365 | if height > self.pos_embed_max_size: |
| 366 | raise ValueError( |
| 367 | f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
| 368 | ) |
| 369 | if width > self.pos_embed_max_size: |
| 370 | raise ValueError( |
| 371 | f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." |
| 372 | ) |
| 373 | |
| 374 | top = (self.pos_embed_max_size - height) // 2 |
| 375 | left = (self.pos_embed_max_size - width) // 2 |
| 376 | spatial_pos_embed = identity(self.pos_embed.value).view( |
| 377 | [1, self.pos_embed_max_size, self.pos_embed_max_size, -1]) |
| 378 | spatial_pos_embed = slice(spatial_pos_embed, |
| 379 | starts=[0, top, left, 0], |
| 380 | sizes=concat([ |
| 381 | shape(spatial_pos_embed, 0), height, |
| 382 | width, |
| 383 | shape(spatial_pos_embed, 3) |
| 384 | ])) |
| 385 | spatial_pos_embed = spatial_pos_embed.view( |
| 386 | concat( |
| 387 | [1, -1, |
| 388 | shape(spatial_pos_embed, |
| 389 | spatial_pos_embed.ndim() - 1)])) |
| 390 | return spatial_pos_embed |
| 391 | |
| 392 | def forward(self, latent): |
| 393 | # [TODO] to support height and width for runtime |