| 1754 | |
| 1755 | |
| 1756 | class ResnetDownsampleBlock2D(nn.Module): |
| 1757 | def __init__( |
| 1758 | self, |
| 1759 | in_channels: int, |
| 1760 | out_channels: int, |
| 1761 | temb_channels: int, |
| 1762 | dropout: float = 0.0, |
| 1763 | num_layers: int = 1, |
| 1764 | resnet_eps: float = 1e-6, |
| 1765 | resnet_time_scale_shift: str = "default", |
| 1766 | resnet_act_fn: str = "swish", |
| 1767 | resnet_groups: int = 32, |
| 1768 | resnet_pre_norm: bool = True, |
| 1769 | output_scale_factor: float = 1.0, |
| 1770 | add_downsample: bool = True, |
| 1771 | skip_time_act: bool = False, |
| 1772 | ): |
| 1773 | super().__init__() |
| 1774 | resnets = [] |
| 1775 | |
| 1776 | for i in range(num_layers): |
| 1777 | in_channels = in_channels if i == 0 else out_channels |
| 1778 | resnets.append( |
| 1779 | ResnetBlock2D( |
| 1780 | in_channels=in_channels, |
| 1781 | out_channels=out_channels, |
| 1782 | temb_channels=temb_channels, |
| 1783 | eps=resnet_eps, |
| 1784 | groups=resnet_groups, |
| 1785 | dropout=dropout, |
| 1786 | time_embedding_norm=resnet_time_scale_shift, |
| 1787 | non_linearity=resnet_act_fn, |
| 1788 | output_scale_factor=output_scale_factor, |
| 1789 | pre_norm=resnet_pre_norm, |
| 1790 | skip_time_act=skip_time_act, |
| 1791 | ) |
| 1792 | ) |
| 1793 | |
| 1794 | self.resnets = nn.ModuleList(resnets) |
| 1795 | |
| 1796 | if add_downsample: |
| 1797 | self.downsamplers = nn.ModuleList( |
| 1798 | [ |
| 1799 | ResnetBlock2D( |
| 1800 | in_channels=out_channels, |
| 1801 | out_channels=out_channels, |
| 1802 | temb_channels=temb_channels, |
| 1803 | eps=resnet_eps, |
| 1804 | groups=resnet_groups, |
| 1805 | dropout=dropout, |
| 1806 | time_embedding_norm=resnet_time_scale_shift, |
| 1807 | non_linearity=resnet_act_fn, |
| 1808 | output_scale_factor=output_scale_factor, |
| 1809 | pre_norm=resnet_pre_norm, |
| 1810 | skip_time_act=skip_time_act, |
| 1811 | down=True, |
| 1812 | ) |
| 1813 | ] |
no outgoing calls
no test coverage detected
searching dependent graphs…