| 1721 | return self.delegate(*args, **kwargs) |
| 1722 | |
| 1723 | def set_deep_shrink(self, ds_depth_1, ds_timesteps_1=650, ds_depth_2=None, ds_timesteps_2=None, ds_ratio=0.5): |
| 1724 | if ds_depth_1 is None: |
| 1725 | logger.info("Deep Shrink is disabled.") |
| 1726 | self.ds_depth_1 = None |
| 1727 | self.ds_timesteps_1 = None |
| 1728 | self.ds_depth_2 = None |
| 1729 | self.ds_timesteps_2 = None |
| 1730 | self.ds_ratio = None |
| 1731 | else: |
| 1732 | logger.info( |
| 1733 | f"Deep Shrink is enabled: [depth={ds_depth_1}/{ds_depth_2}, timesteps={ds_timesteps_1}/{ds_timesteps_2}, ratio={ds_ratio}]" |
| 1734 | ) |
| 1735 | self.ds_depth_1 = ds_depth_1 |
| 1736 | self.ds_timesteps_1 = ds_timesteps_1 |
| 1737 | self.ds_depth_2 = ds_depth_2 if ds_depth_2 is not None else -1 |
| 1738 | self.ds_timesteps_2 = ds_timesteps_2 if ds_timesteps_2 is not None else 1000 |
| 1739 | self.ds_ratio = ds_ratio |
| 1740 | |
| 1741 | def up_block_forward(self, _self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): |
| 1742 | for resnet in _self.resnets: |