Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The full list of supported enums is documented at https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. Args: interpol
(interpolation_type: str)
| 197 | |
| 198 | |
| 199 | def resolve_interpolation_mode(interpolation_type: str): |
| 200 | """ |
| 201 | Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The |
| 202 | full list of supported enums is documented at |
| 203 | https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. |
| 204 | |
| 205 | Args: |
| 206 | interpolation_type (`str`): |
| 207 | A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, |
| 208 | `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes |
| 209 | in torchvision. |
| 210 | |
| 211 | Returns: |
| 212 | `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` |
| 213 | transform. |
| 214 | """ |
| 215 | if not is_torchvision_available(): |
| 216 | raise ImportError( |
| 217 | "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." |
| 218 | ) |
| 219 | |
| 220 | if interpolation_type == "bilinear": |
| 221 | interpolation_mode = transforms.InterpolationMode.BILINEAR |
| 222 | elif interpolation_type == "bicubic": |
| 223 | interpolation_mode = transforms.InterpolationMode.BICUBIC |
| 224 | elif interpolation_type == "box": |
| 225 | interpolation_mode = transforms.InterpolationMode.BOX |
| 226 | elif interpolation_type == "nearest": |
| 227 | interpolation_mode = transforms.InterpolationMode.NEAREST |
| 228 | elif interpolation_type == "nearest_exact": |
| 229 | interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT |
| 230 | elif interpolation_type == "hamming": |
| 231 | interpolation_mode = transforms.InterpolationMode.HAMMING |
| 232 | elif interpolation_type == "lanczos": |
| 233 | interpolation_mode = transforms.InterpolationMode.LANCZOS |
| 234 | else: |
| 235 | raise ValueError( |
| 236 | f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" |
| 237 | f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." |
| 238 | ) |
| 239 | |
| 240 | return interpolation_mode |
| 241 | |
| 242 | |
| 243 | def compute_dream_and_update_latents( |