r""" A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sam
| 38 | |
| 39 | |
| 40 | class UNet1DModel(ModelMixin, ConfigMixin): |
| 41 | r""" |
| 42 | A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. |
| 43 | |
| 44 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| 45 | for all models (such as downloading or saving). |
| 46 | |
| 47 | Parameters: |
| 48 | sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. |
| 49 | in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. |
| 50 | out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. |
| 51 | extra_in_channels (`int`, *optional*, defaults to 0): |
| 52 | Number of additional channels to be added to the input of the first down block. Useful for cases where the |
| 53 | input data has more channels than what the model was initially designed for. |
| 54 | time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. |
| 55 | freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. |
| 56 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): |
| 57 | Whether to flip sin to cos for Fourier time embedding. |
| 58 | down_block_types (`tuple[str]`, *optional*, defaults to `("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D")`): |
| 59 | tuple of downsample block types. |
| 60 | up_block_types (`tuple[str]`, *optional*, defaults to `("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip")`): |
| 61 | tuple of upsample block types. |
| 62 | block_out_channels (`tuple[int]`, *optional*, defaults to `(32, 32, 64)`): |
| 63 | tuple of block output channels. |
| 64 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. |
| 65 | out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. |
| 66 | act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. |
| 67 | norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. |
| 68 | layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. |
| 69 | downsample_each_block (`int`, *optional*, defaults to `False`): |
| 70 | Experimental feature for using a UNet without upsampling. |
| 71 | """ |
| 72 | |
| 73 | _skip_layerwise_casting_patterns = ["norm"] |
| 74 | |
| 75 | @register_to_config |
| 76 | def __init__( |
| 77 | self, |
| 78 | sample_size: int = 65536, |
| 79 | sample_rate: int | None = None, |
| 80 | in_channels: int = 2, |
| 81 | out_channels: int = 2, |
| 82 | extra_in_channels: int = 0, |
| 83 | time_embedding_type: str = "fourier", |
| 84 | time_embedding_dim: int | None = None, |
| 85 | flip_sin_to_cos: bool = True, |
| 86 | use_timestep_embedding: bool = False, |
| 87 | freq_shift: float = 0.0, |
| 88 | down_block_types: tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), |
| 89 | up_block_types: tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), |
| 90 | mid_block_type: str = "UNetMidBlock1D", |
| 91 | out_block_type: str = None, |
| 92 | block_out_channels: tuple[int, ...] = (32, 32, 64), |
| 93 | act_fn: str = None, |
| 94 | norm_num_groups: int = 8, |
| 95 | layers_per_block: int = 1, |
| 96 | downsample_each_block: bool = False, |
| 97 | ): |
no outgoing calls
no test coverage detected
searching dependent graphs…